GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_msum.cpp Lines: 130 144 90.3 %
Date: 2021-09-30 Branches: 338 640 52.8 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Arithmetic utilities regarding monomial sums.
14
 */
15
16
#include "theory/arith/arith_msum.h"
17
18
#include "theory/rewriter.h"
19
#include "util/rational.h"
20
21
using namespace cvc5::kind;
22
23
namespace cvc5 {
24
namespace theory {
25
26
325
bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
27
{
28
325
  if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
29
  {
30
325
    c = n[0];
31
325
    v = n[1];
32
325
    return true;
33
  }
34
  return false;
35
}
36
37
396489
bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
38
{
39
396489
  if (n.isConst())
40
  {
41
116734
    if (msum.find(Node::null()) == msum.end())
42
    {
43
116734
      msum[Node::null()] = n;
44
116734
      return true;
45
    }
46
  }
47
279755
  else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
48
  {
49
113113
    if (msum.find(n[1]) == msum.end())
50
    {
51
113113
      msum[n[1]] = n[0];
52
113113
      return true;
53
    }
54
  }
55
  else
56
  {
57
166642
    if (msum.find(n) == msum.end())
58
    {
59
166642
      msum[n] = Node::null();
60
166642
      return true;
61
    }
62
  }
63
  return false;
64
}
65
66
249889
bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
67
{
68
249889
  if (n.getKind() == PLUS)
69
  {
70
351086
    for (Node nc : n)
71
    {
72
248843
      if (!getMonomial(nc, msum))
73
      {
74
        return false;
75
      }
76
    }
77
102243
    return true;
78
  }
79
147646
  return getMonomial(n, msum);
80
}
81
82
101144
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
83
{
84
101144
  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
85
  {
86
83002
    if (getMonomialSum(lit[0], msum))
87
    {
88
83002
      if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
89
      {
90
15446
        return true;
91
      }
92
      else
93
      {
94
        // subtract the other side
95
67556
        std::map<Node, Node> msum2;
96
67556
        NodeManager* nm = NodeManager::currentNM();
97
67556
        if (getMonomialSum(lit[1], msum2))
98
        {
99
157650
          for (std::map<Node, Node>::iterator it = msum2.begin();
100
157650
               it != msum2.end();
101
               ++it)
102
          {
103
90094
            std::map<Node, Node>::iterator it2 = msum.find(it->first);
104
90094
            if (it2 != msum.end())
105
            {
106
              Node r = nm->mkNode(
107
                  MINUS,
108
30
                  it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second,
109
60
                  it->second.isNull() ? nm->mkConst(Rational(1)) : it->second);
110
15
              msum[it->first] = Rewriter::rewrite(r);
111
            }
112
            else
113
            {
114
142418
              msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1))
115
52339
                                                    : negate(it->second);
116
            }
117
          }
118
67556
          return true;
119
        }
120
      }
121
    }
122
  }
123
18142
  return false;
124
}
125
126
102
Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
127
{
128
102
  NodeManager* nm = NodeManager::currentNM();
129
204
  std::vector<Node> children;
130
204
  for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
131
       ++it)
132
  {
133
204
    Node m;
134
102
    if (!it->first.isNull())
135
    {
136
68
      m = mkCoeffTerm(it->second, it->first);
137
    }
138
    else
139
    {
140
34
      Assert(!it->second.isNull());
141
34
      m = it->second;
142
    }
143
102
    children.push_back(m);
144
  }
145
102
  return children.size() > 1
146
             ? nm->mkNode(PLUS, children)
147
204
             : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0)));
148
}
149
150
59725
int ArithMSum::isolate(
151
    Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
152
{
153
59725
  Assert(veq_c.isNull());
154
59725
  std::map<Node, Node>::const_iterator itv = msum.find(v);
155
59725
  if (itv != msum.end())
156
  {
157
57995
    std::vector<Node> children;
158
    Rational r =
159
57995
        itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
160
57987
    if (r.sgn() != 0)
161
    {
162
208858
      for (std::map<Node, Node>::const_iterator it = msum.begin();
163
208858
           it != msum.end();
164
           ++it)
165
      {
166
150879
        if (it->first != v)
167
        {
168
185800
          Node m;
169
92900
          if (!it->first.isNull())
170
          {
171
63993
            m = mkCoeffTerm(it->second, it->first);
172
          }
173
          else
174
          {
175
28907
            m = it->second;
176
          }
177
92900
          children.push_back(m);
178
        }
179
      }
180
115958
      val = children.size() > 1
181
178032
                ? NodeManager::currentNM()->mkNode(PLUS, children)
182
31037
                : (children.size() == 1
183
26299
                       ? children[0]
184
62717
                       : NodeManager::currentNM()->mkConst(Rational(0)));
185
57979
      if (!r.isOne() && !r.isNegativeOne())
186
      {
187
1583
        if (v.getType().isInteger())
188
        {
189
924
          veq_c = NodeManager::currentNM()->mkConst(r.abs());
190
        }
191
        else
192
        {
193
1318
          val = NodeManager::currentNM()->mkNode(
194
              MULT,
195
              val,
196
1318
              NodeManager::currentNM()->mkConst(Rational(1) / r.abs()));
197
        }
198
      }
199
57979
      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
200
57979
      return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
201
    }
202
  }
203
1746
  return 0;
204
}
205
206
9157
int ArithMSum::isolate(
207
    Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
208
{
209
18314
  Node veq_c;
210
18314
  Node val;
211
  // isolate v in the (in)equality
212
9157
  int ires = isolate(v, msum, veq_c, val, k);
213
9157
  if (ires != 0)
214
  {
215
18306
    Node vc = v;
216
9155
    if (!veq_c.isNull())
217
    {
218
48
      if (doCoeff)
219
      {
220
44
        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
221
      }
222
      else
223
      {
224
4
        return 0;
225
      }
226
    }
227
9151
    bool inOrder = ires == 1;
228
9151
    veq = NodeManager::currentNM()->mkNode(
229
        k, inOrder ? vc : val, inOrder ? val : vc);
230
  }
231
9153
  return ires;
232
}
233
234
22824
Node ArithMSum::solveEqualityFor(Node lit, Node v)
235
{
236
22824
  Assert(lit.getKind() == EQUAL);
237
  // first look directly at sides
238
45648
  TypeNode tn = lit[0].getType();
239
36804
  for (unsigned r = 0; r < 2; r++)
240
  {
241
29814
    if (lit[r] == v)
242
    {
243
15834
      return lit[1 - r];
244
    }
245
  }
246
6990
  if (tn.isReal())
247
  {
248
7028
    std::map<Node, Node> msum;
249
6990
    if (ArithMSum::getMonomialSumLit(lit, msum))
250
    {
251
7028
      Node val, veqc;
252
6990
      if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0)
253
      {
254
6990
        if (veqc.isNull())
255
        {
256
          // in this case, we have an integer equality with a coefficient
257
          // on the variable we solved for that could not be eliminated,
258
          // hence we fail.
259
6952
          return val;
260
        }
261
      }
262
    }
263
  }
264
38
  return Node::null();
265
}
266
267
bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
268
{
269
  std::map<Node, Node> msum;
270
  if (getMonomialSum(n, msum))
271
  {
272
    std::map<Node, Node>::iterator it = msum.find(v);
273
    if (it == msum.end())
274
    {
275
      return false;
276
    }
277
    else
278
    {
279
      coeff = it->second;
280
      msum.erase(v);
281
      rem = mkNode(msum);
282
      return true;
283
    }
284
  }
285
  return false;
286
}
287
288
88675
Node ArithMSum::negate(Node t)
289
{
290
  Node tt = NodeManager::currentNM()->mkNode(
291
88675
      MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t);
292
88675
  tt = Rewriter::rewrite(tt);
293
88675
  return tt;
294
}
295
296
586
Node ArithMSum::offset(Node t, int i)
297
{
298
  Node tt = NodeManager::currentNM()->mkNode(
299
586
      PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t);
300
586
  tt = Rewriter::rewrite(tt);
301
586
  return tt;
302
}
303
304
2277
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
305
{
306
7047
  for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
307
  {
308
4770
    Trace(c) << "  ";
309
4770
    if (!it->second.isNull())
310
    {
311
2137
      Trace(c) << it->second;
312
2137
      if (!it->first.isNull())
313
      {
314
1190
        Trace(c) << " * ";
315
      }
316
    }
317
4770
    if (!it->first.isNull())
318
    {
319
3823
      Trace(c) << it->first;
320
    }
321
4770
    Trace(c) << std::endl;
322
  }
323
2277
  Trace(c) << std::endl;
324
2277
}
325
326
}  // namespace theory
327
22755
}  // namespace cvc5