GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/extended_rewrite.cpp Lines: 763 841 90.7 %
Date: 2021-03-23 Branches: 1840 3800 48.4 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file extended_rewrite.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Andres Noetzli, Mathias Preiner
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of extended rewriting techniques
13
 **/
14
15
#include "theory/quantifiers/extended_rewrite.h"
16
17
#include "theory/arith/arith_msum.h"
18
#include "theory/bv/theory_bv_utils.h"
19
#include "theory/datatypes/datatypes_rewriter.h"
20
#include "theory/quantifiers/term_util.h"
21
#include "theory/rewriter.h"
22
#include "theory/strings/sequences_rewriter.h"
23
#include "theory/theory.h"
24
25
using namespace CVC4::kind;
26
using namespace std;
27
28
namespace CVC4 {
29
namespace theory {
30
namespace quantifiers {
31
32
struct ExtRewriteAttributeId
33
{
34
};
35
typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
36
37
struct ExtRewriteAggAttributeId
38
{
39
};
40
typedef expr::Attribute<ExtRewriteAggAttributeId, Node> ExtRewriteAggAttribute;
41
42
11305
ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
43
{
44
11305
  d_true = NodeManager::currentNM()->mkConst(true);
45
11305
  d_false = NodeManager::currentNM()->mkConst(false);
46
11305
}
47
48
293756
void ExtendedRewriter::setCache(Node n, Node ret)
49
{
50
293756
  if (d_aggr)
51
  {
52
    ExtRewriteAggAttribute erga;
53
293144
    n.setAttribute(erga, ret);
54
  }
55
  else
56
  {
57
    ExtRewriteAttribute era;
58
612
    n.setAttribute(era, ret);
59
  }
60
293756
}
61
62
706413
Node ExtendedRewriter::getCache(Node n)
63
{
64
706413
  if (d_aggr)
65
  {
66
705923
    if (n.hasAttribute(ExtRewriteAggAttribute()))
67
    {
68
558397
      return n.getAttribute(ExtRewriteAggAttribute());
69
    }
70
  }
71
  else
72
  {
73
490
    if (n.hasAttribute(ExtRewriteAttribute()))
74
    {
75
184
      return n.getAttribute(ExtRewriteAttribute());
76
    }
77
  }
78
147832
  return Node::null();
79
}
80
81
292011
bool ExtendedRewriter::addToChildren(Node nc,
82
                                     std::vector<Node>& children,
83
                                     bool dropDup)
84
{
85
  // If the operator is non-additive, do not consider duplicates
86
292011
  if (dropDup
87
292011
      && std::find(children.begin(), children.end(), nc) != children.end())
88
  {
89
162
    return false;
90
  }
91
291849
  children.push_back(nc);
92
291849
  return true;
93
}
94
95
706413
Node ExtendedRewriter::extendedRewrite(Node n)
96
{
97
706413
  n = Rewriter::rewrite(n);
98
99
  // has it already been computed?
100
1412826
  Node ncache = getCache(n);
101
706413
  if (!ncache.isNull())
102
  {
103
558581
    return ncache;
104
  }
105
106
295664
  Node ret = n;
107
147832
  NodeManager* nm = NodeManager::currentNM();
108
109
  //--------------------pre-rewrite
110
147832
  if (d_aggr)
111
  {
112
293144
    Node pre_new_ret;
113
147526
    if (ret.getKind() == IMPLIES)
114
    {
115
44
      pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
116
44
      debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
117
    }
118
147482
    else if (ret.getKind() == XOR)
119
    {
120
207
      pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
121
207
      debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
122
    }
123
147275
    else if (ret.getKind() == NOT)
124
    {
125
13920
      pre_new_ret = extendedRewriteNnf(ret);
126
13920
      debugExtendedRewrite(ret, pre_new_ret, "NNF");
127
    }
128
147526
    if (!pre_new_ret.isNull())
129
    {
130
1908
      ret = extendedRewrite(pre_new_ret);
131
132
3816
      Trace("q-ext-rewrite-debug")
133
1908
          << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
134
1908
      setCache(n, ret);
135
1908
      return ret;
136
    }
137
  }
138
  //--------------------end pre-rewrite
139
140
  //--------------------rewrite children
141
145924
  if (n.getNumChildren() > 0)
142
  {
143
275738
    std::vector<Node> children;
144
137869
    if (n.getMetaKind() == metakind::PARAMETERIZED)
145
    {
146
5197
      children.push_back(n.getOperator());
147
    }
148
137869
    Kind k = n.getKind();
149
137869
    bool childChanged = false;
150
137869
    bool isNonAdditive = TermUtil::isNonAdditive(k);
151
    // We flatten associative operators below, which requires k to be n-ary.
152
137869
    bool isAssoc = TermUtil::isAssoc(k, true);
153
429858
    for (unsigned i = 0; i < n.getNumChildren(); i++)
154
    {
155
583978
      Node nc = extendedRewrite(n[i]);
156
291989
      childChanged = nc != n[i] || childChanged;
157
291989
      if (isAssoc && nc.getKind() == n.getKind())
158
      {
159
50
        for (const Node& ncc : nc)
160
        {
161
36
          if (!addToChildren(ncc, children, isNonAdditive))
162
          {
163
            childChanged = true;
164
          }
165
        }
166
      }
167
291975
      else if (!addToChildren(nc, children, isNonAdditive))
168
      {
169
162
        childChanged = true;
170
      }
171
    }
172
137869
    Assert(!children.empty());
173
    // Some commutative operators have rewriters that are agnostic to order,
174
    // thus, we sort here.
175
137869
    if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
176
    {
177
77126
      childChanged = true;
178
77126
      std::sort(children.begin(), children.end());
179
    }
180
137869
    if (childChanged)
181
    {
182
87019
      if (isNonAdditive && children.size() == 1)
183
      {
184
        // we may have subsumed children down to one
185
48
        ret = children[0];
186
      }
187
86971
      else if (isAssoc
188
86971
               && children.size() > kind::metakind::getMaxArityForKind(k))
189
      {
190
2
        Assert(kind::metakind::getMaxArityForKind(k) >= 2);
191
        // kind may require binary construction
192
2
        ret = children[0];
193
6
        for (unsigned i = 1, nchild = children.size(); i < nchild; i++)
194
        {
195
4
          ret = nm->mkNode(k, ret, children[i]);
196
        }
197
      }
198
      else
199
      {
200
86969
        ret = nm->mkNode(k, children);
201
      }
202
    }
203
  }
204
145924
  ret = Rewriter::rewrite(ret);
205
  //--------------------end rewrite children
206
207
  // now, do extended rewrite
208
291848
  Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
209
145924
                               << " (from " << n << ")" << std::endl;
210
291848
  Node new_ret;
211
212
  //---------------------- theory-independent post-rewriting
213
145924
  if (ret.getKind() == ITE)
214
  {
215
10815
    new_ret = extendedRewriteIte(ITE, ret);
216
  }
217
135109
  else if (ret.getKind() == AND || ret.getKind() == OR)
218
  {
219
13374
    new_ret = extendedRewriteAndOr(ret);
220
  }
221
121735
  else if (ret.getKind() == EQUAL)
222
  {
223
20539
    new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
224
20539
    debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
225
  }
226
145924
  Assert(new_ret.isNull() || new_ret != ret);
227
145924
  if (new_ret.isNull() && ret.getKind() != ITE)
228
  {
229
    // simple ITE pulling
230
133774
    new_ret = extendedRewritePullIte(ITE, ret);
231
  }
232
  //----------------------end theory-independent post-rewriting
233
234
  //----------------------theory-specific post-rewriting
235
145924
  if (new_ret.isNull())
236
  {
237
    TheoryId tid;
238
131186
    if (ret.getKind() == ITE)
239
    {
240
9852
      tid = Theory::theoryOf(ret.getType());
241
    }
242
    else
243
    {
244
121334
      tid = Theory::theoryOf(ret);
245
    }
246
262372
    Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
247
131186
                                 << std::endl;
248
131186
    if (tid == THEORY_STRINGS)
249
    {
250
2003
      new_ret = extendedRewriteStrings(ret);
251
    }
252
  }
253
  //----------------------end theory-specific post-rewriting
254
255
  //----------------------aggressive rewrites
256
145924
  if (new_ret.isNull() && d_aggr)
257
  {
258
130614
    new_ret = extendedRewriteAggr(ret);
259
  }
260
  //----------------------end aggressive rewrites
261
262
145924
  setCache(n, ret);
263
145924
  if (!new_ret.isNull())
264
  {
265
15996
    ret = extendedRewrite(new_ret);
266
  }
267
291848
  Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
268
145924
                               << std::endl;
269
145924
  if (Trace.isOn("q-ext-rewrite-nf"))
270
  {
271
    if (n == ret)
272
    {
273
      Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
274
    }
275
  }
276
145924
  setCache(n, ret);
277
145924
  return ret;
278
}
279
280
130614
Node ExtendedRewriter::extendedRewriteAggr(Node n)
281
{
282
130614
  Node new_ret;
283
261228
  Trace("q-ext-rewrite-debug2")
284
130614
      << "Do aggressive rewrites on " << n << std::endl;
285
130614
  bool polarity = n.getKind() != NOT;
286
261228
  Node ret_atom = n.getKind() == NOT ? n[0] : n;
287
277136
  if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
288
382268
      || ret_atom.getKind() == GEQ)
289
  {
290
    // ITE term removal in polynomials
291
    // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
292
51924
    Trace("q-ext-rewrite-debug2")
293
25962
        << "Compute monomial sum " << ret_atom << std::endl;
294
    // compute monomial sum
295
51924
    std::map<Node, Node> msum;
296
25962
    if (ArithMSum::getMonomialSumLit(ret_atom, msum))
297
    {
298
84503
      for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
299
           ++itm)
300
      {
301
118070
        Node v = itm->first;
302
119058
        Trace("q-ext-rewrite-debug2")
303
59529
            << itm->first << " * " << itm->second << std::endl;
304
59529
        if (v.getKind() == ITE)
305
        {
306
12076
          Node veq;
307
6532
          int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
308
6532
          if (res != 0)
309
          {
310
13064
            Trace("q-ext-rewrite-debug")
311
6532
                << "  have ITE relation, solved form : " << veq << std::endl;
312
            // try pulling ITE
313
6532
            new_ret = extendedRewritePullIte(ITE, veq);
314
6532
            if (!new_ret.isNull())
315
            {
316
988
              if (!polarity)
317
              {
318
                new_ret = new_ret.negate();
319
              }
320
988
              break;
321
            }
322
          }
323
          else
324
          {
325
            Trace("q-ext-rewrite-debug")
326
                << "  failed to isolate " << v << " in " << n << std::endl;
327
          }
328
        }
329
      }
330
    }
331
    else
332
    {
333
      Trace("q-ext-rewrite-debug")
334
          << "  failed to get monomial sum of " << n << std::endl;
335
    }
336
  }
337
  // TODO (#1706) : conditional rewriting, condition merging
338
261228
  return new_ret;
339
}
340
341
20060
Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
342
{
343
20060
  Assert(n.getKind() == itek);
344
20060
  Assert(n[1] != n[2]);
345
346
20060
  NodeManager* nm = NodeManager::currentNM();
347
348
20060
  Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
349
350
40120
  Node flip_cond;
351
20060
  if (n[0].getKind() == NOT)
352
  {
353
    flip_cond = n[0][0];
354
  }
355
20060
  else if (n[0].getKind() == OR)
356
  {
357
    // a | b ---> ~( ~a & ~b )
358
127
    flip_cond = TermUtil::simpleNegate(n[0]);
359
  }
360
20060
  if (!flip_cond.isNull())
361
  {
362
254
    Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
363
    // only print debug trace if full=true
364
127
    if (full)
365
    {
366
127
      debugExtendedRewrite(n, new_ret, "ITE flip");
367
    }
368
127
    return new_ret;
369
  }
370
  // Boolean true/false return
371
39866
  TypeNode tn = n.getType();
372
19933
  if (tn.isBoolean())
373
  {
374
29061
    for (unsigned i = 1; i <= 2; i++)
375
    {
376
19374
      if (n[i].isConst())
377
      {
378
        Node cond = i == 1 ? n[0] : n[0].negate();
379
        Node other = n[i == 1 ? 2 : 1];
380
        Kind retk = AND;
381
        if (n[i].getConst<bool>())
382
        {
383
          retk = OR;
384
        }
385
        else
386
        {
387
          cond = cond.negate();
388
        }
389
        Node new_ret = nm->mkNode(retk, cond, other);
390
        if (full)
391
        {
392
          // ite( A, true, B ) ---> A V B
393
          // ite( A, false, B ) ---> ~A /\ B
394
          // ite( A, B,  true ) ---> ~A V B
395
          // ite( A, B, false ) ---> A /\ B
396
          debugExtendedRewrite(n, new_ret, "ITE const return");
397
        }
398
        return new_ret;
399
      }
400
    }
401
  }
402
403
  // get entailed equalities in the condition
404
39866
  std::vector<Node> eq_conds;
405
19933
  Kind ck = n[0].getKind();
406
19933
  if (ck == EQUAL)
407
  {
408
6778
    eq_conds.push_back(n[0]);
409
  }
410
13155
  else if (ck == AND)
411
  {
412
1149
    for (const Node& cn : n[0])
413
    {
414
791
      if (cn.getKind() == EQUAL)
415
      {
416
218
        eq_conds.push_back(cn);
417
      }
418
    }
419
  }
420
421
39866
  Node new_ret;
422
39866
  Node b;
423
39866
  Node e;
424
39866
  Node t1 = n[1];
425
39866
  Node t2 = n[2];
426
39866
  std::stringstream ss_reason;
427
428
26863
  for (const Node& eq : eq_conds)
429
  {
430
    // simple invariant ITE
431
20890
    for (unsigned i = 0; i <= 1; i++)
432
    {
433
      // ite( x = y ^ C, y, x ) ---> x
434
      // this is subsumed by the rewrites below
435
13960
      if (t2 == eq[i] && t1 == eq[1 - i])
436
      {
437
66
        new_ret = t2;
438
66
        ss_reason << "ITE simple rev subs";
439
66
        break;
440
      }
441
    }
442
6996
    if (!new_ret.isNull())
443
    {
444
66
      break;
445
    }
446
  }
447
19933
  if (new_ret.isNull())
448
  {
449
    // merging branches
450
59494
    for (unsigned i = 1; i <= 2; i++)
451
    {
452
39714
      if (n[i].getKind() == ITE)
453
      {
454
1218
        Node no = n[3 - i];
455
1722
        for (unsigned j = 1; j <= 2; j++)
456
        {
457
1200
          if (n[i][j] == no)
458
          {
459
            // e.g.
460
            // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
461
174
            Node nc1 = i == 2 ? n[0].negate() : n[0];
462
174
            Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
463
174
            Node new_cond = nm->mkNode(AND, nc1, nc2);
464
87
            new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
465
87
            ss_reason << "ITE merge branch";
466
87
            break;
467
          }
468
        }
469
      }
470
39714
      if (!new_ret.isNull())
471
      {
472
87
        break;
473
      }
474
    }
475
  }
476
477
19933
  if (new_ret.isNull() && d_aggr)
478
  {
479
    // If x is less than t based on an ordering, then we use { x -> t } as a
480
    // substitution to the children of ite( x = t ^ C, s, t ) below.
481
39552
    std::vector<Node> vars;
482
39552
    std::vector<Node> subs;
483
19776
    inferSubstitution(n[0], vars, subs, true);
484
485
19776
    if (!vars.empty())
486
    {
487
      // reverse substitution to opposite child
488
      // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
489
      // We can use ordinary substitute since the result of the substitution
490
      // is not being returned. In other words, nn is only being used to query
491
      // whether the second branch is a generalization of the first.
492
      Node nn =
493
39552
          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
494
19776
      if (nn != t2)
495
      {
496
2377
        nn = Rewriter::rewrite(nn);
497
2377
        if (nn == t1)
498
        {
499
14
          new_ret = t2;
500
14
          ss_reason << "ITE rev subs";
501
        }
502
      }
503
504
      // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
505
      // must use partial substitute here, to avoid substitution into witness
506
39552
      std::map<Kind, bool> rkinds;
507
19776
      nn = partialSubstitute(t1, vars, subs, rkinds);
508
19776
      if (nn != t1)
509
      {
510
        // If full=false, then we've duplicated a term u in the children of n.
511
        // For example, when ITE pulling, we have n is of the form:
512
        //   ite( C, f( u, t1 ), f( u, t2 ) )
513
        // We must show that at least one copy of u dissappears in this case.
514
1384
        nn = Rewriter::rewrite(nn);
515
1384
        if (nn == t2)
516
        {
517
28
          new_ret = nn;
518
28
          ss_reason << "ITE subs invariant";
519
        }
520
1356
        else if (full || nn.isConst())
521
        {
522
616
          new_ret = nm->mkNode(itek, n[0], nn, t2);
523
616
          ss_reason << "ITE subs";
524
        }
525
      }
526
    }
527
19776
    if (new_ret.isNull())
528
    {
529
      // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
530
      // use partial substitute to avoid substitution into witness
531
38236
      std::map<Node, Node> assign;
532
19118
      assign[n[0]] = d_false;
533
38236
      std::map<Kind, bool> rkinds;
534
38236
      Node nn = partialSubstitute(t2, assign, rkinds);
535
19118
      if (nn != t2)
536
      {
537
327
        nn = Rewriter::rewrite(nn);
538
327
        if (nn == t1)
539
        {
540
26
          new_ret = nn;
541
26
          ss_reason << "ITE subs invariant false";
542
        }
543
301
        else if (full || nn.isConst())
544
        {
545
101
          new_ret = nm->mkNode(itek, n[0], t1, nn);
546
101
          ss_reason << "ITE subs false";
547
        }
548
      }
549
    }
550
  }
551
552
  // only print debug trace if full=true
553
19933
  if (!new_ret.isNull() && full)
554
  {
555
836
    debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
556
  }
557
558
19933
  return new_ret;
559
}
560
561
13374
Node ExtendedRewriter::extendedRewriteAndOr(Node n)
562
{
563
  // all the below rewrites are aggressive
564
13374
  if (!d_aggr)
565
  {
566
18
    return Node::null();
567
  }
568
26712
  Node new_ret;
569
  // we allow substitutions to recurse over any kind, except WITNESS which is
570
  // managed by partialSubstitute.
571
26712
  std::map<Kind, bool> bcp_kinds;
572
13356
  new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
573
13356
  if (!new_ret.isNull())
574
  {
575
692
    debugExtendedRewrite(n, new_ret, "Bool bcp");
576
692
    return new_ret;
577
  }
578
  // factoring
579
12664
  new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
580
12664
  if (!new_ret.isNull())
581
  {
582
38
    debugExtendedRewrite(n, new_ret, "Bool factoring");
583
38
    return new_ret;
584
  }
585
586
  // equality resolution
587
12626
  new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
588
12626
  debugExtendedRewrite(n, new_ret, "Bool eq res");
589
12626
  return new_ret;
590
}
591
592
140306
Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
593
{
594
140306
  Assert(n.getKind() != ITE);
595
140306
  if (n.isClosure())
596
  {
597
    // don't pull ITE out of quantifiers
598
163
    return n;
599
  }
600
140143
  NodeManager* nm = NodeManager::currentNM();
601
280286
  TypeNode tn = n.getType();
602
280286
  std::vector<Node> children;
603
140143
  bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
604
140143
  if (hasOp)
605
  {
606
5233
    children.push_back(n.getOperator());
607
  }
608
140143
  unsigned nchildren = n.getNumChildren();
609
405937
  for (unsigned i = 0; i < nchildren; i++)
610
  {
611
265794
    children.push_back(n[i]);
612
  }
613
280286
  std::map<unsigned, std::map<unsigned, Node> > ite_c;
614
386835
  for (unsigned i = 0; i < nchildren; i++)
615
  {
616
    // only pull ITEs apart if we are aggressive
617
779061
    if (n[i].getKind() == itek
618
779061
        && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
619
    {
620
22420
      unsigned ii = hasOp ? i + 1 : i;
621
67260
      for (unsigned j = 0; j < 2; j++)
622
      {
623
44840
        children[ii] = n[i][j + 1];
624
89680
        Node pull = nm->mkNode(n.getKind(), children);
625
89680
        Node pullr = Rewriter::rewrite(pull);
626
44840
        children[ii] = n[i];
627
44840
        ite_c[i][j] = pullr;
628
      }
629
22420
      if (ite_c[i][0] == ite_c[i][1])
630
      {
631
        // ITE dual invariance
632
        // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
633
        // f( t1..ite( A, s1, s2 )..tn ) ---> t
634
638
        debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
635
638
        return ite_c[i][0];
636
      }
637
21782
      if (d_aggr)
638
      {
639
42903
        if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
640
35771
            && !n[1 - i].getType().isBoolean() && tn.isBoolean())
641
        {
642
          // always pull variable or constant with binary (theory) predicate
643
          // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
644
18922
          Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
645
9461
          debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
646
9461
          return new_ret;
647
        }
648
32628
        for (unsigned j = 0; j < 2; j++)
649
        {
650
43514
          Node pullr = ite_c[i][j];
651
23205
          if (pullr.isConst() || pullr == n[i][j + 1])
652
          {
653
            // ITE single child elimination
654
            // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
655
            // implies
656
            // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
657
5792
            Node new_ret;
658
2896
            if (tn.isBoolean() && pullr.isConst())
659
            {
660
              // remove false/true child immediately
661
816
              bool pol = pullr.getConst<bool>();
662
1632
              std::vector<Node> new_children;
663
816
              new_children.push_back((j == 0) == pol ? n[i][0]
664
                                                     : n[i][0].negate());
665
816
              new_children.push_back(ite_c[i][1 - j]);
666
816
              new_ret = nm->mkNode(pol ? OR : AND, new_children);
667
816
              debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
668
            }
669
            else
670
            {
671
2080
              new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
672
2080
              debugExtendedRewrite(n, new_ret, "ITE single elim");
673
            }
674
2896
            return new_ret;
675
          }
676
        }
677
      }
678
    }
679
  }
680
127148
  if (d_aggr)
681
  {
682
135993
    for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
683
    {
684
18556
      Node nite = n[ip.first];
685
9413
      Assert(nite.getKind() == itek);
686
      // now, simply pull the ITE and try ITE rewrites
687
18556
      Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
688
9413
      pull_ite = Rewriter::rewrite(pull_ite);
689
9413
      if (pull_ite.getKind() == ITE)
690
      {
691
18388
        Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
692
9245
        if (!new_pull_ite.isNull())
693
        {
694
102
          debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
695
102
          return new_pull_ite;
696
        }
697
      }
698
      else
699
      {
700
        // A general rewrite could eliminate the ITE by pulling.
701
        // An example is:
702
        //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
703
        //   ite( C, ~~x, ite( C, y, x ) ) --->
704
        //   x
705
        // where ~ is bitvector negation.
706
168
        debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
707
168
        return pull_ite;
708
      }
709
    }
710
  }
711
712
126878
  return Node::null();
713
}
714
715
13920
Node ExtendedRewriter::extendedRewriteNnf(Node ret)
716
{
717
13920
  Assert(ret.getKind() == NOT);
718
719
13920
  Kind nk = ret[0].getKind();
720
13920
  bool neg_ch = false;
721
13920
  bool neg_ch_1 = false;
722
13920
  if (nk == AND || nk == OR)
723
  {
724
1547
    neg_ch = true;
725
1547
    nk = nk == AND ? OR : AND;
726
  }
727
12373
  else if (nk == IMPLIES)
728
  {
729
    neg_ch = true;
730
    neg_ch_1 = true;
731
    nk = AND;
732
  }
733
12373
  else if (nk == ITE)
734
  {
735
90
    neg_ch = true;
736
90
    neg_ch_1 = true;
737
  }
738
12283
  else if (nk == XOR)
739
  {
740
15
    nk = EQUAL;
741
  }
742
12268
  else if (nk == EQUAL && ret[0][0].getType().isBoolean())
743
  {
744
5
    neg_ch_1 = true;
745
  }
746
  else
747
  {
748
12263
    return Node::null();
749
  }
750
751
3314
  std::vector<Node> new_children;
752
6044
  for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
753
  {
754
8774
    Node c = ret[0][i];
755
4387
    c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
756
4387
    new_children.push_back(c);
757
  }
758
1657
  return NodeManager::currentNM()->mkNode(nk, new_children);
759
}
760
761
13356
Node ExtendedRewriter::extendedRewriteBcp(
762
    Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret)
763
{
764
13356
  Kind k = ret.getKind();
765
13356
  Assert(k == andk || k == ork);
766
13356
  Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
767
768
13356
  NodeManager* nm = NodeManager::currentNM();
769
770
26712
  TypeNode tn = ret.getType();
771
26712
  Node truen = TermUtil::mkTypeMaxValue(tn);
772
26712
  Node falsen = TermUtil::mkTypeValue(tn, 0);
773
774
  // terms to process
775
26712
  std::vector<Node> to_process;
776
48020
  for (const Node& cn : ret)
777
  {
778
34664
    to_process.push_back(cn);
779
  }
780
  // the processing terms
781
26712
  std::vector<Node> clauses;
782
  // the terms we have propagated information to
783
26712
  std::unordered_set<Node, NodeHashFunction> prop_clauses;
784
  // the assignment
785
26712
  std::map<Node, Node> assign;
786
26712
  std::vector<Node> avars;
787
26712
  std::vector<Node> asubs;
788
789
13356
  Kind ok = k == andk ? ork : andk;
790
  // global polarity : when k=ork, everything is negated
791
13356
  bool gpol = k == andk;
792
793
629
  do
794
  {
795
    // process the current nodes
796
41933
    while (!to_process.empty())
797
    {
798
28059
      std::vector<Node> new_to_process;
799
49494
      for (const Node& cn : to_process)
800
      {
801
35520
        Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
802
35520
        Kind cnk = cn.getKind();
803
35520
        bool pol = cnk != notk;
804
70929
        Node cln = cnk == notk ? cn[0] : cn;
805
35520
        Assert(cln.getKind() != notk);
806
35520
        if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
807
        {
808
          // flatten
809
300
          for (const Node& ccln : cln)
810
          {
811
400
            Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
812
200
            new_to_process.push_back(lccln);
813
          }
814
        }
815
        else
816
        {
817
          // add it to the assignment
818
70729
          Node val = gpol == pol ? truen : falsen;
819
35420
          std::map<Node, Node>::iterator it = assign.find(cln);
820
70840
          Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
821
35420
                               << std::endl;
822
35420
          if (it != assign.end())
823
          {
824
117
            if (val != it->second)
825
            {
826
111
              Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
827
              // a conflicting assignment: we are done
828
111
              return gpol ? falsen : truen;
829
            }
830
          }
831
          else
832
          {
833
35303
            assign[cln] = val;
834
35303
            avars.push_back(cln);
835
35303
            asubs.push_back(val);
836
          }
837
838
          // also, treat it as clause if possible
839
70618
          if (cln.getNumChildren() > 0
840
69153
              && (bcp_kinds.empty()
841
35309
                  || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
842
          {
843
67688
            if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
844
33844
                && prop_clauses.find(cn) == prop_clauses.end())
845
            {
846
33844
              Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
847
33844
              clauses.push_back(cn);
848
            }
849
          }
850
        }
851
      }
852
13974
      to_process.clear();
853
13974
      to_process.insert(
854
27948
          to_process.end(), new_to_process.begin(), new_to_process.end());
855
    }
856
857
    // apply substitution to all subterms of clauses
858
27748
    std::vector<Node> new_clauses;
859
49520
    for (const Node& c : clauses)
860
    {
861
35646
      bool cpol = c.getKind() != notk;
862
71292
      Node ca = c.getKind() == notk ? c[0] : c;
863
35646
      bool childChanged = false;
864
71292
      std::vector<Node> ccs_children;
865
111381
      for (const Node& cc : ca)
866
      {
867
        // always use partial substitute, to avoid substitution in witness
868
75735
        Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
869
        // substitution is only applicable to compatible kinds in bcp_kinds
870
151470
        Node ccs = partialSubstitute(cc, assign, bcp_kinds);
871
75735
        childChanged = childChanged || ccs != cc;
872
75735
        ccs_children.push_back(ccs);
873
      }
874
35646
      if (childChanged)
875
      {
876
677
        if (ca.getMetaKind() == metakind::PARAMETERIZED)
877
        {
878
          ccs_children.insert(ccs_children.begin(), ca.getOperator());
879
        }
880
1354
        Node ccs = nm->mkNode(ca.getKind(), ccs_children);
881
677
        ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
882
1354
        Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
883
677
                             << std::endl;
884
677
        ccs = Rewriter::rewrite(ccs);
885
677
        Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
886
677
        to_process.push_back(ccs);
887
        // store this as a node that propagation touched. This marks c so that
888
        // it will not be included in the final construction.
889
677
        prop_clauses.insert(ca);
890
      }
891
      else
892
      {
893
34969
        new_clauses.push_back(c);
894
      }
895
    }
896
13874
    clauses.clear();
897
13874
    clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
898
13874
  } while (!to_process.empty());
899
900
  // remake the node
901
13245
  if (!prop_clauses.empty())
902
  {
903
1162
    std::vector<Node> children;
904
3719
    for (std::pair<const Node, Node>& l : assign)
905
    {
906
6276
      Node a = l.first;
907
      // if propagation did not touch a
908
3138
      if (prop_clauses.find(a) == prop_clauses.end())
909
      {
910
2461
        Assert(l.second == truen || l.second == falsen);
911
4922
        Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
912
2461
        children.push_back(ln);
913
      }
914
    }
915
1162
    Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
916
581
    Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
917
581
    return new_ret;
918
  }
919
920
12664
  return Node::null();
921
}
922
923
12664
Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
924
                                                Kind ork,
925
                                                Kind notk,
926
                                                Node n)
927
{
928
12664
  Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
929
12664
  NodeManager* nm = NodeManager::currentNM();
930
931
12664
  Kind nk = n.getKind();
932
12664
  Assert(nk == andk || nk == ork);
933
12664
  Kind onk = nk == andk ? ork : andk;
934
  // count the number of times atoms occur
935
25328
  std::map<Node, std::vector<Node> > lit_to_cl;
936
25328
  std::map<Node, std::vector<Node> > cl_to_lits;
937
44656
  for (const Node& nc : n)
938
  {
939
31992
    Kind nck = nc.getKind();
940
31992
    if (nck == onk)
941
    {
942
8249
      for (const Node& ncl : nc)
943
      {
944
17808
        if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
945
17808
            == lit_to_cl[ncl].end())
946
        {
947
5936
          lit_to_cl[ncl].push_back(nc);
948
5936
          cl_to_lits[nc].push_back(ncl);
949
        }
950
      }
951
    }
952
    else
953
    {
954
29679
      lit_to_cl[nc].push_back(nc);
955
29679
      cl_to_lits[nc].push_back(nc);
956
    }
957
  }
958
  // get the maximum shared literal to factor
959
12664
  unsigned max_size = 0;
960
25328
  Node flit;
961
48029
  for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
962
  {
963
35365
    if (ltc.second.size() > max_size)
964
    {
965
12698
      max_size = ltc.second.size();
966
12698
      flit = ltc.first;
967
    }
968
  }
969
12664
  if (max_size > 1)
970
  {
971
    // do the factoring
972
76
    std::vector<Node> children;
973
76
    std::vector<Node> fchildren;
974
38
    std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
975
38
    std::vector<Node>& cls = itl->second;
976
510
    for (const Node& nc : n)
977
    {
978
472
      if (std::find(cls.begin(), cls.end(), nc) == cls.end())
979
      {
980
262
        children.push_back(nc);
981
      }
982
      else
983
      {
984
        // rebuild
985
210
        std::vector<Node>& lits = cl_to_lits[nc];
986
        std::vector<Node>::iterator itlfl =
987
210
            std::find(lits.begin(), lits.end(), flit);
988
210
        Assert(itlfl != lits.end());
989
210
        lits.erase(itlfl);
990
        // rebuild
991
210
        if (!lits.empty())
992
        {
993
420
          Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
994
210
          fchildren.push_back(new_cl);
995
        }
996
      }
997
    }
998
    // rebuild the factored children
999
38
    Assert(!fchildren.empty());
1000
76
    Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
1001
38
    children.push_back(nm->mkNode(onk, flit, fcn));
1002
76
    Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
1003
38
    Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
1004
38
    return ret;
1005
  }
1006
  else
1007
  {
1008
12626
    Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
1009
  }
1010
12626
  return Node::null();
1011
}
1012
1013
12626
Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
1014
                                            Kind ork,
1015
                                            Kind eqk,
1016
                                            Kind notk,
1017
                                            std::map<Kind, bool>& bcp_kinds,
1018
                                            Node n,
1019
                                            bool isXor)
1020
{
1021
12626
  Assert(n.getKind() == andk || n.getKind() == ork);
1022
12626
  Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
1023
1024
12626
  NodeManager* nm = NodeManager::currentNM();
1025
12626
  Kind nk = n.getKind();
1026
12626
  bool gpol = (nk == andk);
1027
43368
  for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
1028
  {
1029
61820
    Node lit = n[i];
1030
31078
    if (lit.getKind() == eqk)
1031
    {
1032
      // eq is the equality we are basing a substitution on
1033
11854
      Node eq;
1034
6095
      if (gpol == isXor)
1035
      {
1036
        // can only turn disequality into equality if types are the same
1037
2507
        if (lit[1].getType() == lit.getType())
1038
        {
1039
          // t != s ---> ~t = s
1040
34
          if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1041
          {
1042
24
            eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1043
          }
1044
          else
1045
          {
1046
10
            eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1047
          }
1048
        }
1049
      }
1050
      else
1051
      {
1052
3588
        eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1053
      }
1054
6095
      if (!eq.isNull())
1055
      {
1056
        // see if it corresponds to a substitution
1057
6908
        std::vector<Node> vars;
1058
6908
        std::vector<Node> subs;
1059
3622
        if (inferSubstitution(eq, vars, subs))
1060
        {
1061
3120
          Assert(vars.size() == 1);
1062
5904
          std::vector<Node> children;
1063
3120
          bool childrenChanged = false;
1064
          // apply to all other children
1065
17405
          for (unsigned j = 0; j < nchild; j++)
1066
          {
1067
28570
            Node ccs = n[j];
1068
14285
            if (i != j)
1069
            {
1070
              // Substitution is only applicable to compatible kinds. We always
1071
              // use the partialSubstitute method to avoid substitution into
1072
              // witness terms.
1073
11165
              ccs = partialSubstitute(ccs, vars, subs, bcp_kinds);
1074
11165
              childrenChanged = childrenChanged || n[j] != ccs;
1075
            }
1076
14285
            children.push_back(ccs);
1077
          }
1078
3120
          if (childrenChanged)
1079
          {
1080
336
            return nm->mkNode(nk, children);
1081
          }
1082
        }
1083
      }
1084
    }
1085
  }
1086
1087
12290
  return Node::null();
1088
}
1089
1090
/** sort pairs by their second (unsigned) argument */
1091
41148
static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1092
                           const std::pair<Node, unsigned>& b)
1093
{
1094
41148
  return (a.second < b.second);
1095
}
1096
1097
/** A simple subsumption trie used to compute pairwise list subsets */
1098
123904
class SimpSubsumeTrie
1099
{
1100
 public:
1101
  /** the children of this node */
1102
  std::map<Node, SimpSubsumeTrie> d_children;
1103
  /** the term at this node */
1104
  Node d_data;
1105
  /** add term to the trie
1106
   *
1107
   * This adds term c to this trie, whose atom list is alist. This adds terms
1108
   * s to subsumes such that the atom list of s is a subset of the atom list
1109
   * of c. For example, say:
1110
   *   c1.alist = { A }
1111
   *   c2.alist = { C }
1112
   *   c3.alist = { B, C }
1113
   *   c4.alist = { A, B, D }
1114
   *   c5.alist = { A, B, C }
1115
   * If these terms are added in the order c1, c2, c3, c4, c5, then:
1116
   *   addTerm c1 results in subsumes = {}
1117
   *   addTerm c2 results in subsumes = {}
1118
   *   addTerm c3 results in subsumes = { c2 }
1119
   *   addTerm c4 results in subsumes = { c1 }
1120
   *   addTerm c5 results in subsumes = { c1, c2, c3 }
1121
   * Notice that the intended use case of this trie is to add term t before t'
1122
   * only when size( t.alist ) <= size( t'.alist ).
1123
   *
1124
   * The last two arguments describe the state of the path [t0...tn] we
1125
   * have followed in the trie during the recursive call.
1126
   * If doAdd = true,
1127
   *   then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1128
   *   we add c as the current node of this trie.
1129
   * If doAdd = false,
1130
   *   then t1...tn occur in alist.
1131
   */
1132
82521
  void addTerm(Node c,
1133
               std::vector<Node>& alist,
1134
               std::vector<Node>& subsumes,
1135
               unsigned index = 0,
1136
               bool doAdd = true)
1137
  {
1138
82521
    if (!d_data.isNull())
1139
    {
1140
6
      subsumes.push_back(d_data);
1141
    }
1142
82521
    if (doAdd)
1143
    {
1144
82509
      if (index == alist.size())
1145
      {
1146
40992
        d_data = c;
1147
40992
        return;
1148
      }
1149
    }
1150
    // try all children where we have this atom
1151
62335
    for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1152
    {
1153
20806
      if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1154
      {
1155
12
        cp.second.addTerm(c, alist, subsumes, 0, false);
1156
      }
1157
    }
1158
41529
    if (doAdd)
1159
    {
1160
41517
      d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1161
    }
1162
  }
1163
};
1164
1165
20539
Node ExtendedRewriter::extendedRewriteEqChain(
1166
    Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor)
1167
{
1168
20539
  Assert(ret.getKind() == eqk);
1169
1170
  // this rewrite is aggressive; it in fact has the precondition that other
1171
  // aggressive rewrites (including BCP) have been applied.
1172
20539
  if (!d_aggr)
1173
  {
1174
94
    return Node::null();
1175
  }
1176
1177
20445
  NodeManager* nm = NodeManager::currentNM();
1178
1179
40890
  TypeNode tn = ret[0].getType();
1180
1181
  // sort/cancelling for Boolean EQUAL/XOR-chains
1182
20445
  Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1183
1184
  // get the children on either side
1185
20445
  bool gpol = true;
1186
40890
  std::vector<Node> children;
1187
61335
  for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1188
  {
1189
81780
    Node curr = ret[r];
1190
    // assume, if necessary, right associative
1191
41326
    while (curr.getKind() == eqk && curr[0].getType() == tn)
1192
    {
1193
218
      children.push_back(curr[0]);
1194
218
      curr = curr[1];
1195
    }
1196
40890
    children.push_back(curr);
1197
  }
1198
1199
40890
  std::map<Node, bool> cstatus;
1200
  // add children to status
1201
61553
  for (const Node& c : children)
1202
  {
1203
82216
    Node a = c;
1204
41108
    if (a.getKind() == notk)
1205
    {
1206
361
      gpol = !gpol;
1207
361
      a = a[0];
1208
    }
1209
41108
    Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1210
41108
    std::map<Node, bool>::iterator itc = cstatus.find(a);
1211
41108
    if (itc == cstatus.end())
1212
    {
1213
41062
      cstatus[a] = true;
1214
    }
1215
    else
1216
    {
1217
      // cancels
1218
46
      cstatus.erase(a);
1219
46
      if (isXor)
1220
      {
1221
        gpol = !gpol;
1222
      }
1223
    }
1224
  }
1225
20445
  Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1226
1227
20445
  if (cstatus.empty())
1228
  {
1229
2
    return TermUtil::mkTypeConst(tn, gpol);
1230
  }
1231
1232
20443
  children.clear();
1233
1234
  // compute the atoms of each child
1235
20443
  Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1236
20443
  Trace("ext-rew-eqchain") << "  eqchain-simplify: get atoms...\n";
1237
40886
  std::map<Node, std::map<Node, bool> > atoms;
1238
40886
  std::map<Node, std::vector<Node> > alist;
1239
40886
  std::vector<std::pair<Node, unsigned> > atom_count;
1240
61459
  for (std::pair<const Node, bool>& cp : cstatus)
1241
  {
1242
41016
    if (!cp.second)
1243
    {
1244
      // already eliminated
1245
      continue;
1246
    }
1247
82032
    Node c = cp.first;
1248
41016
    Kind ck = c.getKind();
1249
41016
    Trace("ext-rew-eqchain") << "  process c = " << c << std::endl;
1250
41016
    if (ck == andk || ck == ork)
1251
    {
1252
815
      for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1253
      {
1254
1340
        Node cl = c[j];
1255
676
        bool pol = cl.getKind() != notk;
1256
1340
        Node ca = pol ? cl : cl[0];
1257
676
        bool newVal = (ck == andk ? !pol : pol);
1258
1352
        Trace("ext-rew-eqchain")
1259
676
            << "  atoms(" << c << ", " << ca << ") = " << newVal << std::endl;
1260
676
        Assert(atoms[c].find(ca) == atoms[c].end());
1261
        // polarity is flipped when we are AND
1262
676
        atoms[c][ca] = newVal;
1263
676
        alist[c].push_back(ca);
1264
1265
        // if this already exists as a child of the equality chain, eliminate.
1266
        // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1267
        // consider ( x & y ) a unit, whereas below it is expanded to
1268
        // ~( ~x | ~y ).
1269
676
        std::map<Node, bool>::iterator itc = cstatus.find(ca);
1270
676
        if (itc != cstatus.end() && itc->second)
1271
        {
1272
          // cancel it
1273
12
          cstatus[ca] = false;
1274
12
          cstatus[c] = false;
1275
          // make new child
1276
          // x = ( y | ~x ) ---> y & x
1277
          // x = ( y | x ) ---> ~y | x
1278
          // x = ( y & x ) ---> y | ~x
1279
          // x = ( y & ~x ) ---> ~y & ~x
1280
24
          std::vector<Node> new_children;
1281
36
          for (unsigned k = 0, nchildc = c.getNumChildren(); k < nchildc; k++)
1282
          {
1283
24
            if (j != k)
1284
            {
1285
12
              new_children.push_back(c[k]);
1286
            }
1287
          }
1288
24
          Node nc[2];
1289
12
          nc[0] = c[j];
1290
12
          nc[1] = new_children.size() == 1 ? new_children[0]
1291
                                           : nm->mkNode(ck, new_children);
1292
          // negate the proper child
1293
12
          unsigned nindex = (ck == andk) == pol ? 0 : 1;
1294
12
          nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1295
12
          Kind nk = pol ? ork : andk;
1296
          // store as new child
1297
12
          children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1298
12
          if (isXor)
1299
          {
1300
            gpol = !gpol;
1301
          }
1302
12
          break;
1303
        }
1304
151
      }
1305
    }
1306
    else
1307
    {
1308
40865
      bool pol = ck != notk;
1309
81730
      Node ca = pol ? c : c[0];
1310
40865
      atoms[c][ca] = pol;
1311
81730
      Trace("ext-rew-eqchain")
1312
40865
          << "  atoms(" << c << ", " << ca << ") = " << pol << std::endl;
1313
40865
      alist[c].push_back(ca);
1314
    }
1315
41016
    atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1316
  }
1317
  // sort the atoms in each atom list
1318
61459
  for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1319
61459
       it != alist.end();
1320
       ++it)
1321
  {
1322
41016
    std::sort(it->second.begin(), it->second.end());
1323
  }
1324
  // check subsumptions
1325
  // sort by #atoms
1326
20443
  std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1327
20443
  if (Trace.isOn("ext-rew-eqchain"))
1328
  {
1329
    for (const std::pair<Node, unsigned>& ac : atom_count)
1330
    {
1331
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << ac.first << " has "
1332
                               << ac.second << " atoms." << std::endl;
1333
    }
1334
    Trace("ext-rew-eqchain") << "  eqchain-simplify: compute subsumptions...\n";
1335
  }
1336
40886
  SimpSubsumeTrie sst;
1337
61459
  for (std::pair<const Node, bool>& cp : cstatus)
1338
  {
1339
41016
    if (!cp.second)
1340
    {
1341
      // already eliminated
1342
24
      continue;
1343
    }
1344
81984
    Node c = cp.first;
1345
40992
    std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1346
40992
    Assert(itc != atoms.end());
1347
81984
    Trace("ext-rew-eqchain") << "  - add term " << c << " with atom list "
1348
40992
                             << alist[c] << "...\n";
1349
81984
    std::vector<Node> subsumes;
1350
40992
    sst.addTerm(c, alist[c], subsumes);
1351
40992
    for (const Node& cc : subsumes)
1352
    {
1353
2
      if (!cstatus[cc])
1354
      {
1355
        // subsumes a child that was already eliminated
1356
        continue;
1357
      }
1358
4
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << c << " subsumes "
1359
2
                               << cc << std::endl;
1360
      // for each of the atoms in cc
1361
2
      std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1362
2
      Assert(itcc != atoms.end());
1363
2
      std::vector<Node> common_children;
1364
2
      std::vector<Node> diff_children;
1365
6
      for (const std::pair<const Node, bool>& ap : itcc->second)
1366
      {
1367
        // compare the polarity
1368
8
        Node a = ap.first;
1369
4
        bool polcc = ap.second;
1370
4
        Assert(itc->second.find(a) != itc->second.end());
1371
4
        bool polc = itc->second[a];
1372
8
        Trace("ext-rew-eqchain") << "    eqchain-simplify: atom " << a
1373
4
                                 << " has polarities : " << polc << " " << polcc
1374
4
                                 << "\n";
1375
8
        Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1376
4
        if (polc != polcc)
1377
        {
1378
4
          diff_children.push_back(lit);
1379
        }
1380
        else
1381
        {
1382
          common_children.push_back(lit);
1383
        }
1384
      }
1385
2
      std::vector<Node> rem_children;
1386
6
      for (const std::pair<const Node, bool>& ap : itc->second)
1387
      {
1388
8
        Node a = ap.first;
1389
4
        if (atoms[cc].find(a) == atoms[cc].end())
1390
        {
1391
          bool polc = ap.second;
1392
          rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1393
        }
1394
      }
1395
4
      Trace("ext-rew-eqchain")
1396
4
          << "    #common/diff/rem: " << common_children.size() << "/"
1397
2
          << diff_children.size() << "/" << rem_children.size() << "\n";
1398
2
      bool do_rewrite = false;
1399
6
      if (common_children.empty() && itc->second.size() == itcc->second.size()
1400
4
          && itcc->second.size() == 2)
1401
      {
1402
        // x | y = ~x | ~y ---> ~( x = y )
1403
2
        do_rewrite = true;
1404
2
        children.push_back(diff_children[0]);
1405
2
        children.push_back(diff_children[1]);
1406
2
        gpol = !gpol;
1407
2
        Trace("ext-rew-eqchain") << "    apply 2-child all-diff\n";
1408
      }
1409
      else if (common_children.empty() && diff_children.size() == 1)
1410
      {
1411
        do_rewrite = true;
1412
        // x = ( ~x | y ) ---> ~( ~x | ~y )
1413
        Node remn = rem_children.size() == 1 ? rem_children[0]
1414
                                             : nm->mkNode(ork, rem_children);
1415
        remn = TermUtil::mkNegate(notk, remn);
1416
        children.push_back(nm->mkNode(ork, diff_children[0], remn));
1417
        if (!isXor)
1418
        {
1419
          gpol = !gpol;
1420
        }
1421
        Trace("ext-rew-eqchain") << "    apply unit resolution\n";
1422
      }
1423
      else if (diff_children.size() == 1
1424
               && itc->second.size() == itcc->second.size())
1425
      {
1426
        // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1427
        do_rewrite = true;
1428
        Assert(!common_children.empty());
1429
        Node comn = common_children.size() == 1
1430
                        ? common_children[0]
1431
                        : nm->mkNode(ork, common_children);
1432
        children.push_back(comn);
1433
        if (isXor)
1434
        {
1435
          gpol = !gpol;
1436
        }
1437
        Trace("ext-rew-eqchain") << "    apply resolution\n";
1438
      }
1439
      else if (diff_children.empty())
1440
      {
1441
        do_rewrite = true;
1442
        if (rem_children.empty())
1443
        {
1444
          // x | y = x | y ---> true
1445
          // this can happen if we have ( ~x & ~y ) = ( x | y )
1446
          children.push_back(TermUtil::mkTypeMaxValue(tn));
1447
          if (isXor)
1448
          {
1449
            gpol = !gpol;
1450
          }
1451
          Trace("ext-rew-eqchain") << "    apply cancel\n";
1452
        }
1453
        else
1454
        {
1455
          // x | y = ( x | y | z ) ---> ( x | y | ~z )
1456
          Node remn = rem_children.size() == 1 ? rem_children[0]
1457
                                               : nm->mkNode(ork, rem_children);
1458
          remn = TermUtil::mkNegate(notk, remn);
1459
          Node comn = common_children.size() == 1
1460
                          ? common_children[0]
1461
                          : nm->mkNode(ork, common_children);
1462
          children.push_back(nm->mkNode(ork, comn, remn));
1463
          if (isXor)
1464
          {
1465
            gpol = !gpol;
1466
          }
1467
          Trace("ext-rew-eqchain") << "    apply subsume\n";
1468
        }
1469
      }
1470
2
      if (do_rewrite)
1471
      {
1472
        // eliminate the children, reverse polarity as needed
1473
6
        for (unsigned r = 0; r < 2; r++)
1474
        {
1475
8
          Node c_rem = r == 0 ? c : cc;
1476
4
          cstatus[c_rem] = false;
1477
4
          if (c_rem.getKind() == andk)
1478
          {
1479
            gpol = !gpol;
1480
          }
1481
        }
1482
2
        break;
1483
      }
1484
    }
1485
  }
1486
20443
  Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1487
1488
  // sorted right associative chain
1489
20443
  bool has_nvar = false;
1490
20443
  unsigned nvar_index = 0;
1491
61459
  for (std::pair<const Node, bool>& cp : cstatus)
1492
  {
1493
41016
    if (cp.second)
1494
    {
1495
40988
      if (!cp.first.isVar())
1496
      {
1497
28888
        has_nvar = true;
1498
28888
        nvar_index = children.size();
1499
      }
1500
40988
      children.push_back(cp.first);
1501
    }
1502
  }
1503
20443
  std::sort(children.begin(), children.end());
1504
1505
40886
  Node new_ret;
1506
20443
  if (!gpol)
1507
  {
1508
    // negate the constant child if it exists
1509
301
    unsigned nindex = has_nvar ? nvar_index : 0;
1510
301
    children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1511
  }
1512
20443
  new_ret = children.back();
1513
20443
  unsigned index = children.size() - 1;
1514
61565
  while (index > 0)
1515
  {
1516
20561
    index--;
1517
20561
    new_ret = nm->mkNode(eqk, children[index], new_ret);
1518
  }
1519
20443
  new_ret = Rewriter::rewrite(new_ret);
1520
20443
  if (new_ret != ret)
1521
  {
1522
267
    return new_ret;
1523
  }
1524
20176
  return Node::null();
1525
}
1526
1527
125794
Node ExtendedRewriter::partialSubstitute(Node n,
1528
                                         const std::map<Node, Node>& assign,
1529
                                         const std::map<Kind, bool>& rkinds)
1530
{
1531
251588
  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
1532
125794
  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
1533
125794
  std::map<Node, Node>::const_iterator ita;
1534
251588
  std::vector<TNode> visit;
1535
251588
  TNode cur;
1536
125794
  visit.push_back(n);
1537
870251
  do
1538
  {
1539
996045
    cur = visit.back();
1540
996045
    visit.pop_back();
1541
996045
    it = visited.find(cur);
1542
1543
996045
    if (it == visited.end())
1544
    {
1545
461497
      ita = assign.find(cur);
1546
461497
      if (ita != assign.end())
1547
      {
1548
2994
        visited[cur] = ita->second;
1549
      }
1550
      else
1551
      {
1552
        // If rkinds is non-empty, then can only recurse on its kinds.
1553
        // We also always disallow substitution into witness. Notice that
1554
        // we disallow witness here, due to unsoundness when applying contextual
1555
        // substitutions over witness terms (see #4620).
1556
458503
        Kind k = cur.getKind();
1557
458503
        if (k != WITNESS && (rkinds.empty() || rkinds.find(k) != rkinds.end()))
1558
        {
1559
458503
          visited[cur] = Node::null();
1560
458503
          visit.push_back(cur);
1561
870251
          for (const Node& cn : cur)
1562
          {
1563
411748
            visit.push_back(cn);
1564
          }
1565
        }
1566
        else
1567
        {
1568
          visited[cur] = cur;
1569
        }
1570
      }
1571
    }
1572
534548
    else if (it->second.isNull())
1573
    {
1574
917006
      Node ret = cur;
1575
458503
      bool childChanged = false;
1576
917006
      std::vector<Node> children;
1577
458503
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
1578
      {
1579
7543
        children.push_back(cur.getOperator());
1580
      }
1581
870251
      for (const Node& cn : cur)
1582
      {
1583
411748
        it = visited.find(cn);
1584
411748
        Assert(it != visited.end());
1585
411748
        Assert(!it->second.isNull());
1586
411748
        childChanged = childChanged || cn != it->second;
1587
411748
        children.push_back(it->second);
1588
      }
1589
458503
      if (childChanged)
1590
      {
1591
6349
        ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1592
      }
1593
458503
      visited[cur] = ret;
1594
    }
1595
996045
  } while (!visit.empty());
1596
125794
  Assert(visited.find(n) != visited.end());
1597
125794
  Assert(!visited.find(n)->second.isNull());
1598
251588
  return visited[n];
1599
}
1600
1601
30941
Node ExtendedRewriter::partialSubstitute(Node n,
1602
                                         const std::vector<Node>& vars,
1603
                                         const std::vector<Node>& subs,
1604
                                         const std::map<Kind, bool>& rkinds)
1605
{
1606
30941
  Assert(vars.size() == subs.size());
1607
61882
  std::map<Node, Node> assign;
1608
62295
  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
1609
  {
1610
31354
    assign[vars[i]] = subs[i];
1611
  }
1612
61882
  return partialSubstitute(n, assign, rkinds);
1613
}
1614
1615
10479
Node ExtendedRewriter::solveEquality(Node n)
1616
{
1617
  // TODO (#1706) : implement
1618
10479
  Assert(n.getKind() == EQUAL);
1619
1620
10479
  return Node::null();
1621
}
1622
1623
24149
bool ExtendedRewriter::inferSubstitution(Node n,
1624
                                         std::vector<Node>& vars,
1625
                                         std::vector<Node>& subs,
1626
                                         bool usePred)
1627
{
1628
24149
  if (n.getKind() == AND)
1629
  {
1630
338
    bool ret = false;
1631
1089
    for (const Node& nc : n)
1632
    {
1633
751
      bool cret = inferSubstitution(nc, vars, subs, usePred);
1634
751
      ret = ret || cret;
1635
    }
1636
338
    return ret;
1637
  }
1638
23811
  if (n.getKind() == EQUAL)
1639
  {
1640
    // see if it can be put into form x = y
1641
11208
    Node slv_eq = solveEquality(n);
1642
10479
    if (!slv_eq.isNull())
1643
    {
1644
      n = slv_eq;
1645
    }
1646
31437
    Node v[2];
1647
24024
    for (unsigned i = 0; i < 2; i++)
1648
    {
1649
20865
      if (n[i].isConst())
1650
      {
1651
7320
        vars.push_back(n[1 - i]);
1652
7320
        subs.push_back(n[i]);
1653
7320
        return true;
1654
      }
1655
13545
      if (n[i].isVar())
1656
      {
1657
8764
        v[i] = n[i];
1658
      }
1659
4781
      else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1660
      {
1661
23
        v[i] = n[i][0];
1662
      }
1663
    }
1664
4657
    for (unsigned i = 0; i < 2; i++)
1665
    {
1666
5426
      TNode r1 = v[i];
1667
5426
      Node r2 = v[1 - i];
1668
3928
      if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1669
      {
1670
2430
        r2 = n[1 - i];
1671
2430
        if (v[i] != n[i])
1672
        {
1673
15
          Assert(TermUtil::isNegate(n[i].getKind()));
1674
15
          r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1675
        }
1676
        // TODO (#1706) : union find
1677
2430
        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1678
        {
1679
2430
          vars.push_back(r1);
1680
2430
          subs.push_back(r2);
1681
2430
          return true;
1682
        }
1683
      }
1684
    }
1685
  }
1686
14061
  if (usePred)
1687
  {
1688
13559
    bool negated = n.getKind() == NOT;
1689
13559
    vars.push_back(negated ? n[0] : n);
1690
13559
    subs.push_back(negated ? d_false : d_true);
1691
13559
    return true;
1692
  }
1693
502
  return false;
1694
}
1695
1696
2003
Node ExtendedRewriter::extendedRewriteStrings(Node ret)
1697
{
1698
2003
  Node new_ret;
1699
4006
  Trace("q-ext-rewrite-debug")
1700
2003
      << "Extended rewrite strings : " << ret << std::endl;
1701
1702
2003
  if (ret.getKind() == EQUAL)
1703
  {
1704
270
    new_ret = strings::SequencesRewriter(nullptr).rewriteEqualityExt(ret);
1705
  }
1706
1707
2003
  return new_ret;
1708
}
1709
1710
62294
void ExtendedRewriter::debugExtendedRewrite(Node n,
1711
                                            Node ret,
1712
                                            const char* c) const
1713
{
1714
62294
  if (Trace.isOn("q-ext-rewrite"))
1715
  {
1716
    if (!ret.isNull())
1717
    {
1718
      Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1719
      Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1720
                             << " rewrites to " << ret << std::endl;
1721
    }
1722
  }
1723
62294
}
1724
1725
} /* CVC4::theory::quantifiers namespace */
1726
} /* CVC4::theory namespace */
1727
26685
} /* CVC4 namespace */