GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_msum.cpp Lines: 131 144 91.0 %
Date: 2021-03-22 Branches: 329 642 51.2 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file arith_msum.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of arith_msum
13
 **/
14
15
#include "theory/arith/arith_msum.h"
16
17
#include "theory/rewriter.h"
18
19
using namespace CVC4::kind;
20
21
namespace CVC4 {
22
namespace theory {
23
24
530
bool ArithMSum::getMonomial(Node n, Node& c, Node& v)
25
{
26
530
  if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
27
  {
28
523
    c = n[0];
29
523
    v = n[1];
30
523
    return true;
31
  }
32
7
  return false;
33
}
34
35
689580
bool ArithMSum::getMonomial(Node n, std::map<Node, Node>& msum)
36
{
37
689580
  if (n.isConst())
38
  {
39
201966
    if (msum.find(Node::null()) == msum.end())
40
    {
41
201966
      msum[Node::null()] = n;
42
201966
      return true;
43
    }
44
  }
45
487614
  else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst())
46
  {
47
214344
    if (msum.find(n[1]) == msum.end())
48
    {
49
214344
      msum[n[1]] = n[0];
50
214344
      return true;
51
    }
52
  }
53
  else
54
  {
55
273270
    if (msum.find(n) == msum.end())
56
    {
57
273270
      msum[n] = Node::null();
58
273270
      return true;
59
    }
60
  }
61
  return false;
62
}
63
64
374496
bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
65
{
66
374496
  if (n.getKind() == PLUS)
67
  {
68
699426
    for (Node nc : n)
69
    {
70
507255
      if (!getMonomial(nc, msum))
71
      {
72
        return false;
73
      }
74
    }
75
192171
    return true;
76
  }
77
182325
  return getMonomial(n, msum);
78
}
79
80
112740
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
81
{
82
112740
  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
83
  {
84
102510
    if (getMonomialSum(lit[0], msum))
85
    {
86
102510
      if (lit[1].isConst() && lit[1].getConst<Rational>().isZero())
87
      {
88
25556
        return true;
89
      }
90
      else
91
      {
92
        // subtract the other side
93
76954
        std::map<Node, Node> msum2;
94
76954
        NodeManager* nm = NodeManager::currentNM();
95
76954
        if (getMonomialSum(lit[1], msum2))
96
        {
97
173419
          for (std::map<Node, Node>::iterator it = msum2.begin();
98
173419
               it != msum2.end();
99
               ++it)
100
          {
101
96465
            std::map<Node, Node>::iterator it2 = msum.find(it->first);
102
96465
            if (it2 != msum.end())
103
            {
104
              Node r = nm->mkNode(
105
                  MINUS,
106
2
                  it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second,
107
4
                  it->second.isNull() ? nm->mkConst(Rational(1)) : it->second);
108
1
              msum[it->first] = Rewriter::rewrite(r);
109
            }
110
            else
111
            {
112
162518
              msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1))
113
66054
                                                    : negate(it->second);
114
            }
115
          }
116
76954
          return true;
117
        }
118
      }
119
    }
120
  }
121
10230
  return false;
122
}
123
124
385
Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
125
{
126
385
  NodeManager* nm = NodeManager::currentNM();
127
770
  std::vector<Node> children;
128
770
  for (std::map<Node, Node>::const_iterator it = msum.begin(); it != msum.end();
129
       ++it)
130
  {
131
770
    Node m;
132
385
    if (!it->first.isNull())
133
    {
134
250
      m = mkCoeffTerm(it->second, it->first);
135
    }
136
    else
137
    {
138
135
      Assert(!it->second.isNull());
139
135
      m = it->second;
140
    }
141
385
    children.push_back(m);
142
  }
143
385
  return children.size() > 1
144
             ? nm->mkNode(PLUS, children)
145
770
             : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0)));
146
}
147
148
72726
int ArithMSum::isolate(
149
    Node v, const std::map<Node, Node>& msum, Node& veq_c, Node& val, Kind k)
150
{
151
72726
  Assert(veq_c.isNull());
152
72726
  std::map<Node, Node>::const_iterator itv = msum.find(v);
153
72726
  if (itv != msum.end())
154
  {
155
71536
    std::vector<Node> children;
156
    Rational r =
157
71536
        itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
158
71536
    if (r.sgn() != 0)
159
    {
160
257541
      for (std::map<Node, Node>::const_iterator it = msum.begin();
161
257541
           it != msum.end();
162
           ++it)
163
      {
164
186005
        if (it->first != v)
165
        {
166
228938
          Node m;
167
114469
          if (!it->first.isNull())
168
          {
169
76891
            m = mkCoeffTerm(it->second, it->first);
170
          }
171
          else
172
          {
173
37578
            m = it->second;
174
          }
175
114469
          children.push_back(m);
176
        }
177
      }
178
143072
      val = children.size() > 1
179
216510
                ? NodeManager::currentNM()->mkNode(PLUS, children)
180
36719
                : (children.size() == 1
181
28984
                       ? children[0]
182
79271
                       : NodeManager::currentNM()->mkConst(Rational(0)));
183
71536
      if (!r.isOne() && !r.isNegativeOne())
184
      {
185
4157
        if (v.getType().isInteger())
186
        {
187
1612
          veq_c = NodeManager::currentNM()->mkConst(r.abs());
188
        }
189
        else
190
        {
191
5090
          val = NodeManager::currentNM()->mkNode(
192
              MULT,
193
              val,
194
5090
              NodeManager::currentNM()->mkConst(Rational(1) / r.abs()));
195
        }
196
      }
197
71536
      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
198
71536
      return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
199
    }
200
  }
201
1190
  return 0;
202
}
203
204
11423
int ArithMSum::isolate(
205
    Node v, const std::map<Node, Node>& msum, Node& veq, Kind k, bool doCoeff)
206
{
207
22846
  Node veq_c;
208
22846
  Node val;
209
  // isolate v in the (in)equality
210
11423
  int ires = isolate(v, msum, veq_c, val, k);
211
11423
  if (ires != 0)
212
  {
213
22834
    Node vc = v;
214
11423
    if (!veq_c.isNull())
215
    {
216
54
      if (doCoeff)
217
      {
218
42
        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
219
      }
220
      else
221
      {
222
12
        return 0;
223
      }
224
    }
225
11411
    bool inOrder = ires == 1;
226
11411
    veq = NodeManager::currentNM()->mkNode(
227
        k, inOrder ? vc : val, inOrder ? val : vc);
228
  }
229
11411
  return ires;
230
}
231
232
22794
Node ArithMSum::solveEqualityFor(Node lit, Node v)
233
{
234
22794
  Assert(lit.getKind() == EQUAL);
235
  // first look directly at sides
236
45588
  TypeNode tn = lit[0].getType();
237
35530
  for (unsigned r = 0; r < 2; r++)
238
  {
239
29162
    if (lit[r] == v)
240
    {
241
16426
      return lit[1 - r];
242
    }
243
  }
244
6368
  if (tn.isReal())
245
  {
246
6464
    std::map<Node, Node> msum;
247
6368
    if (ArithMSum::getMonomialSumLit(lit, msum))
248
    {
249
6464
      Node val, veqc;
250
6368
      if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0)
251
      {
252
6368
        if (veqc.isNull())
253
        {
254
          // in this case, we have an integer equality with a coefficient
255
          // on the variable we solved for that could not be eliminated,
256
          // hence we fail.
257
6272
          return val;
258
        }
259
      }
260
    }
261
  }
262
96
  return Node::null();
263
}
264
265
bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
266
{
267
  std::map<Node, Node> msum;
268
  if (getMonomialSum(n, msum))
269
  {
270
    std::map<Node, Node>::iterator it = msum.find(v);
271
    if (it == msum.end())
272
    {
273
      return false;
274
    }
275
    else
276
    {
277
      coeff = it->second;
278
      msum.erase(v);
279
      rem = mkNode(msum);
280
      return true;
281
    }
282
  }
283
  return false;
284
}
285
286
109726
Node ArithMSum::negate(Node t)
287
{
288
  Node tt = NodeManager::currentNM()->mkNode(
289
109726
      MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t);
290
109726
  tt = Rewriter::rewrite(tt);
291
109726
  return tt;
292
}
293
294
373
Node ArithMSum::offset(Node t, int i)
295
{
296
  Node tt = NodeManager::currentNM()->mkNode(
297
373
      PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t);
298
373
  tt = Rewriter::rewrite(tt);
299
373
  return tt;
300
}
301
302
1536
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
303
{
304
4608
  for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
305
  {
306
3072
    Trace(c) << "  ";
307
3072
    if (!it->second.isNull())
308
    {
309
1426
      Trace(c) << it->second;
310
1426
      if (!it->first.isNull())
311
      {
312
840
        Trace(c) << " * ";
313
      }
314
    }
315
3072
    if (!it->first.isNull())
316
    {
317
2486
      Trace(c) << it->first;
318
    }
319
3072
    Trace(c) << std::endl;
320
  }
321
1536
  Trace(c) << std::endl;
322
1536
}
323
324
} /* CVC4::theory namespace */
325
26676
} /* CVC4 namespace */