GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_ite_utils.cpp Lines: 1 276 0.4 %
Date: 2021-09-07 Branches: 2 1330 0.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Tim King, Aina Niemetz, Piotr Trojanek
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * [[ Add one-line brief description here ]]
14
 *
15
 * [[ Add lengthier description here ]]
16
 * \todo document this file
17
 */
18
19
#include "theory/arith/arith_ite_utils.h"
20
21
#include <ostream>
22
23
#include "base/output.h"
24
#include "expr/skolem_manager.h"
25
#include "options/base_options.h"
26
#include "options/smt_options.h"
27
#include "preprocessing/util/ite_utilities.h"
28
#include "theory/arith/arith_utilities.h"
29
#include "theory/arith/normal_form.h"
30
#include "theory/rewriter.h"
31
#include "theory/substitutions.h"
32
#include "theory/theory_model.h"
33
34
using namespace std;
35
36
namespace cvc5 {
37
namespace theory {
38
namespace arith {
39
40
Node ArithIteUtils::applyReduceVariablesInItes(Node n){
41
  NodeBuilder nb(n.getKind());
42
  if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
43
    nb << (n.getOperator());
44
  }
45
  for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
46
    nb << reduceVariablesInItes(*it);
47
  }
48
  Node res = nb;
49
  return res;
50
}
51
52
Node ArithIteUtils::reduceVariablesInItes(Node n){
53
  using namespace cvc5::kind;
54
  if(d_reduceVar.find(n) != d_reduceVar.end()){
55
    Node res = d_reduceVar[n];
56
    return res.isNull() ? n : res;
57
  }
58
59
  switch(n.getKind()){
60
  case ITE:{
61
    Node c = n[0], t = n[1], e = n[2];
62
    if(n.getType().isReal()){
63
      Node rc = reduceVariablesInItes(c);
64
      Node rt = reduceVariablesInItes(t);
65
      Node re = reduceVariablesInItes(e);
66
67
      Node vt = d_varParts[t];
68
      Node ve = d_varParts[e];
69
      Node vpite = (vt == ve) ? vt : Node::null();
70
71
      if(vpite.isNull()){
72
        Node rite = rc.iteNode(rt, re);
73
        // do not apply
74
        d_reduceVar[n] = rite;
75
        d_constants[n] = mkRationalNode(Rational(0));
76
        d_varParts[n] = rite; // treat the ite as a variable
77
        return rite;
78
      }else{
79
        NodeManager* nm = NodeManager::currentNM();
80
        Node constantite = rc.iteNode(d_constants[t], d_constants[e]);
81
        Node sum = nm->mkNode(kind::PLUS, vpite, constantite);
82
        d_reduceVar[n] = sum;
83
        d_constants[n] = constantite;
84
        d_varParts[n] = vpite;
85
        return sum;
86
      }
87
    }else{ // non-arith ite
88
      if(!d_contains.containsTermITE(n)){
89
        // don't bother adding to d_reduceVar
90
        return n;
91
      }else{
92
        Node newIte = applyReduceVariablesInItes(n);
93
        d_reduceVar[n] = (n == newIte) ? Node::null(): newIte;
94
        return newIte;
95
      }
96
    }
97
  }break;
98
  default:
99
    if(n.getType().isReal() && Polynomial::isMember(n)){
100
      Node newn = Node::null();
101
      if(!d_contains.containsTermITE(n)){
102
        newn = n;
103
      }else if(n.getNumChildren() > 0){
104
        newn = applyReduceVariablesInItes(n);
105
        newn = Rewriter::rewrite(newn);
106
        Assert(Polynomial::isMember(newn));
107
      }else{
108
        newn = n;
109
      }
110
111
      Polynomial p = Polynomial::parsePolynomial(newn);
112
      if(p.isConstant()){
113
        d_constants[n] = newn;
114
        d_varParts[n] = mkRationalNode(Rational(0));
115
        // don't bother adding to d_reduceVar
116
        return newn;
117
      }else if(!p.containsConstant()){
118
        d_constants[n] = mkRationalNode(Rational(0));
119
        d_varParts[n] = newn;
120
        d_reduceVar[n] = p.getNode();
121
        return p.getNode();
122
      }else{
123
        Monomial mc = p.getHead();
124
        d_constants[n] = mc.getConstant().getNode();
125
        d_varParts[n] = p.getTail().getNode();
126
        d_reduceVar[n] = newn;
127
        return newn;
128
      }
129
    }else{
130
      if(!d_contains.containsTermITE(n)){
131
        return n;
132
      }
133
      if(n.getNumChildren() > 0){
134
        Node res = applyReduceVariablesInItes(n);
135
        d_reduceVar[n] = res;
136
        return res;
137
      }else{
138
        return n;
139
      }
140
    }
141
    break;
142
  }
143
  Unreachable();
144
}
145
146
ArithIteUtils::ArithIteUtils(
147
    preprocessing::util::ContainsTermITEVisitor& contains,
148
    context::Context* uc,
149
    SubstitutionMap& subs)
150
    : d_contains(contains),
151
      d_subs(subs),
152
      d_one(1),
153
      d_subcount(uc, 0),
154
      d_skolems(uc),
155
      d_implies(),
156
      d_orBinEqs()
157
{
158
}
159
160
ArithIteUtils::~ArithIteUtils(){
161
}
162
163
void ArithIteUtils::clear(){
164
  d_reduceVar.clear();
165
  d_constants.clear();
166
  d_varParts.clear();
167
}
168
169
const Integer& ArithIteUtils::gcdIte(Node n){
170
  if(d_gcds.find(n) != d_gcds.end()){
171
    return d_gcds[n];
172
  }
173
  if(n.getKind() == kind::CONST_RATIONAL){
174
    const Rational& q = n.getConst<Rational>();
175
    if(q.isIntegral()){
176
      d_gcds[n] = q.getNumerator();
177
      return d_gcds[n];
178
    }else{
179
      return d_one;
180
    }
181
  }else if(n.getKind() == kind::ITE && n.getType().isReal()){
182
    const Integer& tgcd = gcdIte(n[1]);
183
    if(tgcd.isOne()){
184
      d_gcds[n] = d_one;
185
      return d_one;
186
    }else{
187
      const Integer& egcd = gcdIte(n[2]);
188
      Integer ite_gcd = tgcd.gcd(egcd);
189
      d_gcds[n] = ite_gcd;
190
      return d_gcds[n];
191
    }
192
  }
193
  return d_one;
194
}
195
196
Node ArithIteUtils::reduceIteConstantIteByGCD_rec(Node n, const Rational& q){
197
  if(n.isConst()){
198
    Assert(n.getKind() == kind::CONST_RATIONAL);
199
    return mkRationalNode(n.getConst<Rational>() * q);
200
  }else{
201
    Assert(n.getKind() == kind::ITE);
202
    Assert(n.getType().isInteger());
203
    Node rc = reduceConstantIteByGCD(n[0]);
204
    Node rt = reduceIteConstantIteByGCD_rec(n[1], q);
205
    Node re = reduceIteConstantIteByGCD_rec(n[2], q);
206
    return rc.iteNode(rt, re);
207
  }
208
}
209
210
Node ArithIteUtils::reduceIteConstantIteByGCD(Node n){
211
  Assert(n.getKind() == kind::ITE);
212
  Assert(n.getType().isReal());
213
  const Integer& gcd = gcdIte(n);
214
  if(gcd.isOne()){
215
    Node newIte = reduceConstantIteByGCD(n[0]).iteNode(n[1],n[2]);
216
    d_reduceGcd[n] = newIte;
217
    return newIte;
218
  }else if(gcd.isZero()){
219
    Node zeroNode = mkRationalNode(Rational(0));
220
    d_reduceGcd[n] = zeroNode;
221
    return zeroNode;
222
  }else{
223
    Rational divBy(Integer(1), gcd);
224
    Node redite = reduceIteConstantIteByGCD_rec(n, divBy);
225
    Node gcdNode = mkRationalNode(Rational(gcd));
226
    Node multIte = NodeManager::currentNM()->mkNode(kind::MULT, gcdNode, redite);
227
    d_reduceGcd[n] = multIte;
228
    return multIte;
229
  }
230
}
231
232
Node ArithIteUtils::reduceConstantIteByGCD(Node n){
233
  if(d_reduceGcd.find(n) != d_reduceGcd.end()){
234
    return d_reduceGcd[n];
235
  }
236
  if(n.getKind() == kind::ITE && n.getType().isReal()){
237
    return reduceIteConstantIteByGCD(n);
238
  }
239
240
  if(n.getNumChildren() > 0){
241
    NodeBuilder nb(n.getKind());
242
    if(n.getMetaKind() == kind::metakind::PARAMETERIZED) {
243
      nb << (n.getOperator());
244
    }
245
    bool anychange = false;
246
    for(Node::iterator it = n.begin(), end = n.end(); it != end; ++it){
247
      Node child = *it;
248
      Node redchild = reduceConstantIteByGCD(child);
249
      anychange = anychange || (child != redchild);
250
      nb << redchild;
251
    }
252
    if(anychange){
253
      Node res = nb;
254
      d_reduceGcd[n] = res;
255
      return res;
256
    }else{
257
      d_reduceGcd[n] = n;
258
      return n;
259
    }
260
  }else{
261
    return n;
262
  }
263
}
264
265
unsigned ArithIteUtils::getSubCount() const{
266
  return d_subcount;
267
}
268
269
void ArithIteUtils::addSubstitution(TNode f, TNode t){
270
  Debug("arith::ite") << "adding " << f << " -> " << t << endl;
271
  d_subcount = d_subcount + 1;
272
  d_subs.addSubstitution(f, t);
273
}
274
275
Node ArithIteUtils::applySubstitutions(TNode f){
276
  AlwaysAssert(!options::incrementalSolving());
277
  return d_subs.apply(f);
278
}
279
280
Node ArithIteUtils::selectForCmp(Node n) const{
281
  if(n.getKind() == kind::ITE){
282
    if(d_skolems.find(n[0]) != d_skolems.end()){
283
      return selectForCmp(n[1]);
284
    }
285
  }
286
  return n;
287
}
288
289
void ArithIteUtils::learnSubstitutions(const std::vector<Node>& assertions){
290
  AlwaysAssert(!options::incrementalSolving());
291
  for(size_t i=0, N=assertions.size(); i < N; ++i){
292
    collectAssertions(assertions[i]);
293
  }
294
  bool solvedSomething;
295
  do{
296
    solvedSomething = false;
297
    size_t readPos = 0, writePos = 0, N = d_orBinEqs.size();
298
    for(; readPos < N; readPos++){
299
      Node curr = d_orBinEqs[readPos];
300
      bool solved = solveBinOr(curr);
301
      if(solved){
302
        solvedSomething = true;
303
      }else{
304
        // didn't solve, push back
305
        d_orBinEqs[writePos] = curr;
306
        writePos++;
307
      }
308
    }
309
    Assert(writePos <= N);
310
    d_orBinEqs.resize(writePos);
311
  }while(solvedSomething);
312
313
  d_implies.clear();
314
  d_orBinEqs.clear();
315
}
316
317
void ArithIteUtils::addImplications(Node x, Node y){
318
  // (or x y)
319
  // (=> (not x) y)
320
  // (=> (not y) x)
321
322
  Node xneg = x.negate();
323
  Node yneg = y.negate();
324
  d_implies[xneg].insert(y);
325
  d_implies[yneg].insert(x);
326
}
327
328
void ArithIteUtils::collectAssertions(TNode assertion){
329
  if(assertion.getKind() == kind::OR){
330
    if(assertion.getNumChildren() == 2){
331
      TNode left = assertion[0], right = assertion[1];
332
      addImplications(left, right);
333
      if(left.getKind() == kind::EQUAL && right.getKind() == kind::EQUAL){
334
        if(left[0].getType().isInteger() && right[0].getType().isInteger()){
335
          d_orBinEqs.push_back(assertion);
336
        }
337
      }
338
    }
339
  }else if(assertion.getKind() == kind::AND){
340
    for(unsigned i=0, N=assertion.getNumChildren(); i < N; ++i){
341
      collectAssertions(assertion[i]);
342
    }
343
  }
344
}
345
346
Node ArithIteUtils::findIteCnd(TNode tb, TNode fb) const{
347
  Node negtb = tb.negate();
348
  Node negfb = fb.negate();
349
  ImpMap::const_iterator ti = d_implies.find(negtb);
350
  ImpMap::const_iterator fi = d_implies.find(negfb);
351
352
  if(ti != d_implies.end() && fi != d_implies.end()){
353
    const std::set<Node>& negtimp = ti->second;
354
    const std::set<Node>& negfimp = fi->second;
355
356
    // (or (not x) y)
357
    // (or x z)
358
    // (or y z)
359
    // ---
360
    // (ite x y z) return x
361
    // ---
362
    // (not y) => (not x)
363
    // (not z) => x
364
    std::set<Node>::const_iterator ci = negtimp.begin(), cend = negtimp.end();
365
    for (; ci != cend; ++ci)
366
    {
367
      Node impliedByNotTB = *ci;
368
      Node impliedByNotTBNeg = impliedByNotTB.negate();
369
      if(negfimp.find(impliedByNotTBNeg) != negfimp.end()){
370
        return impliedByNotTBNeg; // implies tb
371
      }
372
    }
373
  }
374
375
  return Node::null();
376
}
377
378
bool ArithIteUtils::solveBinOr(TNode binor){
379
  Assert(binor.getKind() == kind::OR);
380
  Assert(binor.getNumChildren() == 2);
381
  Assert(binor[0].getKind() == kind::EQUAL);
382
  Assert(binor[1].getKind() == kind::EQUAL);
383
384
  //Node n =
385
  Node n = applySubstitutions(binor);
386
  if(n != binor){
387
    n = Rewriter::rewrite(n);
388
389
    if(!(n.getKind() == kind::OR &&
390
	 n.getNumChildren() == 2 &&
391
	 n[0].getKind() ==  kind::EQUAL &&
392
	 n[1].getKind() ==  kind::EQUAL)){
393
      return false;
394
    }
395
  }
396
397
  Assert(n.getKind() == kind::OR);
398
  Assert(n.getNumChildren() == 2);
399
  TNode l = n[0];
400
  TNode r = n[1];
401
402
  Assert(l.getKind() == kind::EQUAL);
403
  Assert(r.getKind() == kind::EQUAL);
404
405
  Debug("arith::ite") << "bin or " << n << endl;
406
407
  bool lArithEq = l.getKind() == kind::EQUAL && l[0].getType().isInteger();
408
  bool rArithEq = r.getKind() == kind::EQUAL && r[0].getType().isInteger();
409
410
  if(lArithEq && rArithEq){
411
    TNode sel = Node::null();
412
    TNode otherL = Node::null();
413
    TNode otherR = Node::null();
414
    if(l[0] == r[0]) {
415
      sel = l[0]; otherL = l[1]; otherR = r[1];
416
    }else if(l[0] == r[1]){
417
      sel = l[0]; otherL = l[1]; otherR = r[0];
418
    }else if(l[1] == r[0]){
419
      sel = l[1]; otherL = l[0]; otherR = r[1];
420
    }else if(l[1] == r[1]){
421
      sel = l[1]; otherL = l[0]; otherR = r[0];
422
    }
423
    Debug("arith::ite") << "selected " << sel << endl;
424
    if(sel.isVar() && sel.getKind() != kind::SKOLEM){
425
426
      Debug("arith::ite") << "others l:" << otherL << " r " << otherR << endl;
427
      Node useForCmpL = selectForCmp(otherL);
428
      Node useForCmpR = selectForCmp(otherR);
429
430
      Assert(Polynomial::isMember(sel));
431
      Assert(Polynomial::isMember(useForCmpL));
432
      Assert(Polynomial::isMember(useForCmpR));
433
      Polynomial lside = Polynomial::parsePolynomial( useForCmpL );
434
      Polynomial rside = Polynomial::parsePolynomial( useForCmpR );
435
      Polynomial diff = lside-rside;
436
437
      Debug("arith::ite") << "diff: " << diff.getNode() << endl;
438
      if(diff.isConstant()){
439
        // a: (sel = otherL) or (sel = otherR), otherL-otherR = c
440
441
        NodeManager* nm = NodeManager::currentNM();
442
        SkolemManager* sm = nm->getSkolemManager();
443
444
        Node cnd = findIteCnd(binor[0], binor[1]);
445
446
        Node sk = sm->mkDummySkolem("deor", nm->booleanType());
447
        Node ite = sk.iteNode(otherL, otherR);
448
        d_skolems.insert(sk, cnd);
449
        addSubstitution(sel, ite);
450
        return true;
451
      }
452
    }
453
  }
454
  return false;
455
}
456
457
}  // namespace arith
458
}  // namespace theory
459
29502
}  // namespace cvc5