GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_msum.cpp Lines: 129 144 89.6 %
Date: 2021-11-07 Branches: 322 640 50.3 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Arithmetic utilities regarding monomial sums.
14
 */
15
16
#include "theory/arith/arith_msum.h"
17
18
#include "theory/rewriter.h"
19
#include "util/rational.h"
20
21
using namespace cvc5::kind;
22
23
namespace cvc5 {
24
namespace theory {
25
26
535
bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
27
{
28
535
  if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
29
  {
30
535
    c = n[0];
31
535
    v = n[1];
32
535
    return true;
33
  }
34
  return false;
35
}
36
37
542388
bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
38
{
39
542388
  if (n.isConst())
40
  {
41
143218
    if (msum.find(Node::null()) == msum.end())
42
    {
43
143218
      msum[Node::null()] = n;
44
143218
      return true;
45
    }
46
  }
47
399170
  else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
48
  {
49
120446
    if (msum.find(n[1]) == msum.end())
50
    {
51
120446
      msum[n[1]] = n[0];
52
120446
      return true;
53
    }
54
  }
55
  else
56
  {
57
278724
    if (msum.find(n) == msum.end())
58
    {
59
278724
      msum[n] = Node::null();
60
278724
      return true;
61
    }
62
  }
63
  return false;
64
}
65
66
392330
bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
67
{
68
392330
  if (n.getKind() == PLUS)
69
  {
70
371014
    for (Node nc : n)
71
    {
72
260536
      if (!getMonomial(nc, msum))
73
      {
74
        return false;
75
      }
76
    }
77
110478
    return true;
78
  }
79
281852
  return getMonomial(n, msum);
80
}
81
82
198904
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
83
{
84
198904
  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
85
  {
86
171452
    if (getMonomialSum(lit[0], msum))
87
    {
88
171452
      if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
89
      {
90
38627
        return true;
91
      }
92
      else
93
      {
94
        // subtract the other side
95
132825
        std::map<Node, Node> msum2;
96
132825
        NodeManager* nm = NodeManager::currentNM();
97
132825
        if (getMonomialSum(lit[1], msum2))
98
        {
99
277390
          for (std::map<Node, Node>::iterator it = msum2.begin();
100
277390
               it != msum2.end();
101
               ++it)
102
          {
103
144565
            std::map<Node, Node>::iterator it2 = msum.find(it->first);
104
144565
            if (it2 != msum.end())
105
            {
106
              Node r = nm->mkNode(
107
                  MINUS,
108
20
                  it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second,
109
40
                  it->second.isNull() ? nm->mkConst(Rational(1)) : it->second);
110
10
              msum[it->first] = Rewriter::rewrite(r);
111
            }
112
            else
113
            {
114
228551
              msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1))
115
83996
                                                    : negate(it->second);
116
            }
117
          }
118
132825
          return true;
119
        }
120
      }
121
    }
122
  }
123
27452
  return false;
124
}
125
126
347
Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
127
{
128
347
  NodeManager* nm = NodeManager::currentNM();
129
694
  std::vector<Node> children;
130
694
  for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
131
       ++it)
132
  {
133
694
    Node m;
134
347
    if (!it->first.isNull())
135
    {
136
246
      m = mkCoeffTerm(it->second, it->first);
137
    }
138
    else
139
    {
140
101
      Assert(!it->second.isNull());
141
101
      m = it->second;
142
    }
143
347
    children.push_back(m);
144
  }
145
347
  return children.size() > 1
146
             ? nm->mkNode(PLUS, children)
147
694
             : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0)));
148
}
149
150
108678
int ArithMSum::isolate(
151
    Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
152
{
153
108678
  Assert(veq_c.isNull());
154
108678
  std::map<Node, Node>::const_iterator itv = msum.find(v);
155
108678
  if (itv != msum.end())
156
  {
157
107285
    std::vector<Node> children;
158
    Rational r =
159
107285
        itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
160
107285
    if (r.sgn() != 0)
161
    {
162
361851
      for (std::map<Node, Node>::const_iterator it = msum.begin();
163
361851
           it != msum.end();
164
           ++it)
165
      {
166
254566
        if (it->first != v)
167
        {
168
294562
          Node m;
169
147281
          if (!it->first.isNull())
170
          {
171
101600
            m = mkCoeffTerm(it->second, it->first);
172
          }
173
          else
174
          {
175
45681
            m = it->second;
176
          }
177
147281
          children.push_back(m);
178
        }
179
      }
180
214570
      val = children.size() > 1
181
348730
                ? NodeManager::currentNM()->mkNode(PLUS, children)
182
67080
                : (children.size() == 1
183
57044
                       ? children[0]
184
117321
                       : NodeManager::currentNM()->mkConst(Rational(0)));
185
107285
      if (!r.isOne() && !r.isNegativeOne())
186
      {
187
5220
        if (v.getType().isInteger())
188
        {
189
2319
          veq_c = NodeManager::currentNM()->mkConst(r.abs());
190
        }
191
        else
192
        {
193
5802
          val = NodeManager::currentNM()->mkNode(
194
              MULT,
195
              val,
196
5802
              NodeManager::currentNM()->mkConst(Rational(1) / r.abs()));
197
        }
198
      }
199
107285
      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
200
107285
      return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
201
    }
202
  }
203
1393
  return 0;
204
}
205
206
14580
int ArithMSum::isolate(
207
    Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
208
{
209
29160
  Node veq_c;
210
29160
  Node val;
211
  // isolate v in the (in)equality
212
14580
  int ires = isolate(v, msum, veq_c, val, k);
213
14580
  if (ires != 0)
214
  {
215
29129
    Node vc = v;
216
14572
    if (!veq_c.isNull())
217
    {
218
102
      if (doCoeff)
219
      {
220
87
        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
221
      }
222
      else
223
      {
224
15
        return 0;
225
      }
226
    }
227
14557
    bool inOrder = ires == 1;
228
14557
    veq = NodeManager::currentNM()->mkNode(
229
        k, inOrder ? vc : val, inOrder ? val : vc);
230
  }
231
14565
  return ires;
232
}
233
234
140
Node ArithMSum::solveEqualityFor(Node lit, Node v)
235
{
236
140
  Assert(lit.getKind() == EQUAL);
237
  // first look directly at sides
238
280
  TypeNode tn = lit[0].getType();
239
160
  for (unsigned r = 0; r < 2; r++)
240
  {
241
150
    if (lit[r] == v)
242
    {
243
130
      return lit[1 - r];
244
    }
245
  }
246
10
  if (tn.isReal())
247
  {
248
10
    std::map<Node, Node> msum;
249
10
    if (ArithMSum::getMonomialSumLit(lit, msum))
250
    {
251
10
      Node val, veqc;
252
10
      if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0)
253
      {
254
10
        if (veqc.isNull())
255
        {
256
          // in this case, we have an integer equality with a coefficient
257
          // on the variable we solved for that could not be eliminated,
258
          // hence we fail.
259
10
          return val;
260
        }
261
      }
262
    }
263
  }
264
  return Node::null();
265
}
266
267
bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
268
{
269
  std::map<Node, Node> msum;
270
  if (getMonomialSum(n, msum))
271
  {
272
    std::map<Node, Node>::iterator it = msum.find(v);
273
    if (it == msum.end())
274
    {
275
      return false;
276
    }
277
    else
278
    {
279
      coeff = it->second;
280
      msum.erase(v);
281
      rem = mkNode(msum);
282
      return true;
283
    }
284
  }
285
  return false;
286
}
287
288
156293
Node ArithMSum::negate(Node t)
289
{
290
  Node tt = NodeManager::currentNM()->mkNode(
291
156293
      MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t);
292
156293
  tt = Rewriter::rewrite(tt);
293
156293
  return tt;
294
}
295
296
1020
Node ArithMSum::offset(Node t, int i)
297
{
298
  Node tt = NodeManager::currentNM()->mkNode(
299
1020
      PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t);
300
1020
  tt = Rewriter::rewrite(tt);
301
1020
  return tt;
302
}
303
304
3522
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
305
{
306
10892
  for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
307
  {
308
7370
    Trace(c) << "  ";
309
7370
    if (!it->second.isNull())
310
    {
311
3298
      Trace(c) << it->second;
312
3298
      if (!it->first.isNull())
313
      {
314
1792
        Trace(c) << " * ";
315
      }
316
    }
317
7370
    if (!it->first.isNull())
318
    {
319
5864
      Trace(c) << it->first;
320
    }
321
7370
    Trace(c) << std::endl;
322
  }
323
3522
  Trace(c) << std::endl;
324
3522
}
325
326
}  // namespace theory
327
31137
}  // namespace cvc5