GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/extended_rewrite.cpp Lines: 723 841 86.0 %
Date: 2021-08-16 Branches: 1764 3794 46.5 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of extended rewriting techniques.
14
 */
15
16
#include "theory/quantifiers/extended_rewrite.h"
17
18
#include <sstream>
19
20
#include "theory/arith/arith_msum.h"
21
#include "theory/bv/theory_bv_utils.h"
22
#include "theory/datatypes/datatypes_rewriter.h"
23
#include "theory/quantifiers/term_util.h"
24
#include "theory/rewriter.h"
25
#include "theory/strings/sequences_rewriter.h"
26
#include "theory/theory.h"
27
28
using namespace cvc5::kind;
29
using namespace std;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
35
struct ExtRewriteAttributeId
36
{
37
};
38
typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
39
40
struct ExtRewriteAggAttributeId
41
{
42
};
43
typedef expr::Attribute<ExtRewriteAggAttributeId, Node> ExtRewriteAggAttribute;
44
45
11438
ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
46
{
47
11438
  d_true = NodeManager::currentNM()->mkConst(true);
48
11438
  d_false = NodeManager::currentNM()->mkConst(false);
49
11438
}
50
51
222984
void ExtendedRewriter::setCache(Node n, Node ret) const
52
{
53
222984
  if (d_aggr)
54
  {
55
    ExtRewriteAggAttribute erga;
56
222336
    n.setAttribute(erga, ret);
57
  }
58
  else
59
  {
60
    ExtRewriteAttribute era;
61
648
    n.setAttribute(era, ret);
62
  }
63
222984
}
64
65
575331
Node ExtendedRewriter::getCache(Node n) const
66
{
67
575331
  if (d_aggr)
68
  {
69
574819
    if (n.hasAttribute(ExtRewriteAggAttribute()))
70
    {
71
462813
      return n.getAttribute(ExtRewriteAggAttribute());
72
    }
73
  }
74
  else
75
  {
76
512
    if (n.hasAttribute(ExtRewriteAttribute()))
77
    {
78
188
      return n.getAttribute(ExtRewriteAttribute());
79
    }
80
  }
81
112330
  return Node::null();
82
}
83
84
224757
bool ExtendedRewriter::addToChildren(Node nc,
85
                                     std::vector<Node>& children,
86
                                     bool dropDup) const
87
{
88
  // If the operator is non-additive, do not consider duplicates
89
224757
  if (dropDup
90
224757
      && std::find(children.begin(), children.end(), nc) != children.end())
91
  {
92
112
    return false;
93
  }
94
224645
  children.push_back(nc);
95
224645
  return true;
96
}
97
98
575331
Node ExtendedRewriter::extendedRewrite(Node n) const
99
{
100
575331
  n = Rewriter::rewrite(n);
101
102
  // has it already been computed?
103
1150662
  Node ncache = getCache(n);
104
575331
  if (!ncache.isNull())
105
  {
106
463001
    return ncache;
107
  }
108
109
224660
  Node ret = n;
110
112330
  NodeManager* nm = NodeManager::currentNM();
111
112
  //--------------------pre-rewrite
113
112330
  if (d_aggr)
114
  {
115
222336
    Node pre_new_ret;
116
112006
    if (ret.getKind() == IMPLIES)
117
    {
118
22
      pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
119
22
      debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
120
    }
121
111984
    else if (ret.getKind() == XOR)
122
    {
123
107
      pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
124
107
      debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
125
    }
126
111877
    else if (ret.getKind() == NOT)
127
    {
128
11182
      pre_new_ret = extendedRewriteNnf(ret);
129
11182
      debugExtendedRewrite(ret, pre_new_ret, "NNF");
130
    }
131
112006
    if (!pre_new_ret.isNull())
132
    {
133
1676
      ret = extendedRewrite(pre_new_ret);
134
135
3352
      Trace("q-ext-rewrite-debug")
136
1676
          << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
137
1676
      setCache(n, ret);
138
1676
      return ret;
139
    }
140
  }
141
  //--------------------end pre-rewrite
142
143
  //--------------------rewrite children
144
110654
  if (n.getNumChildren() > 0)
145
  {
146
210888
    std::vector<Node> children;
147
105444
    if (n.getMetaKind() == metakind::PARAMETERIZED)
148
    {
149
4600
      children.push_back(n.getOperator());
150
    }
151
105444
    Kind k = n.getKind();
152
105444
    bool childChanged = false;
153
105444
    bool isNonAdditive = TermUtil::isNonAdditive(k);
154
    // We flatten associative operators below, which requires k to be n-ary.
155
105444
    bool isAssoc = TermUtil::isAssoc(k, true);
156
329877
    for (unsigned i = 0; i < n.getNumChildren(); i++)
157
    {
158
448866
      Node nc = extendedRewrite(n[i]);
159
224433
      childChanged = nc != n[i] || childChanged;
160
224433
      if (isAssoc && nc.getKind() == n.getKind())
161
      {
162
924
        for (const Node& ncc : nc)
163
        {
164
624
          if (!addToChildren(ncc, children, isNonAdditive))
165
          {
166
8
            childChanged = true;
167
          }
168
        }
169
      }
170
224133
      else if (!addToChildren(nc, children, isNonAdditive))
171
      {
172
104
        childChanged = true;
173
      }
174
    }
175
105444
    Assert(!children.empty());
176
    // Some commutative operators have rewriters that are agnostic to order,
177
    // thus, we sort here.
178
105444
    if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
179
    {
180
56819
      childChanged = true;
181
56819
      std::sort(children.begin(), children.end());
182
    }
183
105444
    if (childChanged)
184
    {
185
65379
      if (isNonAdditive && children.size() == 1)
186
      {
187
        // we may have subsumed children down to one
188
24
        ret = children[0];
189
      }
190
65355
      else if (isAssoc
191
65355
               && children.size() > kind::metakind::getMaxArityForKind(k))
192
      {
193
2
        Assert(kind::metakind::getMaxArityForKind(k) >= 2);
194
        // kind may require binary construction
195
2
        ret = children[0];
196
6
        for (unsigned i = 1, nchild = children.size(); i < nchild; i++)
197
        {
198
4
          ret = nm->mkNode(k, ret, children[i]);
199
        }
200
      }
201
      else
202
      {
203
65353
        ret = nm->mkNode(k, children);
204
      }
205
    }
206
  }
207
110654
  ret = Rewriter::rewrite(ret);
208
  //--------------------end rewrite children
209
210
  // now, do extended rewrite
211
221308
  Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
212
110654
                               << " (from " << n << ")" << std::endl;
213
221308
  Node new_ret;
214
215
  //---------------------- theory-independent post-rewriting
216
110654
  if (ret.getKind() == ITE)
217
  {
218
10860
    new_ret = extendedRewriteIte(ITE, ret);
219
  }
220
99794
  else if (ret.getKind() == AND || ret.getKind() == OR)
221
  {
222
13497
    new_ret = extendedRewriteAndOr(ret);
223
  }
224
86297
  else if (ret.getKind() == EQUAL)
225
  {
226
14039
    new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
227
14039
    debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
228
  }
229
110654
  Assert(new_ret.isNull() || new_ret != ret);
230
110654
  if (new_ret.isNull() && ret.getKind() != ITE)
231
  {
232
    // simple ITE pulling
233
98002
    new_ret = extendedRewritePullIte(ITE, ret);
234
  }
235
  //----------------------end theory-independent post-rewriting
236
237
  //----------------------theory-specific post-rewriting
238
110654
  if (new_ret.isNull())
239
  {
240
    TheoryId tid;
241
96713
    if (ret.getKind() == ITE)
242
    {
243
9894
      tid = Theory::theoryOf(ret.getType());
244
    }
245
    else
246
    {
247
86819
      tid = Theory::theoryOf(ret);
248
    }
249
193426
    Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
250
96713
                                 << std::endl;
251
96713
    if (tid == THEORY_STRINGS)
252
    {
253
1315
      new_ret = extendedRewriteStrings(ret);
254
    }
255
  }
256
  //----------------------end theory-specific post-rewriting
257
258
  //----------------------aggressive rewrites
259
110654
  if (new_ret.isNull() && d_aggr)
260
  {
261
96192
    new_ret = extendedRewriteAggr(ret);
262
  }
263
  //----------------------end aggressive rewrites
264
265
110654
  setCache(n, ret);
266
110654
  if (!new_ret.isNull())
267
  {
268
15169
    ret = extendedRewrite(new_ret);
269
  }
270
221308
  Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
271
110654
                               << std::endl;
272
110654
  if (Trace.isOn("q-ext-rewrite-nf"))
273
  {
274
    if (n == ret)
275
    {
276
      Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
277
    }
278
  }
279
110654
  setCache(n, ret);
280
110654
  return ret;
281
}
282
283
96192
Node ExtendedRewriter::extendedRewriteAggr(Node n) const
284
{
285
96192
  Node new_ret;
286
192384
  Trace("q-ext-rewrite-debug2")
287
96192
      << "Do aggressive rewrites on " << n << std::endl;
288
96192
  bool polarity = n.getKind() != NOT;
289
192384
  Node ret_atom = n.getKind() == NOT ? n[0] : n;
290
201637
  if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
291
283444
      || ret_atom.getKind() == GEQ)
292
  {
293
    // ITE term removal in polynomials
294
    // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
295
39228
    Trace("q-ext-rewrite-debug2")
296
19614
        << "Compute monomial sum " << ret_atom << std::endl;
297
    // compute monomial sum
298
39228
    std::map<Node, Node> msum;
299
19614
    if (ArithMSum::getMonomialSumLit(ret_atom, msum))
300
    {
301
65806
      for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
302
           ++itm)
303
      {
304
93409
        Node v = itm->first;
305
94434
        Trace("q-ext-rewrite-debug2")
306
47217
            << itm->first << " * " << itm->second << std::endl;
307
47217
        if (v.getKind() == ITE)
308
        {
309
11955
          Node veq;
310
6490
          int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
311
6490
          if (res != 0)
312
          {
313
12980
            Trace("q-ext-rewrite-debug")
314
6490
                << "  have ITE relation, solved form : " << veq << std::endl;
315
            // try pulling ITE
316
6490
            new_ret = extendedRewritePullIte(ITE, veq);
317
6490
            if (!new_ret.isNull())
318
            {
319
1025
              if (!polarity)
320
              {
321
                new_ret = new_ret.negate();
322
              }
323
1025
              break;
324
            }
325
          }
326
          else
327
          {
328
            Trace("q-ext-rewrite-debug")
329
                << "  failed to isolate " << v << " in " << n << std::endl;
330
          }
331
        }
332
      }
333
    }
334
    else
335
    {
336
      Trace("q-ext-rewrite-debug")
337
          << "  failed to get monomial sum of " << n << std::endl;
338
    }
339
  }
340
  // TODO (#1706) : conditional rewriting, condition merging
341
192384
  return new_ret;
342
}
343
344
20123
Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const
345
{
346
20123
  Assert(n.getKind() == itek);
347
20123
  Assert(n[1] != n[2]);
348
349
20123
  NodeManager* nm = NodeManager::currentNM();
350
351
20123
  Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
352
353
40246
  Node flip_cond;
354
20123
  if (n[0].getKind() == NOT)
355
  {
356
    flip_cond = n[0][0];
357
  }
358
20123
  else if (n[0].getKind() == OR)
359
  {
360
    // a | b ---> ~( ~a & ~b )
361
57
    flip_cond = TermUtil::simpleNegate(n[0]);
362
  }
363
20123
  if (!flip_cond.isNull())
364
  {
365
114
    Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
366
    // only print debug trace if full=true
367
57
    if (full)
368
    {
369
57
      debugExtendedRewrite(n, new_ret, "ITE flip");
370
    }
371
57
    return new_ret;
372
  }
373
  // Boolean true/false return
374
40132
  TypeNode tn = n.getType();
375
20066
  if (tn.isBoolean())
376
  {
377
29088
    for (unsigned i = 1; i <= 2; i++)
378
    {
379
19392
      if (n[i].isConst())
380
      {
381
        Node cond = i == 1 ? n[0] : n[0].negate();
382
        Node other = n[i == 1 ? 2 : 1];
383
        Kind retk = AND;
384
        if (n[i].getConst<bool>())
385
        {
386
          retk = OR;
387
        }
388
        else
389
        {
390
          cond = cond.negate();
391
        }
392
        Node new_ret = nm->mkNode(retk, cond, other);
393
        if (full)
394
        {
395
          // ite( A, true, B ) ---> A V B
396
          // ite( A, false, B ) ---> ~A /\ B
397
          // ite( A, B,  true ) ---> ~A V B
398
          // ite( A, B, false ) ---> A /\ B
399
          debugExtendedRewrite(n, new_ret, "ITE const return");
400
        }
401
        return new_ret;
402
      }
403
    }
404
  }
405
406
  // get entailed equalities in the condition
407
40132
  std::vector<Node> eq_conds;
408
20066
  Kind ck = n[0].getKind();
409
20066
  if (ck == EQUAL)
410
  {
411
6894
    eq_conds.push_back(n[0]);
412
  }
413
13172
  else if (ck == AND)
414
  {
415
772
    for (const Node& cn : n[0])
416
    {
417
528
      if (cn.getKind() == EQUAL)
418
      {
419
66
        eq_conds.push_back(cn);
420
      }
421
    }
422
  }
423
424
40132
  Node new_ret;
425
40132
  Node b;
426
40132
  Node e;
427
40132
  Node t1 = n[1];
428
40132
  Node t2 = n[2];
429
40132
  std::stringstream ss_reason;
430
431
26974
  for (const Node& eq : eq_conds)
432
  {
433
    // simple invariant ITE
434
20801
    for (unsigned i = 0; i <= 1; i++)
435
    {
436
      // ite( x = y ^ C, y, x ) ---> x
437
      // this is subsumed by the rewrites below
438
13893
      if (t2 == eq[i] && t1 == eq[1 - i])
439
      {
440
52
        new_ret = t2;
441
52
        ss_reason << "ITE simple rev subs";
442
52
        break;
443
      }
444
    }
445
6960
    if (!new_ret.isNull())
446
    {
447
52
      break;
448
    }
449
  }
450
20066
  if (new_ret.isNull())
451
  {
452
    // merging branches
453
59844
    for (unsigned i = 1; i <= 2; i++)
454
    {
455
39969
      if (n[i].getKind() == ITE)
456
      {
457
1100
        Node no = n[3 - i];
458
1450
        for (unsigned j = 1; j <= 2; j++)
459
        {
460
1039
          if (n[i][j] == no)
461
          {
462
            // e.g.
463
            // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
464
278
            Node nc1 = i == 2 ? n[0].negate() : n[0];
465
278
            Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
466
278
            Node new_cond = nm->mkNode(AND, nc1, nc2);
467
139
            new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
468
139
            ss_reason << "ITE merge branch";
469
139
            break;
470
          }
471
        }
472
      }
473
39969
      if (!new_ret.isNull())
474
      {
475
139
        break;
476
      }
477
    }
478
  }
479
480
20066
  if (new_ret.isNull() && d_aggr)
481
  {
482
    // If x is less than t based on an ordering, then we use { x -> t } as a
483
    // substitution to the children of ite( x = t ^ C, s, t ) below.
484
39742
    std::vector<Node> vars;
485
39742
    std::vector<Node> subs;
486
19871
    inferSubstitution(n[0], vars, subs, true);
487
488
19871
    if (!vars.empty())
489
    {
490
      // reverse substitution to opposite child
491
      // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
492
      // We can use ordinary substitute since the result of the substitution
493
      // is not being returned. In other words, nn is only being used to query
494
      // whether the second branch is a generalization of the first.
495
      Node nn =
496
39742
          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
497
19871
      if (nn != t2)
498
      {
499
2416
        nn = Rewriter::rewrite(nn);
500
2416
        if (nn == t1)
501
        {
502
8
          new_ret = t2;
503
8
          ss_reason << "ITE rev subs";
504
        }
505
      }
506
507
      // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
508
      // must use partial substitute here, to avoid substitution into witness
509
39742
      std::map<Kind, bool> rkinds;
510
19871
      nn = partialSubstitute(t1, vars, subs, rkinds);
511
19871
      nn = Rewriter::rewrite(nn);
512
19871
      if (nn != t1)
513
      {
514
        // If full=false, then we've duplicated a term u in the children of n.
515
        // For example, when ITE pulling, we have n is of the form:
516
        //   ite( C, f( u, t1 ), f( u, t2 ) )
517
        // We must show that at least one copy of u dissappears in this case.
518
1312
        if (nn == t2)
519
        {
520
2
          new_ret = nn;
521
2
          ss_reason << "ITE subs invariant";
522
        }
523
1310
        else if (full || nn.isConst())
524
        {
525
681
          new_ret = nm->mkNode(itek, n[0], nn, t2);
526
681
          ss_reason << "ITE subs";
527
        }
528
      }
529
    }
530
19871
    if (new_ret.isNull())
531
    {
532
      // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
533
      // use partial substitute to avoid substitution into witness
534
38360
      std::map<Node, Node> assign;
535
19180
      assign[n[0]] = d_false;
536
38360
      std::map<Kind, bool> rkinds;
537
38360
      Node nn = partialSubstitute(t2, assign, rkinds);
538
19180
      if (nn != t2)
539
      {
540
187
        nn = Rewriter::rewrite(nn);
541
187
        if (nn == t1)
542
        {
543
2
          new_ret = nn;
544
2
          ss_reason << "ITE subs invariant false";
545
        }
546
185
        else if (full || nn.isConst())
547
        {
548
178
          new_ret = nm->mkNode(itek, n[0], t1, nn);
549
178
          ss_reason << "ITE subs false";
550
        }
551
      }
552
    }
553
  }
554
555
  // only print debug trace if full=true
556
20066
  if (!new_ret.isNull() && full)
557
  {
558
909
    debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
559
  }
560
561
20066
  return new_ret;
562
}
563
564
13497
Node ExtendedRewriter::extendedRewriteAndOr(Node n) const
565
{
566
  // all the below rewrites are aggressive
567
13497
  if (!d_aggr)
568
  {
569
18
    return Node::null();
570
  }
571
26958
  Node new_ret;
572
  // we allow substitutions to recurse over any kind, except WITNESS which is
573
  // managed by partialSubstitute.
574
26958
  std::map<Kind, bool> bcp_kinds;
575
13479
  new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
576
13479
  if (!new_ret.isNull())
577
  {
578
1310
    debugExtendedRewrite(n, new_ret, "Bool bcp");
579
1310
    return new_ret;
580
  }
581
  // factoring
582
12169
  new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
583
12169
  if (!new_ret.isNull())
584
  {
585
69
    debugExtendedRewrite(n, new_ret, "Bool factoring");
586
69
    return new_ret;
587
  }
588
589
  // equality resolution
590
12100
  new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
591
12100
  debugExtendedRewrite(n, new_ret, "Bool eq res");
592
12100
  return new_ret;
593
}
594
595
104492
Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const
596
{
597
104492
  Assert(n.getKind() != ITE);
598
104492
  if (n.isClosure())
599
  {
600
    // don't pull ITE out of quantifiers
601
95
    return n;
602
  }
603
104397
  NodeManager* nm = NodeManager::currentNM();
604
208794
  TypeNode tn = n.getType();
605
208794
  std::vector<Node> children;
606
104397
  bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
607
104397
  if (hasOp)
608
  {
609
4625
    children.push_back(n.getOperator());
610
  }
611
104397
  unsigned nchildren = n.getNumChildren();
612
303218
  for (unsigned i = 0; i < nchildren; i++)
613
  {
614
198821
    children.push_back(n[i]);
615
  }
616
208794
  std::map<unsigned, std::map<unsigned, Node> > ite_c;
617
286996
  for (unsigned i = 0; i < nchildren; i++)
618
  {
619
    // only pull ITEs apart if we are aggressive
620
583197
    if (n[i].getKind() == itek
621
583197
        && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
622
    {
623
21230
      unsigned ii = hasOp ? i + 1 : i;
624
63690
      for (unsigned j = 0; j < 2; j++)
625
      {
626
42460
        children[ii] = n[i][j + 1];
627
84920
        Node pull = nm->mkNode(n.getKind(), children);
628
84920
        Node pullr = Rewriter::rewrite(pull);
629
42460
        children[ii] = n[i];
630
42460
        ite_c[i][j] = pullr;
631
      }
632
21230
      if (ite_c[i][0] == ite_c[i][1])
633
      {
634
        // ITE dual invariance
635
        // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
636
        // f( t1..ite( A, s1, s2 )..tn ) ---> t
637
450
        debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
638
450
        return ite_c[i][0];
639
      }
640
20780
      if (d_aggr)
641
      {
642
40825
        if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
643
33926
            && !n[1 - i].getType().isBoolean() && tn.isBoolean())
644
        {
645
          // always pull variable or constant with binary (theory) predicate
646
          // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
647
16150
          Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
648
8075
          debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
649
8075
          return new_ret;
650
        }
651
32985
        for (unsigned j = 0; j < 2; j++)
652
        {
653
43839
          Node pullr = ite_c[i][j];
654
23557
          if (pullr.isConst() || pullr == n[i][j + 1])
655
          {
656
            // ITE single child elimination
657
            // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
658
            // implies
659
            // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
660
6550
            Node new_ret;
661
3275
            if (tn.isBoolean() && pullr.isConst())
662
            {
663
              // remove false/true child immediately
664
824
              bool pol = pullr.getConst<bool>();
665
1648
              std::vector<Node> new_children;
666
824
              new_children.push_back((j == 0) == pol ? n[i][0]
667
                                                     : n[i][0].negate());
668
824
              new_children.push_back(ite_c[i][1 - j]);
669
824
              new_ret = nm->mkNode(pol ? OR : AND, new_children);
670
824
              debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
671
            }
672
            else
673
            {
674
2451
              new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
675
2451
              debugExtendedRewrite(n, new_ret, "ITE single elim");
676
            }
677
3275
            return new_ret;
678
          }
679
        }
680
      }
681
    }
682
  }
683
92597
  if (d_aggr)
684
  {
685
101391
    for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
686
    {
687
18533
      Node nite = n[ip.first];
688
9423
      Assert(nite.getKind() == itek);
689
      // now, simply pull the ITE and try ITE rewrites
690
18533
      Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
691
9423
      pull_ite = Rewriter::rewrite(pull_ite);
692
9423
      if (pull_ite.getKind() == ITE)
693
      {
694
18373
        Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
695
9263
        if (!new_pull_ite.isNull())
696
        {
697
153
          debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
698
153
          return new_pull_ite;
699
        }
700
      }
701
      else
702
      {
703
        // A general rewrite could eliminate the ITE by pulling.
704
        // An example is:
705
        //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
706
        //   ite( C, ~~x, ite( C, y, x ) ) --->
707
        //   x
708
        // where ~ is bitvector negation.
709
160
        debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
710
160
        return pull_ite;
711
      }
712
    }
713
  }
714
715
92284
  return Node::null();
716
}
717
718
11182
Node ExtendedRewriter::extendedRewriteNnf(Node ret) const
719
{
720
11182
  Assert(ret.getKind() == NOT);
721
722
11182
  Kind nk = ret[0].getKind();
723
11182
  bool neg_ch = false;
724
11182
  bool neg_ch_1 = false;
725
11182
  if (nk == AND || nk == OR)
726
  {
727
1453
    neg_ch = true;
728
1453
    nk = nk == AND ? OR : AND;
729
  }
730
9729
  else if (nk == IMPLIES)
731
  {
732
    neg_ch = true;
733
    neg_ch_1 = true;
734
    nk = AND;
735
  }
736
9729
  else if (nk == ITE)
737
  {
738
83
    neg_ch = true;
739
83
    neg_ch_1 = true;
740
  }
741
9646
  else if (nk == XOR)
742
  {
743
7
    nk = EQUAL;
744
  }
745
9639
  else if (nk == EQUAL && ret[0][0].getType().isBoolean())
746
  {
747
4
    neg_ch_1 = true;
748
  }
749
  else
750
  {
751
9635
    return Node::null();
752
  }
753
754
3094
  std::vector<Node> new_children;
755
5229
  for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
756
  {
757
7364
    Node c = ret[0][i];
758
3682
    c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
759
3682
    new_children.push_back(c);
760
  }
761
1547
  return NodeManager::currentNM()->mkNode(nk, new_children);
762
}
763
764
13479
Node ExtendedRewriter::extendedRewriteBcp(Kind andk,
765
                                          Kind ork,
766
                                          Kind notk,
767
                                          std::map<Kind, bool>& bcp_kinds,
768
                                          Node ret) const
769
{
770
13479
  Kind k = ret.getKind();
771
13479
  Assert(k == andk || k == ork);
772
13479
  Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
773
774
13479
  NodeManager* nm = NodeManager::currentNM();
775
776
26958
  TypeNode tn = ret.getType();
777
26958
  Node truen = TermUtil::mkTypeMaxValue(tn);
778
26958
  Node falsen = TermUtil::mkTypeValue(tn, 0);
779
780
  // terms to process
781
26958
  std::vector<Node> to_process;
782
47196
  for (const Node& cn : ret)
783
  {
784
33717
    to_process.push_back(cn);
785
  }
786
  // the processing terms
787
26958
  std::vector<Node> clauses;
788
  // the terms we have propagated information to
789
26958
  std::unordered_set<Node> prop_clauses;
790
  // the assignment
791
26958
  std::map<Node, Node> assign;
792
26958
  std::vector<Node> avars;
793
26958
  std::vector<Node> asubs;
794
795
13479
  Kind ok = k == andk ? ork : andk;
796
  // global polarity : when k=ork, everything is negated
797
13479
  bool gpol = k == andk;
798
799
1172
  do
800
  {
801
    // process the current nodes
802
43781
    while (!to_process.empty())
803
    {
804
29296
      std::vector<Node> new_to_process;
805
49551
      for (const Node& cn : to_process)
806
      {
807
34986
        Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
808
34986
        Kind cnk = cn.getKind();
809
34986
        bool pol = cnk != notk;
810
69806
        Node cln = cnk == notk ? cn[0] : cn;
811
34986
        Assert(cln.getKind() != notk);
812
34986
        if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
813
        {
814
          // flatten
815
240
          for (const Node& ccln : cln)
816
          {
817
320
            Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
818
160
            new_to_process.push_back(lccln);
819
          }
820
        }
821
        else
822
        {
823
          // add it to the assignment
824
69646
          Node val = gpol == pol ? truen : falsen;
825
34906
          std::map<Node, Node>::iterator it = assign.find(cln);
826
69812
          Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
827
34906
                               << std::endl;
828
34906
          if (it != assign.end())
829
          {
830
168
            if (val != it->second)
831
            {
832
166
              Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
833
              // a conflicting assignment: we are done
834
166
              return gpol ? falsen : truen;
835
            }
836
          }
837
          else
838
          {
839
34738
            assign[cln] = val;
840
34738
            avars.push_back(cln);
841
34738
            asubs.push_back(val);
842
          }
843
844
          // also, treat it as clause if possible
845
69480
          if (cln.getNumChildren() > 0
846
60174
              && (bcp_kinds.empty()
847
34740
                  || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
848
          {
849
50868
            if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
850
25434
                && prop_clauses.find(cn) == prop_clauses.end())
851
            {
852
25434
              Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
853
25434
              clauses.push_back(cn);
854
            }
855
          }
856
        }
857
      }
858
14565
      to_process.clear();
859
14565
      to_process.insert(
860
29130
          to_process.end(), new_to_process.begin(), new_to_process.end());
861
    }
862
863
    // apply substitution to all subterms of clauses
864
28970
    std::vector<Node> new_clauses;
865
40750
    for (const Node& c : clauses)
866
    {
867
26265
      bool cpol = c.getKind() != notk;
868
52530
      Node ca = c.getKind() == notk ? c[0] : c;
869
26265
      bool childChanged = false;
870
52530
      std::vector<Node> ccs_children;
871
81903
      for (const Node& cc : ca)
872
      {
873
        // always use partial substitute, to avoid substitution in witness
874
55638
        Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
875
        // substitution is only applicable to compatible kinds in bcp_kinds
876
111276
        Node ccs = partialSubstitute(cc, assign, bcp_kinds);
877
55638
        childChanged = childChanged || ccs != cc;
878
55638
        ccs_children.push_back(ccs);
879
      }
880
26265
      if (childChanged)
881
      {
882
1183
        if (ca.getMetaKind() == metakind::PARAMETERIZED)
883
        {
884
          ccs_children.insert(ccs_children.begin(), ca.getOperator());
885
        }
886
2366
        Node ccs = nm->mkNode(ca.getKind(), ccs_children);
887
1183
        ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
888
2366
        Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
889
1183
                             << std::endl;
890
1183
        ccs = Rewriter::rewrite(ccs);
891
1183
        Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
892
1183
        to_process.push_back(ccs);
893
        // store this as a node that propagation touched. This marks c so that
894
        // it will not be included in the final construction.
895
1183
        prop_clauses.insert(ca);
896
      }
897
      else
898
      {
899
25082
        new_clauses.push_back(c);
900
      }
901
    }
902
14485
    clauses.clear();
903
14485
    clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
904
14485
  } while (!to_process.empty());
905
906
  // remake the node
907
13313
  if (!prop_clauses.empty())
908
  {
909
2288
    std::vector<Node> children;
910
5665
    for (std::pair<const Node, Node>& l : assign)
911
    {
912
9042
      Node a = l.first;
913
      // if propagation did not touch a
914
4521
      if (prop_clauses.find(a) == prop_clauses.end())
915
      {
916
3338
        Assert(l.second == truen || l.second == falsen);
917
6676
        Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
918
3338
        children.push_back(ln);
919
      }
920
    }
921
2288
    Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
922
1144
    Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
923
1144
    return new_ret;
924
  }
925
926
12169
  return Node::null();
927
}
928
929
12169
Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
930
                                                Kind ork,
931
                                                Kind notk,
932
                                                Node n) const
933
{
934
12169
  Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
935
12169
  NodeManager* nm = NodeManager::currentNM();
936
937
12169
  Kind nk = n.getKind();
938
12169
  Assert(nk == andk || nk == ork);
939
12169
  Kind onk = nk == andk ? ork : andk;
940
  // count the number of times atoms occur
941
24338
  std::map<Node, std::vector<Node> > lit_to_cl;
942
24338
  std::map<Node, std::vector<Node> > cl_to_lits;
943
42050
  for (const Node& nc : n)
944
  {
945
29881
    Kind nck = nc.getKind();
946
29881
    if (nck == onk)
947
    {
948
13168
      for (const Node& ncl : nc)
949
      {
950
27399
        if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
951
27399
            == lit_to_cl[ncl].end())
952
        {
953
9133
          lit_to_cl[ncl].push_back(nc);
954
9133
          cl_to_lits[nc].push_back(ncl);
955
        }
956
      }
957
    }
958
    else
959
    {
960
25846
      lit_to_cl[nc].push_back(nc);
961
25846
      cl_to_lits[nc].push_back(nc);
962
    }
963
  }
964
  // get the maximum shared literal to factor
965
12169
  unsigned max_size = 0;
966
24338
  Node flit;
967
46975
  for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
968
  {
969
34806
    if (ltc.second.size() > max_size)
970
    {
971
12195
      max_size = ltc.second.size();
972
12195
      flit = ltc.first;
973
    }
974
  }
975
12169
  if (max_size > 1)
976
  {
977
    // do the factoring
978
138
    std::vector<Node> children;
979
138
    std::vector<Node> fchildren;
980
69
    std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
981
69
    std::vector<Node>& cls = itl->second;
982
419
    for (const Node& nc : n)
983
    {
984
350
      if (std::find(cls.begin(), cls.end(), nc) == cls.end())
985
      {
986
145
        children.push_back(nc);
987
      }
988
      else
989
      {
990
        // rebuild
991
205
        std::vector<Node>& lits = cl_to_lits[nc];
992
        std::vector<Node>::iterator itlfl =
993
205
            std::find(lits.begin(), lits.end(), flit);
994
205
        Assert(itlfl != lits.end());
995
205
        lits.erase(itlfl);
996
        // rebuild
997
205
        if (!lits.empty())
998
        {
999
410
          Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
1000
205
          fchildren.push_back(new_cl);
1001
        }
1002
      }
1003
    }
1004
    // rebuild the factored children
1005
69
    Assert(!fchildren.empty());
1006
138
    Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
1007
69
    children.push_back(nm->mkNode(onk, flit, fcn));
1008
138
    Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
1009
69
    Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
1010
69
    return ret;
1011
  }
1012
  else
1013
  {
1014
12100
    Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
1015
  }
1016
12100
  return Node::null();
1017
}
1018
1019
12100
Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
1020
                                            Kind ork,
1021
                                            Kind eqk,
1022
                                            Kind notk,
1023
                                            std::map<Kind, bool>& bcp_kinds,
1024
                                            Node n,
1025
                                            bool isXor) const
1026
{
1027
12100
  Assert(n.getKind() == andk || n.getKind() == ork);
1028
12100
  Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
1029
1030
12100
  NodeManager* nm = NodeManager::currentNM();
1031
12100
  Kind nk = n.getKind();
1032
12100
  bool gpol = (nk == andk);
1033
40921
  for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
1034
  {
1035
57941
    Node lit = n[i];
1036
29120
    if (lit.getKind() == eqk)
1037
    {
1038
      // eq is the equality we are basing a substitution on
1039
7967
      Node eq;
1040
4133
      if (gpol == isXor)
1041
      {
1042
        // can only turn disequality into equality if types are the same
1043
2184
        if (lit[1].getType() == lit.getType())
1044
        {
1045
          // t != s ---> ~t = s
1046
41
          if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1047
          {
1048
32
            eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1049
          }
1050
          else
1051
          {
1052
9
            eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1053
          }
1054
        }
1055
      }
1056
      else
1057
      {
1058
1949
        eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1059
      }
1060
4133
      if (!eq.isNull())
1061
      {
1062
        // see if it corresponds to a substitution
1063
3681
        std::vector<Node> vars;
1064
3681
        std::vector<Node> subs;
1065
1990
        if (inferSubstitution(eq, vars, subs))
1066
        {
1067
1632
          Assert(vars.size() == 1);
1068
2965
          std::vector<Node> children;
1069
1632
          bool childrenChanged = false;
1070
          // apply to all other children
1071
9015
          for (unsigned j = 0; j < nchild; j++)
1072
          {
1073
14766
            Node ccs = n[j];
1074
7383
            if (i != j)
1075
            {
1076
              // Substitution is only applicable to compatible kinds. We always
1077
              // use the partialSubstitute method to avoid substitution into
1078
              // witness terms.
1079
5751
              ccs = partialSubstitute(ccs, vars, subs, bcp_kinds);
1080
5751
              childrenChanged = childrenChanged || n[j] != ccs;
1081
            }
1082
7383
            children.push_back(ccs);
1083
          }
1084
1632
          if (childrenChanged)
1085
          {
1086
299
            return nm->mkNode(nk, children);
1087
          }
1088
        }
1089
      }
1090
    }
1091
  }
1092
1093
11801
  return Node::null();
1094
}
1095
1096
/** sort pairs by their second (unsigned) argument */
1097
27976
static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1098
                           const std::pair<Node, unsigned>& b)
1099
{
1100
27976
  return (a.second < b.second);
1101
}
1102
1103
/** A simple subsumption trie used to compute pairwise list subsets */
1104
84580
class SimpSubsumeTrie
1105
{
1106
 public:
1107
  /** the children of this node */
1108
  std::map<Node, SimpSubsumeTrie> d_children;
1109
  /** the term at this node */
1110
  Node d_data;
1111
  /** add term to the trie
1112
   *
1113
   * This adds term c to this trie, whose atom list is alist. This adds terms
1114
   * s to subsumes such that the atom list of s is a subset of the atom list
1115
   * of c. For example, say:
1116
   *   c1.alist = { A }
1117
   *   c2.alist = { C }
1118
   *   c3.alist = { B, C }
1119
   *   c4.alist = { A, B, D }
1120
   *   c5.alist = { A, B, C }
1121
   * If these terms are added in the order c1, c2, c3, c4, c5, then:
1122
   *   addTerm c1 results in subsumes = {}
1123
   *   addTerm c2 results in subsumes = {}
1124
   *   addTerm c3 results in subsumes = { c2 }
1125
   *   addTerm c4 results in subsumes = { c1 }
1126
   *   addTerm c5 results in subsumes = { c1, c2, c3 }
1127
   * Notice that the intended use case of this trie is to add term t before t'
1128
   * only when size( t.alist ) <= size( t'.alist ).
1129
   *
1130
   * The last two arguments describe the state of the path [t0...tn] we
1131
   * have followed in the trie during the recursive call.
1132
   * If doAdd = true,
1133
   *   then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1134
   *   we add c as the current node of this trie.
1135
   * If doAdd = false,
1136
   *   then t1...tn occur in alist.
1137
   */
1138
56269
  void addTerm(Node c,
1139
               std::vector<Node>& alist,
1140
               std::vector<Node>& subsumes,
1141
               unsigned index = 0,
1142
               bool doAdd = true)
1143
  {
1144
56269
    if (!d_data.isNull())
1145
    {
1146
      subsumes.push_back(d_data);
1147
    }
1148
56269
    if (doAdd)
1149
    {
1150
56269
      if (index == alist.size())
1151
      {
1152
27920
        d_data = c;
1153
27920
        return;
1154
      }
1155
    }
1156
    // try all children where we have this atom
1157
42409
    for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1158
    {
1159
14060
      if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1160
      {
1161
        cp.second.addTerm(c, alist, subsumes, 0, false);
1162
      }
1163
    }
1164
28349
    if (doAdd)
1165
    {
1166
28349
      d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1167
    }
1168
  }
1169
};
1170
1171
14039
Node ExtendedRewriter::extendedRewriteEqChain(
1172
    Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) const
1173
{
1174
14039
  Assert(ret.getKind() == eqk);
1175
1176
  // this rewrite is aggressive; it in fact has the precondition that other
1177
  // aggressive rewrites (including BCP) have been applied.
1178
14039
  if (!d_aggr)
1179
  {
1180
96
    return Node::null();
1181
  }
1182
1183
13943
  NodeManager* nm = NodeManager::currentNM();
1184
1185
27886
  TypeNode tn = ret[0].getType();
1186
1187
  // sort/cancelling for Boolean EQUAL/XOR-chains
1188
13943
  Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1189
1190
  // get the children on either side
1191
13943
  bool gpol = true;
1192
27886
  std::vector<Node> children;
1193
41829
  for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1194
  {
1195
55772
    Node curr = ret[r];
1196
    // assume, if necessary, right associative
1197
28082
    while (curr.getKind() == eqk && curr[0].getType() == tn)
1198
    {
1199
98
      children.push_back(curr[0]);
1200
98
      curr = curr[1];
1201
    }
1202
27886
    children.push_back(curr);
1203
  }
1204
1205
27886
  std::map<Node, bool> cstatus;
1206
  // add children to status
1207
41927
  for (const Node& c : children)
1208
  {
1209
55968
    Node a = c;
1210
27984
    if (a.getKind() == notk)
1211
    {
1212
209
      gpol = !gpol;
1213
209
      a = a[0];
1214
    }
1215
27984
    Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1216
27984
    std::map<Node, bool>::iterator itc = cstatus.find(a);
1217
27984
    if (itc == cstatus.end())
1218
    {
1219
27956
      cstatus[a] = true;
1220
    }
1221
    else
1222
    {
1223
      // cancels
1224
28
      cstatus.erase(a);
1225
28
      if (isXor)
1226
      {
1227
        gpol = !gpol;
1228
      }
1229
    }
1230
  }
1231
13943
  Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1232
1233
13943
  if (cstatus.empty())
1234
  {
1235
2
    return TermUtil::mkTypeConst(tn, gpol);
1236
  }
1237
1238
13941
  children.clear();
1239
1240
  // compute the atoms of each child
1241
13941
  Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1242
13941
  Trace("ext-rew-eqchain") << "  eqchain-simplify: get atoms...\n";
1243
27882
  std::map<Node, std::map<Node, bool> > atoms;
1244
27882
  std::map<Node, std::vector<Node> > alist;
1245
27882
  std::vector<std::pair<Node, unsigned> > atom_count;
1246
41869
  for (std::pair<const Node, bool>& cp : cstatus)
1247
  {
1248
27928
    if (!cp.second)
1249
    {
1250
      // already eliminated
1251
      continue;
1252
    }
1253
55856
    Node c = cp.first;
1254
27928
    Kind ck = c.getKind();
1255
27928
    Trace("ext-rew-eqchain") << "  process c = " << c << std::endl;
1256
27928
    if (ck == andk || ck == ork)
1257
    {
1258
596
      for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1259
      {
1260
1028
        Node cl = c[j];
1261
516
        bool pol = cl.getKind() != notk;
1262
1028
        Node ca = pol ? cl : cl[0];
1263
516
        bool newVal = (ck == andk ? !pol : pol);
1264
1032
        Trace("ext-rew-eqchain")
1265
516
            << "  atoms(" << c << ", " << ca << ") = " << newVal << std::endl;
1266
516
        Assert(atoms[c].find(ca) == atoms[c].end());
1267
        // polarity is flipped when we are AND
1268
516
        atoms[c][ca] = newVal;
1269
516
        alist[c].push_back(ca);
1270
1271
        // if this already exists as a child of the equality chain, eliminate.
1272
        // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1273
        // consider ( x & y ) a unit, whereas below it is expanded to
1274
        // ~( ~x | ~y ).
1275
516
        std::map<Node, bool>::iterator itc = cstatus.find(ca);
1276
516
        if (itc != cstatus.end() && itc->second)
1277
        {
1278
          // cancel it
1279
4
          cstatus[ca] = false;
1280
4
          cstatus[c] = false;
1281
          // make new child
1282
          // x = ( y | ~x ) ---> y & x
1283
          // x = ( y | x ) ---> ~y | x
1284
          // x = ( y & x ) ---> y | ~x
1285
          // x = ( y & ~x ) ---> ~y & ~x
1286
8
          std::vector<Node> new_children;
1287
12
          for (unsigned k = 0, nchildc = c.getNumChildren(); k < nchildc; k++)
1288
          {
1289
8
            if (j != k)
1290
            {
1291
4
              new_children.push_back(c[k]);
1292
            }
1293
          }
1294
8
          Node nc[2];
1295
4
          nc[0] = c[j];
1296
4
          nc[1] = new_children.size() == 1 ? new_children[0]
1297
                                           : nm->mkNode(ck, new_children);
1298
          // negate the proper child
1299
4
          unsigned nindex = (ck == andk) == pol ? 0 : 1;
1300
4
          nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1301
4
          Kind nk = pol ? ork : andk;
1302
          // store as new child
1303
4
          children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1304
4
          if (isXor)
1305
          {
1306
            gpol = !gpol;
1307
          }
1308
4
          break;
1309
        }
1310
84
      }
1311
    }
1312
    else
1313
    {
1314
27844
      bool pol = ck != notk;
1315
55688
      Node ca = pol ? c : c[0];
1316
27844
      atoms[c][ca] = pol;
1317
55688
      Trace("ext-rew-eqchain")
1318
27844
          << "  atoms(" << c << ", " << ca << ") = " << pol << std::endl;
1319
27844
      alist[c].push_back(ca);
1320
    }
1321
27928
    atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1322
  }
1323
  // sort the atoms in each atom list
1324
41869
  for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1325
41869
       it != alist.end();
1326
       ++it)
1327
  {
1328
27928
    std::sort(it->second.begin(), it->second.end());
1329
  }
1330
  // check subsumptions
1331
  // sort by #atoms
1332
13941
  std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1333
13941
  if (Trace.isOn("ext-rew-eqchain"))
1334
  {
1335
    for (const std::pair<Node, unsigned>& ac : atom_count)
1336
    {
1337
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << ac.first << " has "
1338
                               << ac.second << " atoms." << std::endl;
1339
    }
1340
    Trace("ext-rew-eqchain") << "  eqchain-simplify: compute subsumptions...\n";
1341
  }
1342
27882
  SimpSubsumeTrie sst;
1343
41869
  for (std::pair<const Node, bool>& cp : cstatus)
1344
  {
1345
27928
    if (!cp.second)
1346
    {
1347
      // already eliminated
1348
8
      continue;
1349
    }
1350
55840
    Node c = cp.first;
1351
27920
    std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1352
27920
    Assert(itc != atoms.end());
1353
55840
    Trace("ext-rew-eqchain") << "  - add term " << c << " with atom list "
1354
27920
                             << alist[c] << "...\n";
1355
55840
    std::vector<Node> subsumes;
1356
27920
    sst.addTerm(c, alist[c], subsumes);
1357
27920
    for (const Node& cc : subsumes)
1358
    {
1359
      if (!cstatus[cc])
1360
      {
1361
        // subsumes a child that was already eliminated
1362
        continue;
1363
      }
1364
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << c << " subsumes "
1365
                               << cc << std::endl;
1366
      // for each of the atoms in cc
1367
      std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1368
      Assert(itcc != atoms.end());
1369
      std::vector<Node> common_children;
1370
      std::vector<Node> diff_children;
1371
      for (const std::pair<const Node, bool>& ap : itcc->second)
1372
      {
1373
        // compare the polarity
1374
        Node a = ap.first;
1375
        bool polcc = ap.second;
1376
        Assert(itc->second.find(a) != itc->second.end());
1377
        bool polc = itc->second[a];
1378
        Trace("ext-rew-eqchain") << "    eqchain-simplify: atom " << a
1379
                                 << " has polarities : " << polc << " " << polcc
1380
                                 << "\n";
1381
        Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1382
        if (polc != polcc)
1383
        {
1384
          diff_children.push_back(lit);
1385
        }
1386
        else
1387
        {
1388
          common_children.push_back(lit);
1389
        }
1390
      }
1391
      std::vector<Node> rem_children;
1392
      for (const std::pair<const Node, bool>& ap : itc->second)
1393
      {
1394
        Node a = ap.first;
1395
        if (atoms[cc].find(a) == atoms[cc].end())
1396
        {
1397
          bool polc = ap.second;
1398
          rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1399
        }
1400
      }
1401
      Trace("ext-rew-eqchain")
1402
          << "    #common/diff/rem: " << common_children.size() << "/"
1403
          << diff_children.size() << "/" << rem_children.size() << "\n";
1404
      bool do_rewrite = false;
1405
      if (common_children.empty() && itc->second.size() == itcc->second.size()
1406
          && itcc->second.size() == 2)
1407
      {
1408
        // x | y = ~x | ~y ---> ~( x = y )
1409
        do_rewrite = true;
1410
        children.push_back(diff_children[0]);
1411
        children.push_back(diff_children[1]);
1412
        gpol = !gpol;
1413
        Trace("ext-rew-eqchain") << "    apply 2-child all-diff\n";
1414
      }
1415
      else if (common_children.empty() && diff_children.size() == 1)
1416
      {
1417
        do_rewrite = true;
1418
        // x = ( ~x | y ) ---> ~( ~x | ~y )
1419
        Node remn = rem_children.size() == 1 ? rem_children[0]
1420
                                             : nm->mkNode(ork, rem_children);
1421
        remn = TermUtil::mkNegate(notk, remn);
1422
        children.push_back(nm->mkNode(ork, diff_children[0], remn));
1423
        if (!isXor)
1424
        {
1425
          gpol = !gpol;
1426
        }
1427
        Trace("ext-rew-eqchain") << "    apply unit resolution\n";
1428
      }
1429
      else if (diff_children.size() == 1
1430
               && itc->second.size() == itcc->second.size())
1431
      {
1432
        // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1433
        do_rewrite = true;
1434
        Assert(!common_children.empty());
1435
        Node comn = common_children.size() == 1
1436
                        ? common_children[0]
1437
                        : nm->mkNode(ork, common_children);
1438
        children.push_back(comn);
1439
        if (isXor)
1440
        {
1441
          gpol = !gpol;
1442
        }
1443
        Trace("ext-rew-eqchain") << "    apply resolution\n";
1444
      }
1445
      else if (diff_children.empty())
1446
      {
1447
        do_rewrite = true;
1448
        if (rem_children.empty())
1449
        {
1450
          // x | y = x | y ---> true
1451
          // this can happen if we have ( ~x & ~y ) = ( x | y )
1452
          children.push_back(TermUtil::mkTypeMaxValue(tn));
1453
          if (isXor)
1454
          {
1455
            gpol = !gpol;
1456
          }
1457
          Trace("ext-rew-eqchain") << "    apply cancel\n";
1458
        }
1459
        else
1460
        {
1461
          // x | y = ( x | y | z ) ---> ( x | y | ~z )
1462
          Node remn = rem_children.size() == 1 ? rem_children[0]
1463
                                               : nm->mkNode(ork, rem_children);
1464
          remn = TermUtil::mkNegate(notk, remn);
1465
          Node comn = common_children.size() == 1
1466
                          ? common_children[0]
1467
                          : nm->mkNode(ork, common_children);
1468
          children.push_back(nm->mkNode(ork, comn, remn));
1469
          if (isXor)
1470
          {
1471
            gpol = !gpol;
1472
          }
1473
          Trace("ext-rew-eqchain") << "    apply subsume\n";
1474
        }
1475
      }
1476
      if (do_rewrite)
1477
      {
1478
        // eliminate the children, reverse polarity as needed
1479
        for (unsigned r = 0; r < 2; r++)
1480
        {
1481
          Node c_rem = r == 0 ? c : cc;
1482
          cstatus[c_rem] = false;
1483
          if (c_rem.getKind() == andk)
1484
          {
1485
            gpol = !gpol;
1486
          }
1487
        }
1488
        break;
1489
      }
1490
    }
1491
  }
1492
13941
  Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1493
1494
  // sorted right associative chain
1495
13941
  bool has_nvar = false;
1496
13941
  unsigned nvar_index = 0;
1497
41869
  for (std::pair<const Node, bool>& cp : cstatus)
1498
  {
1499
27928
    if (cp.second)
1500
    {
1501
27920
      if (!cp.first.isVar())
1502
      {
1503
19271
        has_nvar = true;
1504
19271
        nvar_index = children.size();
1505
      }
1506
27920
      children.push_back(cp.first);
1507
    }
1508
  }
1509
13941
  std::sort(children.begin(), children.end());
1510
1511
27882
  Node new_ret;
1512
13941
  if (!gpol)
1513
  {
1514
    // negate the constant child if it exists
1515
167
    unsigned nindex = has_nvar ? nvar_index : 0;
1516
167
    children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1517
  }
1518
13941
  new_ret = children.back();
1519
13941
  unsigned index = children.size() - 1;
1520
41907
  while (index > 0)
1521
  {
1522
13983
    index--;
1523
13983
    new_ret = nm->mkNode(eqk, children[index], new_ret);
1524
  }
1525
13941
  new_ret = Rewriter::rewrite(new_ret);
1526
13941
  if (new_ret != ret)
1527
  {
1528
112
    return new_ret;
1529
  }
1530
13829
  return Node::null();
1531
}
1532
1533
100440
Node ExtendedRewriter::partialSubstitute(
1534
    Node n,
1535
    const std::map<Node, Node>& assign,
1536
    const std::map<Kind, bool>& rkinds) const
1537
{
1538
200880
  std::unordered_map<TNode, Node> visited;
1539
100440
  std::unordered_map<TNode, Node>::iterator it;
1540
100440
  std::map<Node, Node>::const_iterator ita;
1541
200880
  std::vector<TNode> visit;
1542
200880
  TNode cur;
1543
100440
  visit.push_back(n);
1544
623059
  do
1545
  {
1546
723499
    cur = visit.back();
1547
723499
    visit.pop_back();
1548
723499
    it = visited.find(cur);
1549
1550
723499
    if (it == visited.end())
1551
    {
1552
341532
      ita = assign.find(cur);
1553
341532
      if (ita != assign.end())
1554
      {
1555
3092
        visited[cur] = ita->second;
1556
      }
1557
      else
1558
      {
1559
        // If rkinds is non-empty, then can only recurse on its kinds.
1560
        // We also always disallow substitution into witness. Notice that
1561
        // we disallow witness here, due to unsoundness when applying contextual
1562
        // substitutions over witness terms (see #4620).
1563
338440
        Kind k = cur.getKind();
1564
338440
        if (k != WITNESS && (rkinds.empty() || rkinds.find(k) != rkinds.end()))
1565
        {
1566
338440
          visited[cur] = Node::null();
1567
338440
          visit.push_back(cur);
1568
623059
          for (const Node& cn : cur)
1569
          {
1570
284619
            visit.push_back(cn);
1571
          }
1572
        }
1573
        else
1574
        {
1575
          visited[cur] = cur;
1576
        }
1577
      }
1578
    }
1579
381967
    else if (it->second.isNull())
1580
    {
1581
676880
      Node ret = cur;
1582
338440
      bool childChanged = false;
1583
676880
      std::vector<Node> children;
1584
338440
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
1585
      {
1586
7829
        children.push_back(cur.getOperator());
1587
      }
1588
623059
      for (const Node& cn : cur)
1589
      {
1590
284619
        it = visited.find(cn);
1591
284619
        Assert(it != visited.end());
1592
284619
        Assert(!it->second.isNull());
1593
284619
        childChanged = childChanged || cn != it->second;
1594
284619
        children.push_back(it->second);
1595
      }
1596
338440
      if (childChanged)
1597
      {
1598
4389
        ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1599
      }
1600
338440
      visited[cur] = ret;
1601
    }
1602
723499
  } while (!visit.empty());
1603
100440
  Assert(visited.find(n) != visited.end());
1604
100440
  Assert(!visited.find(n)->second.isNull());
1605
200880
  return visited[n];
1606
}
1607
1608
25622
Node ExtendedRewriter::partialSubstitute(
1609
    Node n,
1610
    const std::vector<Node>& vars,
1611
    const std::vector<Node>& subs,
1612
    const std::map<Kind, bool>& rkinds) const
1613
{
1614
25622
  Assert(vars.size() == subs.size());
1615
51244
  std::map<Node, Node> assign;
1616
51521
  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
1617
  {
1618
25899
    assign[vars[i]] = subs[i];
1619
  }
1620
51244
  return partialSubstitute(n, assign, rkinds);
1621
}
1622
1623
8866
Node ExtendedRewriter::solveEquality(Node n) const
1624
{
1625
  // TODO (#1706) : implement
1626
8866
  Assert(n.getKind() == EQUAL);
1627
1628
8866
  return Node::null();
1629
}
1630
1631
22375
bool ExtendedRewriter::inferSubstitution(Node n,
1632
                                         std::vector<Node>& vars,
1633
                                         std::vector<Node>& subs,
1634
                                         bool usePred) const
1635
{
1636
22375
  if (n.getKind() == AND)
1637
  {
1638
237
    bool ret = false;
1639
751
    for (const Node& nc : n)
1640
    {
1641
514
      bool cret = inferSubstitution(nc, vars, subs, usePred);
1642
514
      ret = ret || cret;
1643
    }
1644
237
    return ret;
1645
  }
1646
22138
  if (n.getKind() == EQUAL)
1647
  {
1648
    // see if it can be put into form x = y
1649
9347
    Node slv_eq = solveEquality(n);
1650
8866
    if (!slv_eq.isNull())
1651
    {
1652
      n = slv_eq;
1653
    }
1654
26598
    Node v[2];
1655
20566
    for (unsigned i = 0; i < 2; i++)
1656
    {
1657
17677
      if (n[i].isConst())
1658
      {
1659
5977
        vars.push_back(n[1 - i]);
1660
5977
        subs.push_back(n[i]);
1661
5977
        return true;
1662
      }
1663
11700
      if (n[i].isVar())
1664
      {
1665
8985
        v[i] = n[i];
1666
      }
1667
2715
      else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1668
      {
1669
16
        v[i] = n[i][0];
1670
      }
1671
    }
1672
3893
    for (unsigned i = 0; i < 2; i++)
1673
    {
1674
4416
      TNode r1 = v[i];
1675
4416
      Node r2 = v[1 - i];
1676
3412
      if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1677
      {
1678
2408
        r2 = n[1 - i];
1679
2408
        if (v[i] != n[i])
1680
        {
1681
10
          Assert(TermUtil::isNegate(n[i].getKind()));
1682
10
          r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1683
        }
1684
        // TODO (#1706) : union find
1685
2408
        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1686
        {
1687
2408
          vars.push_back(r1);
1688
2408
          subs.push_back(r2);
1689
2408
          return true;
1690
        }
1691
      }
1692
    }
1693
  }
1694
13753
  if (usePred)
1695
  {
1696
13395
    bool negated = n.getKind() == NOT;
1697
13395
    vars.push_back(negated ? n[0] : n);
1698
13395
    subs.push_back(negated ? d_false : d_true);
1699
13395
    return true;
1700
  }
1701
358
  return false;
1702
}
1703
1704
1315
Node ExtendedRewriter::extendedRewriteStrings(Node ret) const
1705
{
1706
1315
  Node new_ret;
1707
2630
  Trace("q-ext-rewrite-debug")
1708
1315
      << "Extended rewrite strings : " << ret << std::endl;
1709
1710
1315
  if (ret.getKind() == EQUAL)
1711
  {
1712
203
    new_ret = strings::SequencesRewriter(nullptr).rewriteEqualityExt(ret);
1713
  }
1714
1715
1315
  return new_ret;
1716
}
1717
1718
51908
void ExtendedRewriter::debugExtendedRewrite(Node n,
1719
                                            Node ret,
1720
                                            const char* c) const
1721
{
1722
51908
  if (Trace.isOn("q-ext-rewrite"))
1723
  {
1724
    if (!ret.isNull())
1725
    {
1726
      Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1727
      Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1728
                             << " rewrites to " << ret << std::endl;
1729
    }
1730
  }
1731
51908
}
1732
1733
}  // namespace quantifiers
1734
}  // namespace theory
1735
29340
}  // namespace cvc5