GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_ite_utils.cpp Lines: 1 279 0.4 %
Date: 2021-03-22 Branches: 2 1346 0.1 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file arith_ite_utils.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Tim King, Aina Niemetz, Piotr Trojanek
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief [[ Add one-line brief description here ]]
13
 **
14
 ** [[ Add lengthier description here ]]
15
 ** \todo document this file
16
 **/
17
18
#include "theory/arith/arith_ite_utils.h"
19
20
#include <ostream>
21
22
#include "base/output.h"
23
#include "options/smt_options.h"
24
#include "preprocessing/util/ite_utilities.h"
25
#include "theory/arith/arith_utilities.h"
26
#include "theory/arith/normal_form.h"
27
#include "theory/rewriter.h"
28
#include "theory/substitutions.h"
29
#include "theory/theory_model.h"
30
31
using namespace std;
32
33
namespace CVC4 {
34
namespace theory {
35
namespace arith {
36
37
Node ArithIteUtils::applyReduceVariablesInItes(Node n){
38
  NodeBuilder<> nb(n.getKind());
39
  if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
40
    nb << (n.getOperator());
41
  }
42
  for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
43
    nb << reduceVariablesInItes(*it);
44
  }
45
  Node res = nb;
46
  return res;
47
}
48
49
Node ArithIteUtils::reduceVariablesInItes(Node n){
50
  using namespace CVC4::kind;
51
  if(d_reduceVar.find(n) != d_reduceVar.end()){
52
    Node res = d_reduceVar[n];
53
    return res.isNull() ? n : res;
54
  }
55
56
  switch(n.getKind()){
57
  case ITE:{
58
    Node c = n[0], t = n[1], e = n[2];
59
    if(n.getType().isReal()){
60
      Node rc = reduceVariablesInItes(c);
61
      Node rt = reduceVariablesInItes(t);
62
      Node re = reduceVariablesInItes(e);
63
64
      Node vt = d_varParts[t];
65
      Node ve = d_varParts[e];
66
      Node vpite = (vt == ve) ? vt : Node::null();
67
68
      if(vpite.isNull()){
69
        Node rite = rc.iteNode(rt, re);
70
        // do not apply
71
        d_reduceVar[n] = rite;
72
        d_constants[n] = mkRationalNode(Rational(0));
73
        d_varParts[n] = rite; // treat the ite as a variable
74
        return rite;
75
      }else{
76
        NodeManager* nm = NodeManager::currentNM();
77
        Node constantite = rc.iteNode(d_constants[t], d_constants[e]);
78
        Node sum = nm->mkNode(kind::PLUS, vpite, constantite);
79
        d_reduceVar[n] = sum;
80
        d_constants[n] = constantite;
81
        d_varParts[n] = vpite;
82
        return sum;
83
      }
84
    }else{ // non-arith ite
85
      if(!d_contains.containsTermITE(n)){
86
        // don't bother adding to d_reduceVar
87
        return n;
88
      }else{
89
        Node newIte = applyReduceVariablesInItes(n);
90
        d_reduceVar[n] = (n == newIte) ? Node::null(): newIte;
91
        return newIte;
92
      }
93
    }
94
  }break;
95
  default:
96
    if(n.getType().isReal() && Polynomial::isMember(n)){
97
      Node newn = Node::null();
98
      if(!d_contains.containsTermITE(n)){
99
        newn = n;
100
      }else if(n.getNumChildren() > 0){
101
        newn = applyReduceVariablesInItes(n);
102
        newn = Rewriter::rewrite(newn);
103
        Assert(Polynomial::isMember(newn));
104
      }else{
105
        newn = n;
106
      }
107
108
      Polynomial p = Polynomial::parsePolynomial(newn);
109
      if(p.isConstant()){
110
        d_constants[n] = newn;
111
        d_varParts[n] = mkRationalNode(Rational(0));
112
        // don't bother adding to d_reduceVar
113
        return newn;
114
      }else if(!p.containsConstant()){
115
        d_constants[n] = mkRationalNode(Rational(0));
116
        d_varParts[n] = newn;
117
        d_reduceVar[n] = p.getNode();
118
        return p.getNode();
119
      }else{
120
        Monomial mc = p.getHead();
121
        d_constants[n] = mc.getConstant().getNode();
122
        d_varParts[n] = p.getTail().getNode();
123
        d_reduceVar[n] = newn;
124
        return newn;
125
      }
126
    }else{
127
      if(!d_contains.containsTermITE(n)){
128
        return n;
129
      }
130
      if(n.getNumChildren() > 0){
131
        Node res = applyReduceVariablesInItes(n);
132
        d_reduceVar[n] = res;
133
        return res;
134
      }else{
135
        return n;
136
      }
137
    }
138
    break;
139
  }
140
  Unreachable();
141
}
142
143
ArithIteUtils::ArithIteUtils(
144
    preprocessing::util::ContainsTermITEVisitor& contains,
145
    context::Context* uc,
146
    TheoryModel* model)
147
    : d_contains(contains),
148
      d_subs(NULL),
149
      d_model(model),
150
      d_one(1),
151
      d_subcount(uc, 0),
152
      d_skolems(uc),
153
      d_implies(),
154
      d_orBinEqs()
155
{
156
  d_subs = new SubstitutionMap(uc);
157
}
158
159
ArithIteUtils::~ArithIteUtils(){
160
  delete d_subs;
161
  d_subs = NULL;
162
}
163
164
void ArithIteUtils::clear(){
165
  d_reduceVar.clear();
166
  d_constants.clear();
167
  d_varParts.clear();
168
}
169
170
const Integer& ArithIteUtils::gcdIte(Node n){
171
  if(d_gcds.find(n) != d_gcds.end()){
172
    return d_gcds[n];
173
  }
174
  if(n.getKind() == kind::CONST_RATIONAL){
175
    const Rational& q = n.getConst<Rational>();
176
    if(q.isIntegral()){
177
      d_gcds[n] = q.getNumerator();
178
      return d_gcds[n];
179
    }else{
180
      return d_one;
181
    }
182
  }else if(n.getKind() == kind::ITE && n.getType().isReal()){
183
    const Integer& tgcd = gcdIte(n[1]);
184
    if(tgcd.isOne()){
185
      d_gcds[n] = d_one;
186
      return d_one;
187
    }else{
188
      const Integer& egcd = gcdIte(n[2]);
189
      Integer ite_gcd = tgcd.gcd(egcd);
190
      d_gcds[n] = ite_gcd;
191
      return d_gcds[n];
192
    }
193
  }
194
  return d_one;
195
}
196
197
Node ArithIteUtils::reduceIteConstantIteByGCD_rec(Node n, const Rational& q){
198
  if(n.isConst()){
199
    Assert(n.getKind() == kind::CONST_RATIONAL);
200
    return mkRationalNode(n.getConst<Rational>() * q);
201
  }else{
202
    Assert(n.getKind() == kind::ITE);
203
    Assert(n.getType().isInteger());
204
    Node rc = reduceConstantIteByGCD(n[0]);
205
    Node rt = reduceIteConstantIteByGCD_rec(n[1], q);
206
    Node re = reduceIteConstantIteByGCD_rec(n[2], q);
207
    return rc.iteNode(rt, re);
208
  }
209
}
210
211
Node ArithIteUtils::reduceIteConstantIteByGCD(Node n){
212
  Assert(n.getKind() == kind::ITE);
213
  Assert(n.getType().isReal());
214
  const Integer& gcd = gcdIte(n);
215
  if(gcd.isOne()){
216
    Node newIte = reduceConstantIteByGCD(n[0]).iteNode(n[1],n[2]);
217
    d_reduceGcd[n] = newIte;
218
    return newIte;
219
  }else if(gcd.isZero()){
220
    Node zeroNode = mkRationalNode(Rational(0));
221
    d_reduceGcd[n] = zeroNode;
222
    return zeroNode;
223
  }else{
224
    Rational divBy(Integer(1), gcd);
225
    Node redite = reduceIteConstantIteByGCD_rec(n, divBy);
226
    Node gcdNode = mkRationalNode(Rational(gcd));
227
    Node multIte = NodeManager::currentNM()->mkNode(kind::MULT, gcdNode, redite);
228
    d_reduceGcd[n] = multIte;
229
    return multIte;
230
  }
231
}
232
233
Node ArithIteUtils::reduceConstantIteByGCD(Node n){
234
  if(d_reduceGcd.find(n) != d_reduceGcd.end()){
235
    return d_reduceGcd[n];
236
  }
237
  if(n.getKind() == kind::ITE && n.getType().isReal()){
238
    return reduceIteConstantIteByGCD(n);
239
  }
240
241
  if(n.getNumChildren() > 0){
242
    NodeBuilder<> nb(n.getKind());
243
    if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
244
      nb << (n.getOperator());
245
    }
246
    bool anychange = false;
247
    for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
248
      Node child = *it;
249
      Node redchild = reduceConstantIteByGCD(child);
250
      anychange = anychange || (child != redchild);
251
      nb << redchild;
252
    }
253
    if(anychange){
254
      Node res = nb;
255
      d_reduceGcd[n] = res;
256
      return res;
257
    }else{
258
      d_reduceGcd[n] = n;
259
      return n;
260
    }
261
  }else{
262
    return n;
263
  }
264
}
265
266
unsigned ArithIteUtils::getSubCount() const{
267
  return d_subcount;
268
}
269
270
void ArithIteUtils::addSubstitution(TNode f, TNode t){
271
  Debug("arith::ite") << "adding " << f << " -> " << t << endl;
272
  d_subcount = d_subcount + 1;
273
  d_subs->addSubstitution(f, t);
274
  d_model->addSubstitution(f, t);
275
}
276
277
Node ArithIteUtils::applySubstitutions(TNode f){
278
  AlwaysAssert(!options::incrementalSolving());
279
  return d_subs->apply(f);
280
}
281
282
Node ArithIteUtils::selectForCmp(Node n) const{
283
  if(n.getKind() == kind::ITE){
284
    if(d_skolems.find(n[0]) != d_skolems.end()){
285
      return selectForCmp(n[1]);
286
    }
287
  }
288
  return n;
289
}
290
291
void ArithIteUtils::learnSubstitutions(const std::vector<Node>& assertions){
292
  AlwaysAssert(!options::incrementalSolving());
293
  for(size_t i=0, N=assertions.size(); i < N; ++i){
294
    collectAssertions(assertions[i]);
295
  }
296
  bool solvedSomething;
297
  do{
298
    solvedSomething = false;
299
    size_t readPos = 0, writePos = 0, N = d_orBinEqs.size();
300
    for(; readPos < N; readPos++){
301
      Node curr = d_orBinEqs[readPos];
302
      bool solved = solveBinOr(curr);
303
      if(solved){
304
        solvedSomething = true;
305
      }else{
306
        // didn't solve, push back
307
        d_orBinEqs[writePos] = curr;
308
        writePos++;
309
      }
310
    }
311
    Assert(writePos <= N);
312
    d_orBinEqs.resize(writePos);
313
  }while(solvedSomething);
314
315
  d_implies.clear();
316
  d_orBinEqs.clear();
317
}
318
319
void ArithIteUtils::addImplications(Node x, Node y){
320
  // (or x y)
321
  // (=> (not x) y)
322
  // (=> (not y) x)
323
324
  Node xneg = x.negate();
325
  Node yneg = y.negate();
326
  d_implies[xneg].insert(y);
327
  d_implies[yneg].insert(x);
328
}
329
330
void ArithIteUtils::collectAssertions(TNode assertion){
331
  if(assertion.getKind() == kind::OR){
332
    if(assertion.getNumChildren() == 2){
333
      TNode left = assertion[0], right = assertion[1];
334
      addImplications(left, right);
335
      if(left.getKind() == kind::EQUAL && right.getKind() == kind::EQUAL){
336
        if(left[0].getType().isInteger() && right[0].getType().isInteger()){
337
          d_orBinEqs.push_back(assertion);
338
        }
339
      }
340
    }
341
  }else if(assertion.getKind() == kind::AND){
342
    for(unsigned i=0, N=assertion.getNumChildren(); i < N; ++i){
343
      collectAssertions(assertion[i]);
344
    }
345
  }
346
}
347
348
Node ArithIteUtils::findIteCnd(TNode tb, TNode fb) const{
349
  Node negtb = tb.negate();
350
  Node negfb = fb.negate();
351
  ImpMap::const_iterator ti = d_implies.find(negtb);
352
  ImpMap::const_iterator fi = d_implies.find(negfb);
353
354
  if(ti != d_implies.end() && fi != d_implies.end()){
355
    const std::set<Node>& negtimp = ti->second;
356
    const std::set<Node>& negfimp = fi->second;
357
358
    // (or (not x) y)
359
    // (or x z)
360
    // (or y z)
361
    // ---
362
    // (ite x y z) return x
363
    // ---
364
    // (not y) => (not x)
365
    // (not z) => x
366
    std::set<Node>::const_iterator ci = negtimp.begin(), cend = negtimp.end();
367
    for (; ci != cend; ++ci)
368
    {
369
      Node impliedByNotTB = *ci;
370
      Node impliedByNotTBNeg = impliedByNotTB.negate();
371
      if(negfimp.find(impliedByNotTBNeg) != negfimp.end()){
372
        return impliedByNotTBNeg; // implies tb
373
      }
374
    }
375
  }
376
377
  return Node::null();
378
}
379
380
bool ArithIteUtils::solveBinOr(TNode binor){
381
  Assert(binor.getKind() == kind::OR);
382
  Assert(binor.getNumChildren() == 2);
383
  Assert(binor[0].getKind() == kind::EQUAL);
384
  Assert(binor[1].getKind() == kind::EQUAL);
385
386
  //Node n =
387
  Node n = applySubstitutions(binor);
388
  if(n != binor){
389
    n = Rewriter::rewrite(n);
390
391
    if(!(n.getKind() == kind::OR &&
392
	 n.getNumChildren() == 2 &&
393
	 n[0].getKind() ==  kind::EQUAL &&
394
	 n[1].getKind() ==  kind::EQUAL)){
395
      return false;
396
    }
397
  }
398
399
  Assert(n.getKind() == kind::OR);
400
  Assert(n.getNumChildren() == 2);
401
  TNode l = n[0];
402
  TNode r = n[1];
403
404
  Assert(l.getKind() == kind::EQUAL);
405
  Assert(r.getKind() == kind::EQUAL);
406
407
  Debug("arith::ite") << "bin or " << n << endl;
408
409
  bool lArithEq = l.getKind() == kind::EQUAL && l[0].getType().isInteger();
410
  bool rArithEq = r.getKind() == kind::EQUAL && r[0].getType().isInteger();
411
412
  if(lArithEq && rArithEq){
413
    TNode sel = Node::null();
414
    TNode otherL = Node::null();
415
    TNode otherR = Node::null();
416
    if(l[0] == r[0]) {
417
      sel = l[0]; otherL = l[1]; otherR = r[1];
418
    }else if(l[0] == r[1]){
419
      sel = l[0]; otherL = l[1]; otherR = r[0];
420
    }else if(l[1] == r[0]){
421
      sel = l[1]; otherL = l[0]; otherR = r[1];
422
    }else if(l[1] == r[1]){
423
      sel = l[1]; otherL = l[0]; otherR = r[0];
424
    }
425
    Debug("arith::ite") << "selected " << sel << endl;
426
    if(sel.isVar() && sel.getKind() != kind::SKOLEM){
427
428
      Debug("arith::ite") << "others l:" << otherL << " r " << otherR << endl;
429
      Node useForCmpL = selectForCmp(otherL);
430
      Node useForCmpR = selectForCmp(otherR);
431
432
      Assert(Polynomial::isMember(sel));
433
      Assert(Polynomial::isMember(useForCmpL));
434
      Assert(Polynomial::isMember(useForCmpR));
435
      Polynomial lside = Polynomial::parsePolynomial( useForCmpL );
436
      Polynomial rside = Polynomial::parsePolynomial( useForCmpR );
437
      Polynomial diff = lside-rside;
438
439
      Debug("arith::ite") << "diff: " << diff.getNode() << endl;
440
      if(diff.isConstant()){
441
        // a: (sel = otherL) or (sel = otherR), otherL-otherR = c
442
443
        NodeManager* nm = NodeManager::currentNM();
444
445
        Node cnd = findIteCnd(binor[0], binor[1]);
446
447
        Node sk = nm->mkSkolem("deor", nm->booleanType());
448
        Node ite = sk.iteNode(otherL, otherR);
449
        d_skolems.insert(sk, cnd);
450
        addSubstitution(sel, ite);
451
        return true;
452
      }
453
    }
454
  }
455
  return false;
456
}
457
458
459
}/* CVC4::theory::arith namespace */
460
}/* CVC4::theory namespace */
461
26676
}/* CVC4 namespace */