GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_msum.cpp Lines: 130 144 90.3 %
Date: 2021-05-22 Branches: 327 642 50.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Arithmetic utilities regarding monomial sums.
14
 */
15
16
#include "theory/arith/arith_msum.h"
17
18
#include "theory/rewriter.h"
19
20
using namespace cvc5::kind;
21
22
namespace cvc5 {
23
namespace theory {
24
25
612
bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
26
{
27
612
  if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
28
  {
29
612
    c = n[0];
30
612
    v = n[1];
31
612
    return true;
32
  }
33
  return false;
34
}
35
36
641101
bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
37
{
38
641101
  if (n.isConst())
39
  {
40
179113
    if (msum.find(Node::null()) == msum.end())
41
    {
42
179113
      msum[Node::null()] = n;
43
179113
      return true;
44
    }
45
  }
46
461988
  else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
47
  {
48
203565
    if (msum.find(n[1]) == msum.end())
49
    {
50
203565
      msum[n[1]] = n[0];
51
203565
      return true;
52
    }
53
  }
54
  else
55
  {
56
258423
    if (msum.find(n) == msum.end())
57
    {
58
258423
      msum[n] = Node::null();
59
258423
      return true;
60
    }
61
  }
62
  return false;
63
}
64
65
343952
bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
66
{
67
343952
  if (n.getKind() == PLUS)
68
  {
69
655925
    for (Node nc : n)
70
    {
71
476537
      if (!getMonomial(nc, msum))
72
      {
73
        return false;
74
      }
75
    }
76
179388
    return true;
77
  }
78
164564
  return getMonomial(n, msum);
79
}
80
81
104243
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
82
{
83
104243
  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
84
  {
85
94910
    if (getMonomialSum(lit[0], msum))
86
    {
87
94910
      if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
88
      {
89
23119
        return true;
90
      }
91
      else
92
      {
93
        // subtract the other side
94
71791
        std::map<Node, Node> msum2;
95
71791
        NodeManager* nm = NodeManager::currentNM();
96
71791
        if (getMonomialSum(lit[1], msum2))
97
        {
98
161709
          for (std::map<Node, Node>::iterator it = msum2.begin();
99
161709
               it != msum2.end();
100
               ++it)
101
          {
102
89918
            std::map<Node, Node>::iterator it2 = msum.find(it->first);
103
89918
            if (it2 != msum.end())
104
            {
105
              Node r = nm->mkNode(
106
                  MINUS,
107
2
                  it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second,
108
4
                  it->second.isNull() ? nm->mkConst(Rational(1)) : it->second);
109
1
              msum[it->first] = Rewriter::rewrite(r);
110
            }
111
            else
112
            {
113
150024
              msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1))
114
60107
                                                    : negate(it->second);
115
            }
116
          }
117
71791
          return true;
118
        }
119
      }
120
    }
121
  }
122
9333
  return false;
123
}
124
125
348
Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
126
{
127
348
  NodeManager* nm = NodeManager::currentNM();
128
696
  std::vector<Node> children;
129
696
  for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
130
       ++it)
131
  {
132
696
    Node m;
133
348
    if (!it->first.isNull())
134
    {
135
246
      m = mkCoeffTerm(it->second, it->first);
136
    }
137
    else
138
    {
139
102
      Assert(!it->second.isNull());
140
102
      m = it->second;
141
    }
142
348
    children.push_back(m);
143
  }
144
348
  return children.size() > 1
145
             ? nm->mkNode(PLUS, children)
146
696
             : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0)));
147
}
148
149
68045
int ArithMSum::isolate(
150
    Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
151
{
152
68045
  Assert(veq_c.isNull());
153
68045
  std::map<Node, Node>::const_iterator itv = msum.find(v);
154
68045
  if (itv != msum.end())
155
  {
156
66725
    std::vector<Node> children;
157
    Rational r =
158
66725
        itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
159
66725
    if (r.sgn() != 0)
160
    {
161
238562
      for (std::map<Node, Node>::const_iterator it = msum.begin();
162
238562
           it != msum.end();
163
           ++it)
164
      {
165
171837
        if (it->first != v)
166
        {
167
210224
          Node m;
168
105112
          if (!it->first.isNull())
169
          {
170
69159
            m = mkCoeffTerm(it->second, it->first);
171
          }
172
          else
173
          {
174
35953
            m = it->second;
175
          }
176
105112
          children.push_back(m);
177
        }
178
      }
179
133450
      val = children.size() > 1
180
201786
                ? NodeManager::currentNM()->mkNode(PLUS, children)
181
34168
                : (children.size() == 1
182
27023
                       ? children[0]
183
73870
                       : NodeManager::currentNM()->mkConst(Rational(0)));
184
66725
      if (!r.isOne() && !r.isNegativeOne())
185
      {
186
3936
        if (v.getType().isInteger())
187
        {
188
1675
          veq_c = NodeManager::currentNM()->mkConst(r.abs());
189
        }
190
        else
191
        {
192
4522
          val = NodeManager::currentNM()->mkNode(
193
              MULT,
194
              val,
195
4522
              NodeManager::currentNM()->mkConst(Rational(1) / r.abs()));
196
        }
197
      }
198
66725
      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
199
66725
      return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
200
    }
201
  }
202
1320
  return 0;
203
}
204
205
11679
int ArithMSum::isolate(
206
    Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
207
{
208
23358
  Node veq_c;
209
23358
  Node val;
210
  // isolate v in the (in)equality
211
11679
  int ires = isolate(v, msum, veq_c, val, k);
212
11679
  if (ires != 0)
213
  {
214
23330
    Node vc = v;
215
11671
    if (!veq_c.isNull())
216
    {
217
63
      if (doCoeff)
218
      {
219
51
        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
220
      }
221
      else
222
      {
223
12
        return 0;
224
      }
225
    }
226
11659
    bool inOrder = ires == 1;
227
11659
    veq = NodeManager::currentNM()->mkNode(
228
        k, inOrder ? vc : val, inOrder ? val : vc);
229
  }
230
11667
  return ires;
231
}
232
233
21434
Node ArithMSum::solveEqualityFor(Node lit, Node v)
234
{
235
21434
  Assert(lit.getKind() == EQUAL);
236
  // first look directly at sides
237
42868
  TypeNode tn = lit[0].getType();
238
34490
  for (unsigned r = 0; r < 2; r++)
239
  {
240
27962
    if (lit[r] == v)
241
    {
242
14906
      return lit[1 - r];
243
    }
244
  }
245
6528
  if (tn.isReal())
246
  {
247
6608
    std::map<Node, Node> msum;
248
6528
    if (ArithMSum::getMonomialSumLit(lit, msum))
249
    {
250
6608
      Node val, veqc;
251
6528
      if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0)
252
      {
253
6528
        if (veqc.isNull())
254
        {
255
          // in this case, we have an integer equality with a coefficient
256
          // on the variable we solved for that could not be eliminated,
257
          // hence we fail.
258
6448
          return val;
259
        }
260
      }
261
    }
262
  }
263
80
  return Node::null();
264
}
265
266
bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
267
{
268
  std::map<Node, Node> msum;
269
  if (getMonomialSum(n, msum))
270
  {
271
    std::map<Node, Node>::iterator it = msum.find(v);
272
    if (it == msum.end())
273
    {
274
      return false;
275
    }
276
    else
277
    {
278
      coeff = it->second;
279
      msum.erase(v);
280
      rem = mkNode(msum);
281
      return true;
282
    }
283
  }
284
  return false;
285
}
286
287
101246
Node ArithMSum::negate(Node t)
288
{
289
  Node tt = NodeManager::currentNM()->mkNode(
290
101246
      MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t);
291
101246
  tt = Rewriter::rewrite(tt);
292
101246
  return tt;
293
}
294
295
498
Node ArithMSum::offset(Node t, int i)
296
{
297
  Node tt = NodeManager::currentNM()->mkNode(
298
498
      PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t);
299
498
  tt = Rewriter::rewrite(tt);
300
498
  return tt;
301
}
302
303
1659
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
304
{
305
4979
  for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
306
  {
307
3320
    Trace(c) << "  ";
308
3320
    if (!it->second.isNull())
309
    {
310
1548
      Trace(c) << it->second;
311
1548
      if (!it->first.isNull())
312
      {
313
866
        Trace(c) << " * ";
314
      }
315
    }
316
3320
    if (!it->first.isNull())
317
    {
318
2638
      Trace(c) << it->first;
319
    }
320
3320
    Trace(c) << std::endl;
321
  }
322
1659
  Trace(c) << std::endl;
323
1659
}
324
325
}  // namespace theory
326
28191
}  // namespace cvc5