GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_msum.cpp Lines: 130 144 90.3 %
Date: 2021-08-06 Branches: 327 642 50.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Arithmetic utilities regarding monomial sums.
14
 */
15
16
#include "theory/arith/arith_msum.h"
17
18
#include "theory/rewriter.h"
19
#include "util/rational.h"
20
21
using namespace cvc5::kind;
22
23
namespace cvc5 {
24
namespace theory {
25
26
664
bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
27
{
28
664
  if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
29
  {
30
664
    c = n[0];
31
664
    v = n[1];
32
664
    return true;
33
  }
34
  return false;
35
}
36
37
524582
bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
38
{
39
524582
  if (n.isConst())
40
  {
41
154366
    if (msum.find(Node::null()) == msum.end())
42
    {
43
154366
      msum[Node::null()] = n;
44
154366
      return true;
45
    }
46
  }
47
370216
  else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
48
  {
49
157868
    if (msum.find(n[1]) == msum.end())
50
    {
51
157868
      msum[n[1]] = n[0];
52
157868
      return true;
53
    }
54
  }
55
  else
56
  {
57
212348
    if (msum.find(n) == msum.end())
58
    {
59
212348
      msum[n] = Node::null();
60
212348
      return true;
61
    }
62
  }
63
  return false;
64
}
65
66
330018
bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
67
{
68
330018
  if (n.getKind() == PLUS)
69
  {
70
474480
    for (Node nc : n)
71
    {
72
334522
      if (!getMonomial(nc, msum))
73
      {
74
        return false;
75
      }
76
    }
77
139958
    return true;
78
  }
79
190060
  return getMonomial(n, msum);
80
}
81
82
143691
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
83
{
84
143691
  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
85
  {
86
111555
    if (getMonomialSum(lit[0], msum))
87
    {
88
111555
      if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
89
      {
90
23640
        return true;
91
      }
92
      else
93
      {
94
        // subtract the other side
95
87915
        std::map<Node, Node> msum2;
96
87915
        NodeManager* nm = NodeManager::currentNM();
97
87915
        if (getMonomialSum(lit[1], msum2))
98
        {
99
205160
          for (std::map<Node, Node>::iterator it = msum2.begin();
100
205160
               it != msum2.end();
101
               ++it)
102
          {
103
117245
            std::map<Node, Node>::iterator it2 = msum.find(it->first);
104
117245
            if (it2 != msum.end())
105
            {
106
              Node r = nm->mkNode(
107
                  MINUS,
108
2
                  it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second,
109
4
                  it->second.isNull() ? nm->mkConst(Rational(1)) : it->second);
110
1
              msum[it->first] = Rewriter::rewrite(r);
111
            }
112
            else
113
            {
114
188802
              msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1))
115
71558
                                                    : negate(it->second);
116
            }
117
          }
118
87915
          return true;
119
        }
120
      }
121
    }
122
  }
123
32136
  return false;
124
}
125
126
348
Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
127
{
128
348
  NodeManager* nm = NodeManager::currentNM();
129
696
  std::vector<Node> children;
130
696
  for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
131
       ++it)
132
  {
133
696
    Node m;
134
348
    if (!it->first.isNull())
135
    {
136
246
      m = mkCoeffTerm(it->second, it->first);
137
    }
138
    else
139
    {
140
102
      Assert(!it->second.isNull());
141
102
      m = it->second;
142
    }
143
348
    children.push_back(m);
144
  }
145
348
  return children.size() > 1
146
             ? nm->mkNode(PLUS, children)
147
696
             : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0)));
148
}
149
150
84409
int ArithMSum::isolate(
151
    Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
152
{
153
84409
  Assert(veq_c.isNull());
154
84409
  std::map<Node, Node>::const_iterator itv = msum.find(v);
155
84409
  if (itv != msum.end())
156
  {
157
83097
    std::vector<Node> children;
158
    Rational r =
159
83097
        itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
160
83097
    if (r.sgn() != 0)
161
    {
162
300087
      for (std::map<Node, Node>::const_iterator it = msum.begin();
163
300087
           it != msum.end();
164
           ++it)
165
      {
166
216990
        if (it->first != v)
167
        {
168
267786
          Node m;
169
133893
          if (!it->first.isNull())
170
          {
171
91649
            m = mkCoeffTerm(it->second, it->first);
172
          }
173
          else
174
          {
175
42244
            m = it->second;
176
          }
177
133893
          children.push_back(m);
178
        }
179
      }
180
166194
      val = children.size() > 1
181
255468
                ? NodeManager::currentNM()->mkNode(PLUS, children)
182
44637
                : (children.size() == 1
183
37288
                       ? children[0]
184
90446
                       : NodeManager::currentNM()->mkConst(Rational(0)));
185
83097
      if (!r.isOne() && !r.isNegativeOne())
186
      {
187
3846
        if (v.getType().isInteger())
188
        {
189
1683
          veq_c = NodeManager::currentNM()->mkConst(r.abs());
190
        }
191
        else
192
        {
193
4326
          val = NodeManager::currentNM()->mkNode(
194
              MULT,
195
              val,
196
4326
              NodeManager::currentNM()->mkConst(Rational(1) / r.abs()));
197
        }
198
      }
199
83097
      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
200
83097
      return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
201
    }
202
  }
203
1312
  return 0;
204
}
205
206
12219
int ArithMSum::isolate(
207
    Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
208
{
209
24438
  Node veq_c;
210
24438
  Node val;
211
  // isolate v in the (in)equality
212
12219
  int ires = isolate(v, msum, veq_c, val, k);
213
12219
  if (ires != 0)
214
  {
215
24410
    Node vc = v;
216
12211
    if (!veq_c.isNull())
217
    {
218
80
      if (doCoeff)
219
      {
220
68
        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
221
      }
222
      else
223
      {
224
12
        return 0;
225
      }
226
    }
227
12199
    bool inOrder = ires == 1;
228
12199
    veq = NodeManager::currentNM()->mkNode(
229
        k, inOrder ? vc : val, inOrder ? val : vc);
230
  }
231
12207
  return ires;
232
}
233
234
32934
Node ArithMSum::solveEqualityFor(Node lit, Node v)
235
{
236
32934
  Assert(lit.getKind() == EQUAL);
237
  // first look directly at sides
238
65868
  TypeNode tn = lit[0].getType();
239
53930
  for (unsigned r = 0; r < 2; r++)
240
  {
241
43432
    if (lit[r] == v)
242
    {
243
22436
      return lit[1 - r];
244
    }
245
  }
246
10498
  if (tn.isReal())
247
  {
248
10594
    std::map<Node, Node> msum;
249
10498
    if (ArithMSum::getMonomialSumLit(lit, msum))
250
    {
251
10594
      Node val, veqc;
252
10498
      if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0)
253
      {
254
10498
        if (veqc.isNull())
255
        {
256
          // in this case, we have an integer equality with a coefficient
257
          // on the variable we solved for that could not be eliminated,
258
          // hence we fail.
259
10402
          return val;
260
        }
261
      }
262
    }
263
  }
264
96
  return Node::null();
265
}
266
267
bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
268
{
269
  std::map<Node, Node> msum;
270
  if (getMonomialSum(n, msum))
271
  {
272
    std::map<Node, Node>::iterator it = msum.find(v);
273
    if (it == msum.end())
274
    {
275
      return false;
276
    }
277
    else
278
    {
279
      coeff = it->second;
280
      msum.erase(v);
281
      rem = mkNode(msum);
282
      return true;
283
    }
284
  }
285
  return false;
286
}
287
288
123207
Node ArithMSum::negate(Node t)
289
{
290
  Node tt = NodeManager::currentNM()->mkNode(
291
123207
      MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t);
292
123207
  tt = Rewriter::rewrite(tt);
293
123207
  return tt;
294
}
295
296
964
Node ArithMSum::offset(Node t, int i)
297
{
298
  Node tt = NodeManager::currentNM()->mkNode(
299
964
      PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t);
300
964
  tt = Rewriter::rewrite(tt);
301
964
  return tt;
302
}
303
304
3274
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
305
{
306
10210
  for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
307
  {
308
6936
    Trace(c) << "  ";
309
6936
    if (!it->second.isNull())
310
    {
311
3112
      Trace(c) << it->second;
312
3112
      if (!it->first.isNull())
313
      {
314
1666
        Trace(c) << " * ";
315
      }
316
    }
317
6936
    if (!it->first.isNull())
318
    {
319
5490
      Trace(c) << it->first;
320
    }
321
6936
    Trace(c) << std::endl;
322
  }
323
3274
  Trace(c) << std::endl;
324
3274
}
325
326
}  // namespace theory
327
29322
}  // namespace cvc5