GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_ite_utils.cpp Lines: 1 279 0.4 %
Date: 2021-05-22 Branches: 2 1342 0.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Tim King, Aina Niemetz, Piotr Trojanek
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * [[ Add one-line brief description here ]]
14
 *
15
 * [[ Add lengthier description here ]]
16
 * \todo document this file
17
 */
18
19
#include "theory/arith/arith_ite_utils.h"
20
21
#include <ostream>
22
23
#include "base/output.h"
24
#include "expr/skolem_manager.h"
25
#include "options/smt_options.h"
26
#include "preprocessing/util/ite_utilities.h"
27
#include "theory/arith/arith_utilities.h"
28
#include "theory/arith/normal_form.h"
29
#include "theory/rewriter.h"
30
#include "theory/substitutions.h"
31
#include "theory/theory_model.h"
32
33
using namespace std;
34
35
namespace cvc5 {
36
namespace theory {
37
namespace arith {
38
39
Node ArithIteUtils::applyReduceVariablesInItes(Node n){
40
  NodeBuilder nb(n.getKind());
41
  if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
42
    nb << (n.getOperator());
43
  }
44
  for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
45
    nb << reduceVariablesInItes(*it);
46
  }
47
  Node res = nb;
48
  return res;
49
}
50
51
Node ArithIteUtils::reduceVariablesInItes(Node n){
52
  using namespace cvc5::kind;
53
  if(d_reduceVar.find(n) != d_reduceVar.end()){
54
    Node res = d_reduceVar[n];
55
    return res.isNull() ? n : res;
56
  }
57
58
  switch(n.getKind()){
59
  case ITE:{
60
    Node c = n[0], t = n[1], e = n[2];
61
    if(n.getType().isReal()){
62
      Node rc = reduceVariablesInItes(c);
63
      Node rt = reduceVariablesInItes(t);
64
      Node re = reduceVariablesInItes(e);
65
66
      Node vt = d_varParts[t];
67
      Node ve = d_varParts[e];
68
      Node vpite = (vt == ve) ? vt : Node::null();
69
70
      if(vpite.isNull()){
71
        Node rite = rc.iteNode(rt, re);
72
        // do not apply
73
        d_reduceVar[n] = rite;
74
        d_constants[n] = mkRationalNode(Rational(0));
75
        d_varParts[n] = rite; // treat the ite as a variable
76
        return rite;
77
      }else{
78
        NodeManager* nm = NodeManager::currentNM();
79
        Node constantite = rc.iteNode(d_constants[t], d_constants[e]);
80
        Node sum = nm->mkNode(kind::PLUS, vpite, constantite);
81
        d_reduceVar[n] = sum;
82
        d_constants[n] = constantite;
83
        d_varParts[n] = vpite;
84
        return sum;
85
      }
86
    }else{ // non-arith ite
87
      if(!d_contains.containsTermITE(n)){
88
        // don't bother adding to d_reduceVar
89
        return n;
90
      }else{
91
        Node newIte = applyReduceVariablesInItes(n);
92
        d_reduceVar[n] = (n == newIte) ? Node::null(): newIte;
93
        return newIte;
94
      }
95
    }
96
  }break;
97
  default:
98
    if(n.getType().isReal() && Polynomial::isMember(n)){
99
      Node newn = Node::null();
100
      if(!d_contains.containsTermITE(n)){
101
        newn = n;
102
      }else if(n.getNumChildren() > 0){
103
        newn = applyReduceVariablesInItes(n);
104
        newn = Rewriter::rewrite(newn);
105
        Assert(Polynomial::isMember(newn));
106
      }else{
107
        newn = n;
108
      }
109
110
      Polynomial p = Polynomial::parsePolynomial(newn);
111
      if(p.isConstant()){
112
        d_constants[n] = newn;
113
        d_varParts[n] = mkRationalNode(Rational(0));
114
        // don't bother adding to d_reduceVar
115
        return newn;
116
      }else if(!p.containsConstant()){
117
        d_constants[n] = mkRationalNode(Rational(0));
118
        d_varParts[n] = newn;
119
        d_reduceVar[n] = p.getNode();
120
        return p.getNode();
121
      }else{
122
        Monomial mc = p.getHead();
123
        d_constants[n] = mc.getConstant().getNode();
124
        d_varParts[n] = p.getTail().getNode();
125
        d_reduceVar[n] = newn;
126
        return newn;
127
      }
128
    }else{
129
      if(!d_contains.containsTermITE(n)){
130
        return n;
131
      }
132
      if(n.getNumChildren() > 0){
133
        Node res = applyReduceVariablesInItes(n);
134
        d_reduceVar[n] = res;
135
        return res;
136
      }else{
137
        return n;
138
      }
139
    }
140
    break;
141
  }
142
  Unreachable();
143
}
144
145
ArithIteUtils::ArithIteUtils(
146
    preprocessing::util::ContainsTermITEVisitor& contains,
147
    context::Context* uc,
148
    TheoryModel* model)
149
    : d_contains(contains),
150
      d_subs(NULL),
151
      d_model(model),
152
      d_one(1),
153
      d_subcount(uc, 0),
154
      d_skolems(uc),
155
      d_implies(),
156
      d_orBinEqs()
157
{
158
  d_subs = new SubstitutionMap(uc);
159
}
160
161
ArithIteUtils::~ArithIteUtils(){
162
  delete d_subs;
163
  d_subs = NULL;
164
}
165
166
void ArithIteUtils::clear(){
167
  d_reduceVar.clear();
168
  d_constants.clear();
169
  d_varParts.clear();
170
}
171
172
const Integer& ArithIteUtils::gcdIte(Node n){
173
  if(d_gcds.find(n) != d_gcds.end()){
174
    return d_gcds[n];
175
  }
176
  if(n.getKind() == kind::CONST_RATIONAL){
177
    const Rational& q = n.getConst<Rational>();
178
    if(q.isIntegral()){
179
      d_gcds[n] = q.getNumerator();
180
      return d_gcds[n];
181
    }else{
182
      return d_one;
183
    }
184
  }else if(n.getKind() == kind::ITE && n.getType().isReal()){
185
    const Integer& tgcd = gcdIte(n[1]);
186
    if(tgcd.isOne()){
187
      d_gcds[n] = d_one;
188
      return d_one;
189
    }else{
190
      const Integer& egcd = gcdIte(n[2]);
191
      Integer ite_gcd = tgcd.gcd(egcd);
192
      d_gcds[n] = ite_gcd;
193
      return d_gcds[n];
194
    }
195
  }
196
  return d_one;
197
}
198
199
Node ArithIteUtils::reduceIteConstantIteByGCD_rec(Node n, const Rational& q){
200
  if(n.isConst()){
201
    Assert(n.getKind() == kind::CONST_RATIONAL);
202
    return mkRationalNode(n.getConst<Rational>() * q);
203
  }else{
204
    Assert(n.getKind() == kind::ITE);
205
    Assert(n.getType().isInteger());
206
    Node rc = reduceConstantIteByGCD(n[0]);
207
    Node rt = reduceIteConstantIteByGCD_rec(n[1], q);
208
    Node re = reduceIteConstantIteByGCD_rec(n[2], q);
209
    return rc.iteNode(rt, re);
210
  }
211
}
212
213
Node ArithIteUtils::reduceIteConstantIteByGCD(Node n){
214
  Assert(n.getKind() == kind::ITE);
215
  Assert(n.getType().isReal());
216
  const Integer& gcd = gcdIte(n);
217
  if(gcd.isOne()){
218
    Node newIte = reduceConstantIteByGCD(n[0]).iteNode(n[1],n[2]);
219
    d_reduceGcd[n] = newIte;
220
    return newIte;
221
  }else if(gcd.isZero()){
222
    Node zeroNode = mkRationalNode(Rational(0));
223
    d_reduceGcd[n] = zeroNode;
224
    return zeroNode;
225
  }else{
226
    Rational divBy(Integer(1), gcd);
227
    Node redite = reduceIteConstantIteByGCD_rec(n, divBy);
228
    Node gcdNode = mkRationalNode(Rational(gcd));
229
    Node multIte = NodeManager::currentNM()->mkNode(kind::MULT, gcdNode, redite);
230
    d_reduceGcd[n] = multIte;
231
    return multIte;
232
  }
233
}
234
235
Node ArithIteUtils::reduceConstantIteByGCD(Node n){
236
  if(d_reduceGcd.find(n) != d_reduceGcd.end()){
237
    return d_reduceGcd[n];
238
  }
239
  if(n.getKind() == kind::ITE && n.getType().isReal()){
240
    return reduceIteConstantIteByGCD(n);
241
  }
242
243
  if(n.getNumChildren() > 0){
244
    NodeBuilder nb(n.getKind());
245
    if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
246
      nb << (n.getOperator());
247
    }
248
    bool anychange = false;
249
    for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
250
      Node child = *it;
251
      Node redchild = reduceConstantIteByGCD(child);
252
      anychange = anychange || (child != redchild);
253
      nb << redchild;
254
    }
255
    if(anychange){
256
      Node res = nb;
257
      d_reduceGcd[n] = res;
258
      return res;
259
    }else{
260
      d_reduceGcd[n] = n;
261
      return n;
262
    }
263
  }else{
264
    return n;
265
  }
266
}
267
268
unsigned ArithIteUtils::getSubCount() const{
269
  return d_subcount;
270
}
271
272
void ArithIteUtils::addSubstitution(TNode f, TNode t){
273
  Debug("arith::ite") << "adding " << f << " -> " << t << endl;
274
  d_subcount = d_subcount + 1;
275
  d_subs->addSubstitution(f, t);
276
}
277
278
Node ArithIteUtils::applySubstitutions(TNode f){
279
  AlwaysAssert(!options::incrementalSolving());
280
  return d_subs->apply(f);
281
}
282
283
Node ArithIteUtils::selectForCmp(Node n) const{
284
  if(n.getKind() == kind::ITE){
285
    if(d_skolems.find(n[0]) != d_skolems.end()){
286
      return selectForCmp(n[1]);
287
    }
288
  }
289
  return n;
290
}
291
292
void ArithIteUtils::learnSubstitutions(const std::vector<Node>& assertions){
293
  AlwaysAssert(!options::incrementalSolving());
294
  for(size_t i=0, N=assertions.size(); i < N; ++i){
295
    collectAssertions(assertions[i]);
296
  }
297
  bool solvedSomething;
298
  do{
299
    solvedSomething = false;
300
    size_t readPos = 0, writePos = 0, N = d_orBinEqs.size();
301
    for(; readPos < N; readPos++){
302
      Node curr = d_orBinEqs[readPos];
303
      bool solved = solveBinOr(curr);
304
      if(solved){
305
        solvedSomething = true;
306
      }else{
307
        // didn't solve, push back
308
        d_orBinEqs[writePos] = curr;
309
        writePos++;
310
      }
311
    }
312
    Assert(writePos <= N);
313
    d_orBinEqs.resize(writePos);
314
  }while(solvedSomething);
315
316
  d_implies.clear();
317
  d_orBinEqs.clear();
318
}
319
320
void ArithIteUtils::addImplications(Node x, Node y){
321
  // (or x y)
322
  // (=> (not x) y)
323
  // (=> (not y) x)
324
325
  Node xneg = x.negate();
326
  Node yneg = y.negate();
327
  d_implies[xneg].insert(y);
328
  d_implies[yneg].insert(x);
329
}
330
331
void ArithIteUtils::collectAssertions(TNode assertion){
332
  if(assertion.getKind() == kind::OR){
333
    if(assertion.getNumChildren() == 2){
334
      TNode left = assertion[0], right = assertion[1];
335
      addImplications(left, right);
336
      if(left.getKind() == kind::EQUAL && right.getKind() == kind::EQUAL){
337
        if(left[0].getType().isInteger() && right[0].getType().isInteger()){
338
          d_orBinEqs.push_back(assertion);
339
        }
340
      }
341
    }
342
  }else if(assertion.getKind() == kind::AND){
343
    for(unsigned i=0, N=assertion.getNumChildren(); i < N; ++i){
344
      collectAssertions(assertion[i]);
345
    }
346
  }
347
}
348
349
Node ArithIteUtils::findIteCnd(TNode tb, TNode fb) const{
350
  Node negtb = tb.negate();
351
  Node negfb = fb.negate();
352
  ImpMap::const_iterator ti = d_implies.find(negtb);
353
  ImpMap::const_iterator fi = d_implies.find(negfb);
354
355
  if(ti != d_implies.end() && fi != d_implies.end()){
356
    const std::set<Node>& negtimp = ti->second;
357
    const std::set<Node>& negfimp = fi->second;
358
359
    // (or (not x) y)
360
    // (or x z)
361
    // (or y z)
362
    // ---
363
    // (ite x y z) return x
364
    // ---
365
    // (not y) => (not x)
366
    // (not z) => x
367
    std::set<Node>::const_iterator ci = negtimp.begin(), cend = negtimp.end();
368
    for (; ci != cend; ++ci)
369
    {
370
      Node impliedByNotTB = *ci;
371
      Node impliedByNotTBNeg = impliedByNotTB.negate();
372
      if(negfimp.find(impliedByNotTBNeg) != negfimp.end()){
373
        return impliedByNotTBNeg; // implies tb
374
      }
375
    }
376
  }
377
378
  return Node::null();
379
}
380
381
bool ArithIteUtils::solveBinOr(TNode binor){
382
  Assert(binor.getKind() == kind::OR);
383
  Assert(binor.getNumChildren() == 2);
384
  Assert(binor[0].getKind() == kind::EQUAL);
385
  Assert(binor[1].getKind() == kind::EQUAL);
386
387
  //Node n =
388
  Node n = applySubstitutions(binor);
389
  if(n != binor){
390
    n = Rewriter::rewrite(n);
391
392
    if(!(n.getKind() == kind::OR &&
393
	 n.getNumChildren() == 2 &&
394
	 n[0].getKind() ==  kind::EQUAL &&
395
	 n[1].getKind() ==  kind::EQUAL)){
396
      return false;
397
    }
398
  }
399
400
  Assert(n.getKind() == kind::OR);
401
  Assert(n.getNumChildren() == 2);
402
  TNode l = n[0];
403
  TNode r = n[1];
404
405
  Assert(l.getKind() == kind::EQUAL);
406
  Assert(r.getKind() == kind::EQUAL);
407
408
  Debug("arith::ite") << "bin or " << n << endl;
409
410
  bool lArithEq = l.getKind() == kind::EQUAL && l[0].getType().isInteger();
411
  bool rArithEq = r.getKind() == kind::EQUAL && r[0].getType().isInteger();
412
413
  if(lArithEq && rArithEq){
414
    TNode sel = Node::null();
415
    TNode otherL = Node::null();
416
    TNode otherR = Node::null();
417
    if(l[0] == r[0]) {
418
      sel = l[0]; otherL = l[1]; otherR = r[1];
419
    }else if(l[0] == r[1]){
420
      sel = l[0]; otherL = l[1]; otherR = r[0];
421
    }else if(l[1] == r[0]){
422
      sel = l[1]; otherL = l[0]; otherR = r[1];
423
    }else if(l[1] == r[1]){
424
      sel = l[1]; otherL = l[0]; otherR = r[0];
425
    }
426
    Debug("arith::ite") << "selected " << sel << endl;
427
    if(sel.isVar() && sel.getKind() != kind::SKOLEM){
428
429
      Debug("arith::ite") << "others l:" << otherL << " r " << otherR << endl;
430
      Node useForCmpL = selectForCmp(otherL);
431
      Node useForCmpR = selectForCmp(otherR);
432
433
      Assert(Polynomial::isMember(sel));
434
      Assert(Polynomial::isMember(useForCmpL));
435
      Assert(Polynomial::isMember(useForCmpR));
436
      Polynomial lside = Polynomial::parsePolynomial( useForCmpL );
437
      Polynomial rside = Polynomial::parsePolynomial( useForCmpR );
438
      Polynomial diff = lside-rside;
439
440
      Debug("arith::ite") << "diff: " << diff.getNode() << endl;
441
      if(diff.isConstant()){
442
        // a: (sel = otherL) or (sel = otherR), otherL-otherR = c
443
444
        NodeManager* nm = NodeManager::currentNM();
445
        SkolemManager* sm = nm->getSkolemManager();
446
447
        Node cnd = findIteCnd(binor[0], binor[1]);
448
449
        Node sk = sm->mkDummySkolem("deor", nm->booleanType());
450
        Node ite = sk.iteNode(otherL, otherR);
451
        d_skolems.insert(sk, cnd);
452
        addSubstitution(sel, ite);
453
        return true;
454
      }
455
    }
456
  }
457
  return false;
458
}
459
460
}  // namespace arith
461
}  // namespace theory
462
28191
}  // namespace cvc5