GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_msum.cpp Lines: 130 144 90.3 %
Date: 2021-05-21 Branches: 327 642 50.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Arithmetic utilities regarding monomial sums.
14
 */
15
16
#include "theory/arith/arith_msum.h"
17
18
#include "theory/rewriter.h"
19
20
using namespace cvc5::kind;
21
22
namespace cvc5 {
23
namespace theory {
24
25
537
bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
26
{
27
537
  if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
28
  {
29
537
    c = n[0];
30
537
    v = n[1];
31
537
    return true;
32
  }
33
  return false;
34
}
35
36
639938
bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
37
{
38
639938
  if (n.isConst())
39
  {
40
178856
    if (msum.find(Node::null()) == msum.end())
41
    {
42
178856
      msum[Node::null()] = n;
43
178856
      return true;
44
    }
45
  }
46
461082
  else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
47
  {
48
203055
    if (msum.find(n[1]) == msum.end())
49
    {
50
203055
      msum[n[1]] = n[0];
51
203055
      return true;
52
    }
53
  }
54
  else
55
  {
56
258027
    if (msum.find(n) == msum.end())
57
    {
58
258027
      msum[n] = Node::null();
59
258027
      return true;
60
    }
61
  }
62
  return false;
63
}
64
65
343403
bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
66
{
67
343403
  if (n.getKind() == PLUS)
68
  {
69
654635
    for (Node nc : n)
70
    {
71
475585
      if (!getMonomial(nc, msum))
72
      {
73
        return false;
74
      }
75
    }
76
179050
    return true;
77
  }
78
164353
  return getMonomial(n, msum);
79
}
80
81
103914
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
82
{
83
103914
  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
84
  {
85
94590
    if (getMonomialSum(lit[0], msum))
86
    {
87
94590
      if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
88
      {
89
23001
        return true;
90
      }
91
      else
92
      {
93
        // subtract the other side
94
71589
        std::map<Node, Node> msum2;
95
71589
        NodeManager* nm = NodeManager::currentNM();
96
71589
        if (getMonomialSum(lit[1], msum2))
97
        {
98
161334
          for (std::map<Node, Node>::iterator it = msum2.begin();
99
161334
               it != msum2.end();
100
               ++it)
101
          {
102
89745
            std::map<Node, Node>::iterator it2 = msum.find(it->first);
103
89745
            if (it2 != msum.end())
104
            {
105
              Node r = nm->mkNode(
106
                  MINUS,
107
2
                  it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second,
108
4
                  it->second.isNull() ? nm->mkConst(Rational(1)) : it->second);
109
1
              msum[it->first] = Rewriter::rewrite(r);
110
            }
111
            else
112
            {
113
149698
              msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1))
114
59954
                                                    : negate(it->second);
115
            }
116
          }
117
71589
          return true;
118
        }
119
      }
120
    }
121
  }
122
9324
  return false;
123
}
124
125
348
Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
126
{
127
348
  NodeManager* nm = NodeManager::currentNM();
128
696
  std::vector<Node> children;
129
696
  for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
130
       ++it)
131
  {
132
696
    Node m;
133
348
    if (!it->first.isNull())
134
    {
135
246
      m = mkCoeffTerm(it->second, it->first);
136
    }
137
    else
138
    {
139
102
      Assert(!it->second.isNull());
140
102
      m = it->second;
141
    }
142
348
    children.push_back(m);
143
  }
144
348
  return children.size() > 1
145
             ? nm->mkNode(PLUS, children)
146
696
             : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0)));
147
}
148
149
67736
int ArithMSum::isolate(
150
    Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
151
{
152
67736
  Assert(veq_c.isNull());
153
67736
  std::map<Node, Node>::const_iterator itv = msum.find(v);
154
67736
  if (itv != msum.end())
155
  {
156
66416
    std::vector<Node> children;
157
    Rational r =
158
66416
        itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
159
66416
    if (r.sgn() != 0)
160
    {
161
237659
      for (std::map<Node, Node>::const_iterator it = msum.begin();
162
237659
           it != msum.end();
163
           ++it)
164
      {
165
171243
        if (it->first != v)
166
        {
167
209654
          Node m;
168
104827
          if (!it->first.isNull())
169
          {
170
68998
            m = mkCoeffTerm(it->second, it->first);
171
          }
172
          else
173
          {
174
35829
            m = it->second;
175
          }
176
104827
          children.push_back(m);
177
        }
178
      }
179
132832
      val = children.size() > 1
180
200662
                ? NodeManager::currentNM()->mkNode(PLUS, children)
181
33915
                : (children.size() == 1
182
26846
                       ? children[0]
183
73485
                       : NodeManager::currentNM()->mkConst(Rational(0)));
184
66416
      if (!r.isOne() && !r.isNegativeOne())
185
      {
186
3904
        if (v.getType().isInteger())
187
        {
188
1653
          veq_c = NodeManager::currentNM()->mkConst(r.abs());
189
        }
190
        else
191
        {
192
4502
          val = NodeManager::currentNM()->mkNode(
193
              MULT,
194
              val,
195
4502
              NodeManager::currentNM()->mkConst(Rational(1) / r.abs()));
196
        }
197
      }
198
66416
      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
199
66416
      return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
200
    }
201
  }
202
1320
  return 0;
203
}
204
205
11670
int ArithMSum::isolate(
206
    Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
207
{
208
23340
  Node veq_c;
209
23340
  Node val;
210
  // isolate v in the (in)equality
211
11670
  int ires = isolate(v, msum, veq_c, val, k);
212
11670
  if (ires != 0)
213
  {
214
23312
    Node vc = v;
215
11662
    if (!veq_c.isNull())
216
    {
217
63
      if (doCoeff)
218
      {
219
51
        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
220
      }
221
      else
222
      {
223
12
        return 0;
224
      }
225
    }
226
11650
    bool inOrder = ires == 1;
227
11650
    veq = NodeManager::currentNM()->mkNode(
228
        k, inOrder ? vc : val, inOrder ? val : vc);
229
  }
230
11658
  return ires;
231
}
232
233
21446
Node ArithMSum::solveEqualityFor(Node lit, Node v)
234
{
235
21446
  Assert(lit.getKind() == EQUAL);
236
  // first look directly at sides
237
42892
  TypeNode tn = lit[0].getType();
238
34550
  for (unsigned r = 0; r < 2; r++)
239
  {
240
27998
    if (lit[r] == v)
241
    {
242
14894
      return lit[1 - r];
243
    }
244
  }
245
6552
  if (tn.isReal())
246
  {
247
6632
    std::map<Node, Node> msum;
248
6552
    if (ArithMSum::getMonomialSumLit(lit, msum))
249
    {
250
6632
      Node val, veqc;
251
6552
      if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0)
252
      {
253
6552
        if (veqc.isNull())
254
        {
255
          // in this case, we have an integer equality with a coefficient
256
          // on the variable we solved for that could not be eliminated,
257
          // hence we fail.
258
6472
          return val;
259
        }
260
      }
261
    }
262
  }
263
80
  return Node::null();
264
}
265
266
bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
267
{
268
  std::map<Node, Node> msum;
269
  if (getMonomialSum(n, msum))
270
  {
271
    std::map<Node, Node>::iterator it = msum.find(v);
272
    if (it == msum.end())
273
    {
274
      return false;
275
    }
276
    else
277
    {
278
      coeff = it->second;
279
      msum.erase(v);
280
      rem = mkNode(msum);
281
      return true;
282
    }
283
  }
284
  return false;
285
}
286
287
100814
Node ArithMSum::negate(Node t)
288
{
289
  Node tt = NodeManager::currentNM()->mkNode(
290
100814
      MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t);
291
100814
  tt = Rewriter::rewrite(tt);
292
100814
  return tt;
293
}
294
295
401
Node ArithMSum::offset(Node t, int i)
296
{
297
  Node tt = NodeManager::currentNM()->mkNode(
298
401
      PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t);
299
401
  tt = Rewriter::rewrite(tt);
300
401
  return tt;
301
}
302
303
1655
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
304
{
305
4963
  for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
306
  {
307
3308
    Trace(c) << "  ";
308
3308
    if (!it->second.isNull())
309
    {
310
1542
      Trace(c) << it->second;
311
1542
      if (!it->first.isNull())
312
      {
313
862
        Trace(c) << " * ";
314
      }
315
    }
316
3308
    if (!it->first.isNull())
317
    {
318
2628
      Trace(c) << it->first;
319
    }
320
3308
    Trace(c) << std::endl;
321
  }
322
1655
  Trace(c) << std::endl;
323
1655
}
324
325
}  // namespace theory
326
27735
}  // namespace cvc5