GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_msum.cpp Lines: 130 144 90.3 %
Date: 2021-09-29 Branches: 338 640 52.8 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Arithmetic utilities regarding monomial sums.
14
 */
15
16
#include "theory/arith/arith_msum.h"
17
18
#include "theory/rewriter.h"
19
#include "util/rational.h"
20
21
using namespace cvc5::kind;
22
23
namespace cvc5 {
24
namespace theory {
25
26
325
bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
27
{
28
325
  if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
29
  {
30
325
    c = n[0];
31
325
    v = n[1];
32
325
    return true;
33
  }
34
  return false;
35
}
36
37
388046
bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
38
{
39
388046
  if (n.isConst())
40
  {
41
113653
    if (msum.find(Node::null()) == msum.end())
42
    {
43
113653
      msum[Node::null()] = n;
44
113653
      return true;
45
    }
46
  }
47
274393
  else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
48
  {
49
109651
    if (msum.find(n[1]) == msum.end())
50
    {
51
109651
      msum[n[1]] = n[0];
52
109651
      return true;
53
    }
54
  }
55
  else
56
  {
57
164742
    if (msum.find(n) == msum.end())
58
    {
59
164742
      msum[n] = Node::null();
60
164742
      return true;
61
    }
62
  }
63
  return false;
64
}
65
66
245557
bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
67
{
68
245557
  if (n.getKind() == PLUS)
69
  {
70
340811
    for (Node nc : n)
71
    {
72
241650
      if (!getMonomial(nc, msum))
73
      {
74
        return false;
75
      }
76
    }
77
99161
    return true;
78
  }
79
146396
  return getMonomial(n, msum);
80
}
81
82
100620
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
83
{
84
100620
  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
85
  {
86
82478
    if (getMonomialSum(lit[0], msum))
87
    {
88
82478
      if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
89
      {
90
15396
        return true;
91
      }
92
      else
93
      {
94
        // subtract the other side
95
67082
        std::map<Node, Node> msum2;
96
67082
        NodeManager* nm = NodeManager::currentNM();
97
67082
        if (getMonomialSum(lit[1], msum2))
98
        {
99
155854
          for (std::map<Node, Node>::iterator it = msum2.begin();
100
155854
               it != msum2.end();
101
               ++it)
102
          {
103
88772
            std::map<Node, Node>::iterator it2 = msum.find(it->first);
104
88772
            if (it2 != msum.end())
105
            {
106
              Node r = nm->mkNode(
107
                  MINUS,
108
30
                  it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second,
109
60
                  it->second.isNull() ? nm->mkConst(Rational(1)) : it->second);
110
15
              msum[it->first] = Rewriter::rewrite(r);
111
            }
112
            else
113
            {
114
140304
              msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1))
115
51547
                                                    : negate(it->second);
116
            }
117
          }
118
67082
          return true;
119
        }
120
      }
121
    }
122
  }
123
18142
  return false;
124
}
125
126
102
Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
127
{
128
102
  NodeManager* nm = NodeManager::currentNM();
129
204
  std::vector<Node> children;
130
204
  for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
131
       ++it)
132
  {
133
204
    Node m;
134
102
    if (!it->first.isNull())
135
    {
136
68
      m = mkCoeffTerm(it->second, it->first);
137
    }
138
    else
139
    {
140
34
      Assert(!it->second.isNull());
141
34
      m = it->second;
142
    }
143
102
    children.push_back(m);
144
  }
145
102
  return children.size() > 1
146
             ? nm->mkNode(PLUS, children)
147
204
             : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0)));
148
}
149
150
59237
int ArithMSum::isolate(
151
    Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
152
{
153
59237
  Assert(veq_c.isNull());
154
59237
  std::map<Node, Node>::const_iterator itv = msum.find(v);
155
59237
  if (itv != msum.end())
156
  {
157
57507
    std::vector<Node> children;
158
    Rational r =
159
57507
        itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
160
57499
    if (r.sgn() != 0)
161
    {
162
206522
      for (std::map<Node, Node>::const_iterator it = msum.begin();
163
206522
           it != msum.end();
164
           ++it)
165
      {
166
149031
        if (it->first != v)
167
        {
168
183080
          Node m;
169
91540
          if (!it->first.isNull())
170
          {
171
63096
            m = mkCoeffTerm(it->second, it->first);
172
          }
173
          else
174
          {
175
28444
            m = it->second;
176
          }
177
91540
          children.push_back(m);
178
        }
179
      }
180
114982
      val = children.size() > 1
181
176956
                ? NodeManager::currentNM()->mkNode(PLUS, children)
182
30987
                : (children.size() == 1
183
26263
                       ? children[0]
184
62215
                       : NodeManager::currentNM()->mkConst(Rational(0)));
185
57491
      if (!r.isOne() && !r.isNegativeOne())
186
      {
187
1583
        if (v.getType().isInteger())
188
        {
189
924
          veq_c = NodeManager::currentNM()->mkConst(r.abs());
190
        }
191
        else
192
        {
193
1318
          val = NodeManager::currentNM()->mkNode(
194
              MULT,
195
              val,
196
1318
              NodeManager::currentNM()->mkConst(Rational(1) / r.abs()));
197
        }
198
      }
199
57491
      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
200
57491
      return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
201
    }
202
  }
203
1746
  return 0;
204
}
205
206
9111
int ArithMSum::isolate(
207
    Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
208
{
209
18222
  Node veq_c;
210
18222
  Node val;
211
  // isolate v in the (in)equality
212
9111
  int ires = isolate(v, msum, veq_c, val, k);
213
9111
  if (ires != 0)
214
  {
215
18214
    Node vc = v;
216
9109
    if (!veq_c.isNull())
217
    {
218
48
      if (doCoeff)
219
      {
220
44
        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
221
      }
222
      else
223
      {
224
4
        return 0;
225
      }
226
    }
227
9105
    bool inOrder = ires == 1;
228
9105
    veq = NodeManager::currentNM()->mkNode(
229
        k, inOrder ? vc : val, inOrder ? val : vc);
230
  }
231
9107
  return ires;
232
}
233
234
21072
Node ArithMSum::solveEqualityFor(Node lit, Node v)
235
{
236
21072
  Assert(lit.getKind() == EQUAL);
237
  // first look directly at sides
238
42144
  TypeNode tn = lit[0].getType();
239
34220
  for (unsigned r = 0; r < 2; r++)
240
  {
241
27646
    if (lit[r] == v)
242
    {
243
14498
      return lit[1 - r];
244
    }
245
  }
246
6574
  if (tn.isReal())
247
  {
248
6612
    std::map<Node, Node> msum;
249
6574
    if (ArithMSum::getMonomialSumLit(lit, msum))
250
    {
251
6612
      Node val, veqc;
252
6574
      if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0)
253
      {
254
6574
        if (veqc.isNull())
255
        {
256
          // in this case, we have an integer equality with a coefficient
257
          // on the variable we solved for that could not be eliminated,
258
          // hence we fail.
259
6536
          return val;
260
        }
261
      }
262
    }
263
  }
264
38
  return Node::null();
265
}
266
267
bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
268
{
269
  std::map<Node, Node> msum;
270
  if (getMonomialSum(n, msum))
271
  {
272
    std::map<Node, Node>::iterator it = msum.find(v);
273
    if (it == msum.end())
274
    {
275
      return false;
276
    }
277
    else
278
    {
279
      coeff = it->second;
280
      msum.erase(v);
281
      rem = mkNode(msum);
282
      return true;
283
    }
284
  }
285
  return false;
286
}
287
288
87813
Node ArithMSum::negate(Node t)
289
{
290
  Node tt = NodeManager::currentNM()->mkNode(
291
87813
      MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t);
292
87813
  tt = Rewriter::rewrite(tt);
293
87813
  return tt;
294
}
295
296
563
Node ArithMSum::offset(Node t, int i)
297
{
298
  Node tt = NodeManager::currentNM()->mkNode(
299
563
      PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t);
300
563
  tt = Rewriter::rewrite(tt);
301
563
  return tt;
302
}
303
304
2193
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
305
{
306
6769
  for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
307
  {
308
4576
    Trace(c) << "  ";
309
4576
    if (!it->second.isNull())
310
    {
311
2047
      Trace(c) << it->second;
312
2047
      if (!it->first.isNull())
313
      {
314
1142
        Trace(c) << " * ";
315
      }
316
    }
317
4576
    if (!it->first.isNull())
318
    {
319
3671
      Trace(c) << it->first;
320
    }
321
4576
    Trace(c) << std::endl;
322
  }
323
2193
  Trace(c) << std::endl;
324
2193
}
325
326
}  // namespace theory
327
22746
}  // namespace cvc5