GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/extended_rewrite.cpp Lines: 782 842 92.9 %
Date: 2021-09-15 Branches: 1887 3794 49.7 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of extended rewriting techniques.
14
 */
15
16
#include "theory/quantifiers/extended_rewrite.h"
17
18
#include <sstream>
19
20
#include "theory/arith/arith_msum.h"
21
#include "theory/bv/theory_bv_utils.h"
22
#include "theory/datatypes/datatypes_rewriter.h"
23
#include "theory/quantifiers/term_util.h"
24
#include "theory/rewriter.h"
25
#include "theory/strings/sequences_rewriter.h"
26
#include "theory/theory.h"
27
28
using namespace cvc5::kind;
29
using namespace std;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
35
struct ExtRewriteAttributeId
36
{
37
};
38
typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
39
40
struct ExtRewriteAggAttributeId
41
{
42
};
43
typedef expr::Attribute<ExtRewriteAggAttributeId, Node> ExtRewriteAggAttribute;
44
45
244142
ExtendedRewriter::ExtendedRewriter(Rewriter& rew, bool aggr)
46
244142
    : d_rew(rew), d_aggr(aggr)
47
{
48
244142
  d_true = NodeManager::currentNM()->mkConst(true);
49
244142
  d_false = NodeManager::currentNM()->mkConst(false);
50
244142
}
51
52
407428
void ExtendedRewriter::setCache(Node n, Node ret) const
53
{
54
407428
  if (d_aggr)
55
  {
56
    ExtRewriteAggAttribute erga;
57
406780
    n.setAttribute(erga, ret);
58
  }
59
  else
60
  {
61
    ExtRewriteAttribute era;
62
648
    n.setAttribute(era, ret);
63
  }
64
407428
}
65
66
821881
Node ExtendedRewriter::getCache(Node n) const
67
{
68
821881
  if (d_aggr)
69
  {
70
821369
    if (n.hasAttribute(ExtRewriteAggAttribute()))
71
    {
72
615864
      return n.getAttribute(ExtRewriteAggAttribute());
73
    }
74
  }
75
  else
76
  {
77
512
    if (n.hasAttribute(ExtRewriteAttribute()))
78
    {
79
188
      return n.getAttribute(ExtRewriteAttribute());
80
    }
81
  }
82
205829
  return Node::null();
83
}
84
85
555819
bool ExtendedRewriter::addToChildren(Node nc,
86
                                     std::vector<Node>& children,
87
                                     bool dropDup) const
88
{
89
  // If the operator is non-additive, do not consider duplicates
90
555819
  if (dropDup
91
555819
      && std::find(children.begin(), children.end(), nc) != children.end())
92
  {
93
1077
    return false;
94
  }
95
554742
  children.push_back(nc);
96
554742
  return true;
97
}
98
99
821881
Node ExtendedRewriter::extendedRewrite(Node n) const
100
{
101
821881
  n = d_rew.rewrite(n);
102
103
  // has it already been computed?
104
1643762
  Node ncache = getCache(n);
105
821881
  if (!ncache.isNull())
106
  {
107
616052
    return ncache;
108
  }
109
110
411658
  Node ret = n;
111
205829
  NodeManager* nm = NodeManager::currentNM();
112
113
  //--------------------pre-rewrite
114
205829
  if (d_aggr)
115
  {
116
406780
    Node pre_new_ret;
117
205505
    if (ret.getKind() == IMPLIES)
118
    {
119
160
      pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
120
160
      debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
121
    }
122
205345
    else if (ret.getKind() == XOR)
123
    {
124
232
      pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
125
232
      debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
126
    }
127
205113
    else if (ret.getKind() == NOT)
128
    {
129
17377
      pre_new_ret = extendedRewriteNnf(ret);
130
17377
      debugExtendedRewrite(ret, pre_new_ret, "NNF");
131
    }
132
205505
    if (!pre_new_ret.isNull())
133
    {
134
4230
      ret = extendedRewrite(pre_new_ret);
135
136
8460
      Trace("q-ext-rewrite-debug")
137
4230
          << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
138
4230
      setCache(n, ret);
139
4230
      return ret;
140
    }
141
  }
142
  //--------------------end pre-rewrite
143
144
  //--------------------rewrite children
145
201599
  if (n.getNumChildren() > 0)
146
  {
147
366146
    std::vector<Node> children;
148
183073
    if (n.getMetaKind() == metakind::PARAMETERIZED)
149
    {
150
3841
      children.push_back(n.getOperator());
151
    }
152
183073
    Kind k = n.getKind();
153
183073
    bool childChanged = false;
154
183073
    bool isNonAdditive = TermUtil::isNonAdditive(k);
155
    // We flatten associative operators below, which requires k to be n-ary.
156
183073
    bool isAssoc = TermUtil::isAssoc(k, true);
157
730936
    for (unsigned i = 0; i < n.getNumChildren(); i++)
158
    {
159
1095726
      Node nc = extendedRewrite(n[i]);
160
547863
      childChanged = nc != n[i] || childChanged;
161
547863
      if (isAssoc && nc.getKind() == n.getKind())
162
      {
163
13856
        for (const Node& ncc : nc)
164
        {
165
10906
          if (!addToChildren(ncc, children, isNonAdditive))
166
          {
167
10
            childChanged = true;
168
          }
169
        }
170
      }
171
544913
      else if (!addToChildren(nc, children, isNonAdditive))
172
      {
173
1067
        childChanged = true;
174
      }
175
    }
176
183073
    Assert(!children.empty());
177
    // Some commutative operators have rewriters that are agnostic to order,
178
    // thus, we sort here.
179
183073
    if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
180
    {
181
118205
      childChanged = true;
182
118205
      std::sort(children.begin(), children.end());
183
    }
184
183073
    if (childChanged)
185
    {
186
130374
      if (isNonAdditive && children.size() == 1)
187
      {
188
        // we may have subsumed children down to one
189
35
        ret = children[0];
190
      }
191
130339
      else if (isAssoc
192
130339
               && children.size() > kind::metakind::getMaxArityForKind(k))
193
      {
194
2
        Assert(kind::metakind::getMaxArityForKind(k) >= 2);
195
        // kind may require binary construction
196
2
        ret = children[0];
197
6
        for (unsigned i = 1, nchild = children.size(); i < nchild; i++)
198
        {
199
4
          ret = nm->mkNode(k, ret, children[i]);
200
        }
201
      }
202
      else
203
      {
204
130337
        ret = nm->mkNode(k, children);
205
      }
206
    }
207
  }
208
201599
  ret = d_rew.rewrite(ret);
209
  //--------------------end rewrite children
210
211
  // now, do extended rewrite
212
403198
  Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
213
201599
                               << " (from " << n << ")" << std::endl;
214
403198
  Node new_ret;
215
216
  //---------------------- theory-independent post-rewriting
217
201599
  if (ret.getKind() == ITE)
218
  {
219
15953
    new_ret = extendedRewriteIte(ITE, ret);
220
  }
221
185646
  else if (ret.getKind() == AND || ret.getKind() == OR)
222
  {
223
53597
    new_ret = extendedRewriteAndOr(ret);
224
  }
225
132049
  else if (ret.getKind() == EQUAL)
226
  {
227
28215
    new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
228
28215
    debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
229
  }
230
201599
  Assert(new_ret.isNull() || new_ret != ret);
231
201599
  if (new_ret.isNull() && ret.getKind() != ITE)
232
  {
233
    // simple ITE pulling
234
175570
    new_ret = extendedRewritePullIte(ITE, ret);
235
  }
236
  //----------------------end theory-independent post-rewriting
237
238
  //----------------------theory-specific post-rewriting
239
201599
  if (new_ret.isNull())
240
  {
241
    TheoryId tid;
242
177346
    if (ret.getKind() == ITE)
243
    {
244
14535
      tid = Theory::theoryOf(ret.getType());
245
    }
246
    else
247
    {
248
162811
      tid = Theory::theoryOf(ret);
249
    }
250
354692
    Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
251
177346
                                 << std::endl;
252
177346
    if (tid == THEORY_STRINGS)
253
    {
254
4474
      new_ret = extendedRewriteStrings(ret);
255
    }
256
  }
257
  //----------------------end theory-specific post-rewriting
258
259
  //----------------------aggressive rewrites
260
201599
  if (new_ret.isNull() && d_aggr)
261
  {
262
176718
    new_ret = extendedRewriteAggr(ret);
263
  }
264
  //----------------------end aggressive rewrites
265
266
201599
  setCache(n, ret);
267
201599
  if (!new_ret.isNull())
268
  {
269
25646
    ret = extendedRewrite(new_ret);
270
  }
271
403198
  Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
272
201599
                               << std::endl;
273
201599
  if (Trace.isOn("q-ext-rewrite-nf"))
274
  {
275
    if (n == ret)
276
    {
277
      Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
278
    }
279
  }
280
201599
  setCache(n, ret);
281
201599
  return ret;
282
}
283
284
176718
Node ExtendedRewriter::extendedRewriteAggr(Node n) const
285
{
286
176718
  Node new_ret;
287
353436
  Trace("q-ext-rewrite-debug2")
288
176718
      << "Do aggressive rewrites on " << n << std::endl;
289
176718
  bool polarity = n.getKind() != NOT;
290
353436
  Node ret_atom = n.getKind() == NOT ? n[0] : n;
291
375685
  if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
292
517781
      || ret_atom.getKind() == GEQ)
293
  {
294
    // ITE term removal in polynomials
295
    // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
296
65396
    Trace("q-ext-rewrite-debug2")
297
32698
        << "Compute monomial sum " << ret_atom << std::endl;
298
    // compute monomial sum
299
65396
    std::map<Node, Node> msum;
300
32698
    if (ArithMSum::getMonomialSumLit(ret_atom, msum))
301
    {
302
109272
      for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
303
           ++itm)
304
      {
305
154231
        Node v = itm->first;
306
155314
        Trace("q-ext-rewrite-debug2")
307
77657
            << itm->first << " * " << itm->second << std::endl;
308
77657
        if (v.getKind() == ITE)
309
        {
310
14281
          Node veq;
311
7682
          int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
312
7682
          if (res != 0)
313
          {
314
15358
            Trace("q-ext-rewrite-debug")
315
7679
                << "  have ITE relation, solved form : " << veq << std::endl;
316
            // try pulling ITE
317
7679
            new_ret = extendedRewritePullIte(ITE, veq);
318
7679
            if (!new_ret.isNull())
319
            {
320
1083
              if (!polarity)
321
              {
322
                new_ret = new_ret.negate();
323
              }
324
1083
              break;
325
            }
326
          }
327
          else
328
          {
329
6
            Trace("q-ext-rewrite-debug")
330
3
                << "  failed to isolate " << v << " in " << n << std::endl;
331
          }
332
        }
333
      }
334
    }
335
    else
336
    {
337
      Trace("q-ext-rewrite-debug")
338
          << "  failed to get monomial sum of " << n << std::endl;
339
    }
340
  }
341
  // TODO (#1706) : conditional rewriting, condition merging
342
353436
  return new_ret;
343
}
344
345
32314
Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const
346
{
347
32314
  Assert(n.getKind() == itek);
348
32314
  Assert(n[1] != n[2]);
349
350
32314
  NodeManager* nm = NodeManager::currentNM();
351
352
32314
  Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
353
354
64628
  Node flip_cond;
355
32314
  if (n[0].getKind() == NOT)
356
  {
357
    flip_cond = n[0][0];
358
  }
359
32314
  else if (n[0].getKind() == OR)
360
  {
361
    // a | b ---> ~( ~a & ~b )
362
132
    flip_cond = TermUtil::simpleNegate(n[0]);
363
  }
364
32314
  if (!flip_cond.isNull())
365
  {
366
264
    Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
367
    // only print debug trace if full=true
368
132
    if (full)
369
    {
370
132
      debugExtendedRewrite(n, new_ret, "ITE flip");
371
    }
372
132
    return new_ret;
373
  }
374
  // Boolean true/false return
375
64364
  TypeNode tn = n.getType();
376
32182
  if (tn.isBoolean())
377
  {
378
55227
    for (unsigned i = 1; i <= 2; i++)
379
    {
380
36818
      if (n[i].isConst())
381
      {
382
        Node cond = i == 1 ? n[0] : n[0].negate();
383
        Node other = n[i == 1 ? 2 : 1];
384
        Kind retk = AND;
385
        if (n[i].getConst<bool>())
386
        {
387
          retk = OR;
388
        }
389
        else
390
        {
391
          cond = cond.negate();
392
        }
393
        Node new_ret = nm->mkNode(retk, cond, other);
394
        if (full)
395
        {
396
          // ite( A, true, B ) ---> A V B
397
          // ite( A, false, B ) ---> ~A /\ B
398
          // ite( A, B,  true ) ---> ~A V B
399
          // ite( A, B, false ) ---> A /\ B
400
          debugExtendedRewrite(n, new_ret, "ITE const return");
401
        }
402
        return new_ret;
403
      }
404
    }
405
  }
406
407
  // get entailed equalities in the condition
408
64364
  std::vector<Node> eq_conds;
409
32182
  Kind ck = n[0].getKind();
410
32182
  if (ck == EQUAL)
411
  {
412
11684
    eq_conds.push_back(n[0]);
413
  }
414
20498
  else if (ck == AND)
415
  {
416
26080
    for (const Node& cn : n[0])
417
    {
418
21549
      if (cn.getKind() == EQUAL)
419
      {
420
856
        eq_conds.push_back(cn);
421
      }
422
    }
423
  }
424
425
64364
  Node new_ret;
426
64364
  Node b;
427
64364
  Node e;
428
64364
  Node t1 = n[1];
429
64364
  Node t2 = n[2];
430
64364
  std::stringstream ss_reason;
431
432
44649
  for (const Node& eq : eq_conds)
433
  {
434
    // simple invariant ITE
435
37506
    for (unsigned i = 0; i <= 1; i++)
436
    {
437
      // ite( x = y ^ C, y, x ) ---> x
438
      // this is subsumed by the rewrites below
439
25039
      if (t2 == eq[i] && t1 == eq[1 - i])
440
      {
441
71
        new_ret = t2;
442
71
        ss_reason << "ITE simple rev subs";
443
71
        break;
444
      }
445
    }
446
12538
    if (!new_ret.isNull())
447
    {
448
71
      break;
449
    }
450
  }
451
32182
  if (new_ret.isNull())
452
  {
453
    // merging branches
454
95916
    for (unsigned i = 1; i <= 2; i++)
455
    {
456
64094
      if (n[i].getKind() == ITE)
457
      {
458
7110
        Node no = n[3 - i];
459
10245
        for (unsigned j = 1; j <= 2; j++)
460
        {
461
6979
          if (n[i][j] == no)
462
          {
463
            // e.g.
464
            // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
465
578
            Node nc1 = i == 2 ? n[0].negate() : n[0];
466
578
            Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
467
578
            Node new_cond = nm->mkNode(AND, nc1, nc2);
468
289
            new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
469
289
            ss_reason << "ITE merge branch";
470
289
            break;
471
          }
472
        }
473
      }
474
64094
      if (!new_ret.isNull())
475
      {
476
289
        break;
477
      }
478
    }
479
  }
480
481
32182
  if (new_ret.isNull() && d_aggr)
482
  {
483
    // If x is less than t based on an ordering, then we use { x -> t } as a
484
    // substitution to the children of ite( x = t ^ C, s, t ) below.
485
63636
    std::vector<Node> vars;
486
63636
    std::vector<Node> subs;
487
31818
    inferSubstitution(n[0], vars, subs, true);
488
489
31818
    if (!vars.empty())
490
    {
491
      // reverse substitution to opposite child
492
      // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
493
      // We can use ordinary substitute since the result of the substitution
494
      // is not being returned. In other words, nn is only being used to query
495
      // whether the second branch is a generalization of the first.
496
      Node nn =
497
63636
          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
498
31818
      if (nn != t2)
499
      {
500
8565
        nn = d_rew.rewrite(nn);
501
8565
        if (nn == t1)
502
        {
503
5
          new_ret = t2;
504
5
          ss_reason << "ITE rev subs";
505
        }
506
      }
507
508
      // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
509
      // must use partial substitute here, to avoid substitution into witness
510
63636
      std::map<Kind, bool> rkinds;
511
31818
      nn = partialSubstitute(t1, vars, subs, rkinds);
512
31818
      nn = d_rew.rewrite(nn);
513
31818
      if (nn != t1)
514
      {
515
        // If full=false, then we've duplicated a term u in the children of n.
516
        // For example, when ITE pulling, we have n is of the form:
517
        //   ite( C, f( u, t1 ), f( u, t2 ) )
518
        // We must show that at least one copy of u dissappears in this case.
519
5618
        if (nn == t2)
520
        {
521
31
          new_ret = nn;
522
31
          ss_reason << "ITE subs invariant";
523
        }
524
5587
        else if (full || nn.isConst())
525
        {
526
866
          new_ret = nm->mkNode(itek, n[0], nn, t2);
527
866
          ss_reason << "ITE subs";
528
        }
529
      }
530
    }
531
31818
    if (new_ret.isNull())
532
    {
533
      // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
534
      // use partial substitute to avoid substitution into witness
535
61832
      std::map<Node, Node> assign;
536
30916
      assign[n[0]] = d_false;
537
61832
      std::map<Kind, bool> rkinds;
538
61832
      Node nn = partialSubstitute(t2, assign, rkinds);
539
30916
      if (nn != t2)
540
      {
541
4049
        nn = d_rew.rewrite(nn);
542
4049
        if (nn == t1)
543
        {
544
3
          new_ret = nn;
545
3
          ss_reason << "ITE subs invariant false";
546
        }
547
4046
        else if (full || nn.isConst())
548
        {
549
202
          new_ret = nm->mkNode(itek, n[0], t1, nn);
550
202
          ss_reason << "ITE subs false";
551
        }
552
      }
553
    }
554
  }
555
556
  // only print debug trace if full=true
557
32182
  if (!new_ret.isNull() && full)
558
  {
559
1286
    debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
560
  }
561
562
32182
  return new_ret;
563
}
564
565
53597
Node ExtendedRewriter::extendedRewriteAndOr(Node n) const
566
{
567
  // all the below rewrites are aggressive
568
53597
  if (!d_aggr)
569
  {
570
18
    return Node::null();
571
  }
572
107158
  Node new_ret;
573
  // we allow substitutions to recurse over any kind, except WITNESS which is
574
  // managed by partialSubstitute.
575
107158
  std::map<Kind, bool> bcp_kinds;
576
53579
  new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
577
53579
  if (!new_ret.isNull())
578
  {
579
2287
    debugExtendedRewrite(n, new_ret, "Bool bcp");
580
2287
    return new_ret;
581
  }
582
  // factoring
583
51292
  new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
584
51292
  if (!new_ret.isNull())
585
  {
586
3860
    debugExtendedRewrite(n, new_ret, "Bool factoring");
587
3860
    return new_ret;
588
  }
589
590
  // equality resolution
591
47432
  new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
592
47432
  debugExtendedRewrite(n, new_ret, "Bool eq res");
593
47432
  return new_ret;
594
}
595
596
183249
Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const
597
{
598
183249
  Assert(n.getKind() != ITE);
599
183249
  if (n.isClosure())
600
  {
601
    // don't pull ITE out of quantifiers
602
251
    return n;
603
  }
604
182998
  NodeManager* nm = NodeManager::currentNM();
605
365996
  TypeNode tn = n.getType();
606
365996
  std::vector<Node> children;
607
182998
  bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
608
182998
  if (hasOp)
609
  {
610
3875
    children.push_back(n.getOperator());
611
  }
612
182998
  unsigned nchildren = n.getNumChildren();
613
618203
  for (unsigned i = 0; i < nchildren; i++)
614
  {
615
435205
    children.push_back(n[i]);
616
  }
617
365996
  std::map<unsigned, std::map<unsigned, Node> > ite_c;
618
599501
  for (unsigned i = 0; i < nchildren; i++)
619
  {
620
    // only pull ITEs apart if we are aggressive
621
1289229
    if (n[i].getKind() == itek
622
1289229
        && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
623
    {
624
29801
      unsigned ii = hasOp ? i + 1 : i;
625
89403
      for (unsigned j = 0; j < 2; j++)
626
      {
627
59602
        children[ii] = n[i][j + 1];
628
119204
        Node pull = nm->mkNode(n.getKind(), children);
629
119204
        Node pullr = d_rew.rewrite(pull);
630
59602
        children[ii] = n[i];
631
59602
        ite_c[i][j] = pullr;
632
      }
633
29801
      if (ite_c[i][0] == ite_c[i][1])
634
      {
635
        // ITE dual invariance
636
        // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
637
        // f( t1..ite( A, s1, s2 )..tn ) ---> t
638
741
        debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
639
741
        return ite_c[i][0];
640
      }
641
29060
      if (d_aggr)
642
      {
643
52777
        if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
644
43739
            && !n[1 - i].getType().isBoolean() && tn.isBoolean())
645
        {
646
          // always pull variable or constant with binary (theory) predicate
647
          // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
648
20222
          Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
649
10111
          debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
650
10111
          return new_ret;
651
        }
652
53123
        for (unsigned j = 0; j < 2; j++)
653
        {
654
70740
          Node pullr = ite_c[i][j];
655
36564
          if (pullr.isConst() || pullr == n[i][j + 1])
656
          {
657
            // ITE single child elimination
658
            // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
659
            // implies
660
            // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
661
4776
            Node new_ret;
662
2388
            if (tn.isBoolean() && pullr.isConst())
663
            {
664
              // remove false/true child immediately
665
832
              bool pol = pullr.getConst<bool>();
666
1664
              std::vector<Node> new_children;
667
832
              new_children.push_back((j == 0) == pol ? n[i][0]
668
                                                     : n[i][0].negate());
669
832
              new_children.push_back(ite_c[i][1 - j]);
670
832
              new_ret = nm->mkNode(pol ? OR : AND, new_children);
671
832
              debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
672
            }
673
            else
674
            {
675
1556
              new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
676
1556
              debugExtendedRewrite(n, new_ret, "ITE single elim");
677
            }
678
2388
            return new_ret;
679
          }
680
        }
681
      }
682
    }
683
  }
684
169758
  if (d_aggr)
685
  {
686
185622
    for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
687
    {
688
32711
      Node nite = n[ip.first];
689
16531
      Assert(nite.getKind() == itek);
690
      // now, simply pull the ITE and try ITE rewrites
691
32711
      Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
692
16531
      pull_ite = d_rew.rewrite(pull_ite);
693
16531
      if (pull_ite.getKind() == ITE)
694
      {
695
32541
        Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
696
16361
        if (!new_pull_ite.isNull())
697
        {
698
181
          debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
699
181
          return new_pull_ite;
700
        }
701
      }
702
      else
703
      {
704
        // A general rewrite could eliminate the ITE by pulling.
705
        // An example is:
706
        //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
707
        //   ite( C, ~~x, ite( C, y, x ) ) --->
708
        //   x
709
        // where ~ is bitvector negation.
710
170
        debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
711
170
        return pull_ite;
712
      }
713
    }
714
  }
715
716
169407
  return Node::null();
717
}
718
719
17377
Node ExtendedRewriter::extendedRewriteNnf(Node ret) const
720
{
721
17377
  Assert(ret.getKind() == NOT);
722
723
17377
  Kind nk = ret[0].getKind();
724
17377
  bool neg_ch = false;
725
17377
  bool neg_ch_1 = false;
726
17377
  if (nk == AND || nk == OR)
727
  {
728
3456
    neg_ch = true;
729
3456
    nk = nk == AND ? OR : AND;
730
  }
731
13921
  else if (nk == IMPLIES)
732
  {
733
30
    neg_ch = true;
734
30
    neg_ch_1 = true;
735
30
    nk = AND;
736
  }
737
13891
  else if (nk == ITE)
738
  {
739
112
    neg_ch = true;
740
112
    neg_ch_1 = true;
741
  }
742
13779
  else if (nk == XOR)
743
  {
744
37
    nk = EQUAL;
745
  }
746
13742
  else if (nk == EQUAL && ret[0][0].getType().isBoolean())
747
  {
748
203
    neg_ch_1 = true;
749
  }
750
  else
751
  {
752
13539
    return Node::null();
753
  }
754
755
7676
  std::vector<Node> new_children;
756
16187
  for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
757
  {
758
24698
    Node c = ret[0][i];
759
12349
    c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
760
12349
    new_children.push_back(c);
761
  }
762
3838
  return NodeManager::currentNM()->mkNode(nk, new_children);
763
}
764
765
53579
Node ExtendedRewriter::extendedRewriteBcp(Kind andk,
766
                                          Kind ork,
767
                                          Kind notk,
768
                                          std::map<Kind, bool>& bcp_kinds,
769
                                          Node ret) const
770
{
771
53579
  Kind k = ret.getKind();
772
53579
  Assert(k == andk || k == ork);
773
53579
  Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
774
775
53579
  NodeManager* nm = NodeManager::currentNM();
776
777
107158
  TypeNode tn = ret.getType();
778
107158
  Node truen = TermUtil::mkTypeMaxValue(tn);
779
107158
  Node falsen = TermUtil::mkTypeValue(tn, 0);
780
781
  // terms to process
782
107158
  std::vector<Node> to_process;
783
337233
  for (const Node& cn : ret)
784
  {
785
283654
    to_process.push_back(cn);
786
  }
787
  // the processing terms
788
107158
  std::vector<Node> clauses;
789
  // the terms we have propagated information to
790
107158
  std::unordered_set<Node> prop_clauses;
791
  // the assignment
792
107158
  std::map<Node, Node> assign;
793
107158
  std::vector<Node> avars;
794
107158
  std::vector<Node> asubs;
795
796
53579
  Kind ok = k == andk ? ork : andk;
797
  // global polarity : when k=ork, everything is negated
798
53579
  bool gpol = k == andk;
799
800
2026
  do
801
  {
802
    // process the current nodes
803
166553
    while (!to_process.empty())
804
    {
805
111363
      std::vector<Node> new_to_process;
806
342127
      for (const Node& cn : to_process)
807
      {
808
286653
        Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
809
286653
        Kind cnk = cn.getKind();
810
286653
        bool pol = cnk != notk;
811
572891
        Node cln = cnk == notk ? cn[0] : cn;
812
286653
        Assert(cln.getKind() != notk);
813
286653
        if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
814
        {
815
          // flatten
816
930
          for (const Node& ccln : cln)
817
          {
818
1292
            Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
819
646
            new_to_process.push_back(lccln);
820
          }
821
        }
822
        else
823
        {
824
          // add it to the assignment
825
572323
          Node val = gpol == pol ? truen : falsen;
826
286369
          std::map<Node, Node>::iterator it = assign.find(cln);
827
572738
          Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
828
286369
                               << std::endl;
829
286369
          if (it != assign.end())
830
          {
831
462
            if (val != it->second)
832
            {
833
415
              Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
834
              // a conflicting assignment: we are done
835
415
              return gpol ? falsen : truen;
836
            }
837
          }
838
          else
839
          {
840
285907
            assign[cln] = val;
841
285907
            avars.push_back(cln);
842
285907
            asubs.push_back(val);
843
          }
844
845
          // also, treat it as clause if possible
846
571908
          if (cln.getNumChildren() > 0
847
521341
              && (bcp_kinds.empty()
848
285954
                  || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
849
          {
850
470774
            if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
851
235387
                && prop_clauses.find(cn) == prop_clauses.end())
852
            {
853
235381
              Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
854
235381
              clauses.push_back(cn);
855
            }
856
          }
857
        }
858
      }
859
55474
      to_process.clear();
860
55474
      to_process.insert(
861
110948
          to_process.end(), new_to_process.begin(), new_to_process.end());
862
    }
863
864
    // apply substitution to all subterms of clauses
865
110380
    std::vector<Node> new_clauses;
866
300233
    for (const Node& c : clauses)
867
    {
868
245043
      bool cpol = c.getKind() != notk;
869
490086
      Node ca = c.getKind() == notk ? c[0] : c;
870
245043
      bool childChanged = false;
871
490086
      std::vector<Node> ccs_children;
872
865714
      for (const Node& cc : ca)
873
      {
874
        // always use partial substitute, to avoid substitution in witness
875
620671
        Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
876
        // substitution is only applicable to compatible kinds in bcp_kinds
877
1241342
        Node ccs = partialSubstitute(cc, assign, bcp_kinds);
878
620671
        childChanged = childChanged || ccs != cc;
879
620671
        ccs_children.push_back(ccs);
880
      }
881
245043
      if (childChanged)
882
      {
883
2759
        if (ca.getMetaKind() == metakind::PARAMETERIZED)
884
        {
885
4
          ccs_children.insert(ccs_children.begin(), ca.getOperator());
886
        }
887
5518
        Node ccs = nm->mkNode(ca.getKind(), ccs_children);
888
2759
        ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
889
5518
        Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
890
2759
                             << std::endl;
891
2759
        ccs = d_rew.rewrite(ccs);
892
2759
        Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
893
2759
        to_process.push_back(ccs);
894
        // store this as a node that propagation touched. This marks c so that
895
        // it will not be included in the final construction.
896
2759
        prop_clauses.insert(ca);
897
      }
898
      else
899
      {
900
242284
        new_clauses.push_back(c);
901
      }
902
    }
903
55190
    clauses.clear();
904
55190
    clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
905
55190
  } while (!to_process.empty());
906
907
  // remake the node
908
53164
  if (!prop_clauses.empty())
909
  {
910
3744
    std::vector<Node> children;
911
18725
    for (std::pair<const Node, Node>& l : assign)
912
    {
913
33706
      Node a = l.first;
914
      // if propagation did not touch a
915
16853
      if (prop_clauses.find(a) == prop_clauses.end())
916
      {
917
14117
        Assert(l.second == truen || l.second == falsen);
918
28234
        Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
919
14117
        children.push_back(ln);
920
      }
921
    }
922
3744
    Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
923
1872
    Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
924
1872
    return new_ret;
925
  }
926
927
51292
  return Node::null();
928
}
929
930
51292
Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
931
                                                Kind ork,
932
                                                Kind notk,
933
                                                Node n) const
934
{
935
51292
  Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
936
51292
  NodeManager* nm = NodeManager::currentNM();
937
938
51292
  Kind nk = n.getKind();
939
51292
  Assert(nk == andk || nk == ork);
940
51292
  Kind onk = nk == andk ? ork : andk;
941
  // count the number of times atoms occur
942
102584
  std::map<Node, std::vector<Node> > lit_to_cl;
943
102584
  std::map<Node, std::vector<Node> > cl_to_lits;
944
319205
  for (const Node& nc : n)
945
  {
946
267913
    Kind nck = nc.getKind();
947
267913
    if (nck == onk)
948
    {
949
275746
      for (const Node& ncl : nc)
950
      {
951
667662
        if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
952
667662
            == lit_to_cl[ncl].end())
953
        {
954
222554
          lit_to_cl[ncl].push_back(nc);
955
222554
          cl_to_lits[nc].push_back(ncl);
956
        }
957
      }
958
    }
959
    else
960
    {
961
214721
      lit_to_cl[nc].push_back(nc);
962
214721
      cl_to_lits[nc].push_back(nc);
963
    }
964
  }
965
  // get the maximum shared literal to factor
966
51292
  unsigned max_size = 0;
967
102584
  Node flit;
968
439666
  for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
969
  {
970
388374
    if (ltc.second.size() > max_size)
971
    {
972
55187
      max_size = ltc.second.size();
973
55187
      flit = ltc.first;
974
    }
975
  }
976
51292
  if (max_size > 1)
977
  {
978
    // do the factoring
979
7720
    std::vector<Node> children;
980
7720
    std::vector<Node> fchildren;
981
3860
    std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
982
3860
    std::vector<Node>& cls = itl->second;
983
23670
    for (const Node& nc : n)
984
    {
985
19810
      if (std::find(cls.begin(), cls.end(), nc) == cls.end())
986
      {
987
8511
        children.push_back(nc);
988
      }
989
      else
990
      {
991
        // rebuild
992
11299
        std::vector<Node>& lits = cl_to_lits[nc];
993
        std::vector<Node>::iterator itlfl =
994
11299
            std::find(lits.begin(), lits.end(), flit);
995
11299
        Assert(itlfl != lits.end());
996
11299
        lits.erase(itlfl);
997
        // rebuild
998
11299
        if (!lits.empty())
999
        {
1000
22598
          Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
1001
11299
          fchildren.push_back(new_cl);
1002
        }
1003
      }
1004
    }
1005
    // rebuild the factored children
1006
3860
    Assert(!fchildren.empty());
1007
7720
    Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
1008
3860
    children.push_back(nm->mkNode(onk, flit, fcn));
1009
7720
    Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
1010
3860
    Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
1011
3860
    return ret;
1012
  }
1013
  else
1014
  {
1015
47432
    Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
1016
  }
1017
47432
  return Node::null();
1018
}
1019
1020
47432
Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
1021
                                            Kind ork,
1022
                                            Kind eqk,
1023
                                            Kind notk,
1024
                                            std::map<Kind, bool>& bcp_kinds,
1025
                                            Node n,
1026
                                            bool isXor) const
1027
{
1028
47432
  Assert(n.getKind() == andk || n.getKind() == ork);
1029
47432
  Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
1030
1031
47432
  NodeManager* nm = NodeManager::currentNM();
1032
47432
  Kind nk = n.getKind();
1033
47432
  bool gpol = (nk == andk);
1034
262502
  for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
1035
  {
1036
433722
    Node lit = n[i];
1037
218652
    if (lit.getKind() == eqk)
1038
    {
1039
      // eq is the equality we are basing a substitution on
1040
187936
      Node eq;
1041
95759
      if (gpol == isXor)
1042
      {
1043
        // can only turn disequality into equality if types are the same
1044
8149
        if (lit[1].getType() == lit.getType())
1045
        {
1046
          // t != s ---> ~t = s
1047
210
          if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1048
          {
1049
36
            eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1050
          }
1051
          else
1052
          {
1053
174
            eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1054
          }
1055
        }
1056
      }
1057
      else
1058
      {
1059
87610
        eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1060
      }
1061
95759
      if (!eq.isNull())
1062
      {
1063
        // see if it corresponds to a substitution
1064
172058
        std::vector<Node> vars;
1065
172058
        std::vector<Node> subs;
1066
87820
        if (inferSubstitution(eq, vars, subs))
1067
        {
1068
80173
          Assert(vars.size() == 1);
1069
156764
          std::vector<Node> children;
1070
80173
          bool childrenChanged = false;
1071
          // apply to all other children
1072
1419248
          for (unsigned j = 0; j < nchild; j++)
1073
          {
1074
2678150
            Node ccs = n[j];
1075
1339075
            if (i != j)
1076
            {
1077
              // Substitution is only applicable to compatible kinds. We always
1078
              // use the partialSubstitute method to avoid substitution into
1079
              // witness terms.
1080
1258902
              ccs = partialSubstitute(ccs, vars, subs, bcp_kinds);
1081
1258902
              childrenChanged = childrenChanged || n[j] != ccs;
1082
            }
1083
1339075
            children.push_back(ccs);
1084
          }
1085
80173
          if (childrenChanged)
1086
          {
1087
3582
            return nm->mkNode(nk, children);
1088
          }
1089
        }
1090
      }
1091
    }
1092
  }
1093
1094
43850
  return Node::null();
1095
}
1096
1097
/** sort pairs by their second (unsigned) argument */
1098
57093
static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1099
                           const std::pair<Node, unsigned>& b)
1100
{
1101
57093
  return (a.second < b.second);
1102
}
1103
1104
/** A simple subsumption trie used to compute pairwise list subsets */
1105
172634
class SimpSubsumeTrie
1106
{
1107
 public:
1108
  /** the children of this node */
1109
  std::map<Node, SimpSubsumeTrie> d_children;
1110
  /** the term at this node */
1111
  Node d_data;
1112
  /** add term to the trie
1113
   *
1114
   * This adds term c to this trie, whose atom list is alist. This adds terms
1115
   * s to subsumes such that the atom list of s is a subset of the atom list
1116
   * of c. For example, say:
1117
   *   c1.alist = { A }
1118
   *   c2.alist = { C }
1119
   *   c3.alist = { B, C }
1120
   *   c4.alist = { A, B, D }
1121
   *   c5.alist = { A, B, C }
1122
   * If these terms are added in the order c1, c2, c3, c4, c5, then:
1123
   *   addTerm c1 results in subsumes = {}
1124
   *   addTerm c2 results in subsumes = {}
1125
   *   addTerm c3 results in subsumes = { c2 }
1126
   *   addTerm c4 results in subsumes = { c1 }
1127
   *   addTerm c5 results in subsumes = { c1, c2, c3 }
1128
   * Notice that the intended use case of this trie is to add term t before t'
1129
   * only when size( t.alist ) <= size( t'.alist ).
1130
   *
1131
   * The last two arguments describe the state of the path [t0...tn] we
1132
   * have followed in the trie during the recursive call.
1133
   * If doAdd = true,
1134
   *   then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1135
   *   we add c as the current node of this trie.
1136
   * If doAdd = false,
1137
   *   then t1...tn occur in alist.
1138
   */
1139
114803
  void addTerm(Node c,
1140
               std::vector<Node>& alist,
1141
               std::vector<Node>& subsumes,
1142
               unsigned index = 0,
1143
               bool doAdd = true)
1144
  {
1145
114803
    if (!d_data.isNull())
1146
    {
1147
12
      subsumes.push_back(d_data);
1148
    }
1149
114803
    if (doAdd)
1150
    {
1151
114779
      if (index == alist.size())
1152
      {
1153
56560
        d_data = c;
1154
56560
        return;
1155
      }
1156
    }
1157
    // try all children where we have this atom
1158
87447
    for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1159
    {
1160
29204
      if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1161
      {
1162
24
        cp.second.addTerm(c, alist, subsumes, 0, false);
1163
      }
1164
    }
1165
58243
    if (doAdd)
1166
    {
1167
58219
      d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1168
    }
1169
  }
1170
};
1171
1172
28215
Node ExtendedRewriter::extendedRewriteEqChain(
1173
    Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) const
1174
{
1175
28215
  Assert(ret.getKind() == eqk);
1176
1177
  // this rewrite is aggressive; it in fact has the precondition that other
1178
  // aggressive rewrites (including BCP) have been applied.
1179
28215
  if (!d_aggr)
1180
  {
1181
96
    return Node::null();
1182
  }
1183
1184
28119
  NodeManager* nm = NodeManager::currentNM();
1185
1186
56238
  TypeNode tn = ret[0].getType();
1187
1188
  // sort/cancelling for Boolean EQUAL/XOR-chains
1189
28119
  Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1190
1191
  // get the children on either side
1192
28119
  bool gpol = true;
1193
56238
  std::vector<Node> children;
1194
84357
  for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1195
  {
1196
112476
    Node curr = ret[r];
1197
    // assume, if necessary, right associative
1198
57122
    while (curr.getKind() == eqk && curr[0].getType() == tn)
1199
    {
1200
442
      children.push_back(curr[0]);
1201
442
      curr = curr[1];
1202
    }
1203
56238
    children.push_back(curr);
1204
  }
1205
1206
56238
  std::map<Node, bool> cstatus;
1207
  // add children to status
1208
84799
  for (const Node& c : children)
1209
  {
1210
113360
    Node a = c;
1211
56680
    if (a.getKind() == notk)
1212
    {
1213
589
      gpol = !gpol;
1214
589
      a = a[0];
1215
    }
1216
56680
    Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1217
56680
    std::map<Node, bool>::iterator itc = cstatus.find(a);
1218
56680
    if (itc == cstatus.end())
1219
    {
1220
56647
      cstatus[a] = true;
1221
    }
1222
    else
1223
    {
1224
      // cancels
1225
33
      cstatus.erase(a);
1226
33
      if (isXor)
1227
      {
1228
        gpol = !gpol;
1229
      }
1230
    }
1231
  }
1232
28119
  Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1233
1234
28119
  if (cstatus.empty())
1235
  {
1236
2
    return TermUtil::mkTypeConst(tn, gpol);
1237
  }
1238
1239
28117
  children.clear();
1240
1241
  // compute the atoms of each child
1242
28117
  Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1243
28117
  Trace("ext-rew-eqchain") << "  eqchain-simplify: get atoms...\n";
1244
56234
  std::map<Node, std::map<Node, bool> > atoms;
1245
56234
  std::map<Node, std::vector<Node> > alist;
1246
56234
  std::vector<std::pair<Node, unsigned> > atom_count;
1247
84731
  for (std::pair<const Node, bool>& cp : cstatus)
1248
  {
1249
56614
    if (!cp.second)
1250
    {
1251
      // already eliminated
1252
      continue;
1253
    }
1254
113228
    Node c = cp.first;
1255
56614
    Kind ck = c.getKind();
1256
56614
    Trace("ext-rew-eqchain") << "  process c = " << c << std::endl;
1257
56614
    if (ck == andk || ck == ork)
1258
    {
1259
3398
      for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1260
      {
1261
5083
        Node cl = c[j];
1262
2555
        bool pol = cl.getKind() != notk;
1263
5083
        Node ca = pol ? cl : cl[0];
1264
2555
        bool newVal = (ck == andk ? !pol : pol);
1265
5110
        Trace("ext-rew-eqchain")
1266
2555
            << "  atoms(" << c << ", " << ca << ") = " << newVal << std::endl;
1267
2555
        Assert(atoms[c].find(ca) == atoms[c].end());
1268
        // polarity is flipped when we are AND
1269
2555
        atoms[c][ca] = newVal;
1270
2555
        alist[c].push_back(ca);
1271
1272
        // if this already exists as a child of the equality chain, eliminate.
1273
        // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1274
        // consider ( x & y ) a unit, whereas below it is expanded to
1275
        // ~( ~x | ~y ).
1276
2555
        std::map<Node, bool>::iterator itc = cstatus.find(ca);
1277
2555
        if (itc != cstatus.end() && itc->second)
1278
        {
1279
          // cancel it
1280
27
          cstatus[ca] = false;
1281
27
          cstatus[c] = false;
1282
          // make new child
1283
          // x = ( y | ~x ) ---> y & x
1284
          // x = ( y | x ) ---> ~y | x
1285
          // x = ( y & x ) ---> y | ~x
1286
          // x = ( y & ~x ) ---> ~y & ~x
1287
54
          std::vector<Node> new_children;
1288
95
          for (unsigned k = 0, nchildc = c.getNumChildren(); k < nchildc; k++)
1289
          {
1290
68
            if (j != k)
1291
            {
1292
41
              new_children.push_back(c[k]);
1293
            }
1294
          }
1295
54
          Node nc[2];
1296
27
          nc[0] = c[j];
1297
27
          nc[1] = new_children.size() == 1 ? new_children[0]
1298
                                           : nm->mkNode(ck, new_children);
1299
          // negate the proper child
1300
27
          unsigned nindex = (ck == andk) == pol ? 0 : 1;
1301
27
          nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1302
27
          Kind nk = pol ? ork : andk;
1303
          // store as new child
1304
27
          children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1305
27
          if (isXor)
1306
          {
1307
            gpol = !gpol;
1308
          }
1309
27
          break;
1310
        }
1311
870
      }
1312
    }
1313
    else
1314
    {
1315
55744
      bool pol = ck != notk;
1316
111488
      Node ca = pol ? c : c[0];
1317
55744
      atoms[c][ca] = pol;
1318
111488
      Trace("ext-rew-eqchain")
1319
55744
          << "  atoms(" << c << ", " << ca << ") = " << pol << std::endl;
1320
55744
      alist[c].push_back(ca);
1321
    }
1322
56614
    atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1323
  }
1324
  // sort the atoms in each atom list
1325
84731
  for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1326
84731
       it != alist.end();
1327
       ++it)
1328
  {
1329
56614
    std::sort(it->second.begin(), it->second.end());
1330
  }
1331
  // check subsumptions
1332
  // sort by #atoms
1333
28117
  std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1334
28117
  if (Trace.isOn("ext-rew-eqchain"))
1335
  {
1336
    for (const std::pair<Node, unsigned>& ac : atom_count)
1337
    {
1338
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << ac.first << " has "
1339
                               << ac.second << " atoms." << std::endl;
1340
    }
1341
    Trace("ext-rew-eqchain") << "  eqchain-simplify: compute subsumptions...\n";
1342
  }
1343
56234
  SimpSubsumeTrie sst;
1344
84731
  for (std::pair<const Node, bool>& cp : cstatus)
1345
  {
1346
56614
    if (!cp.second)
1347
    {
1348
      // already eliminated
1349
54
      continue;
1350
    }
1351
113120
    Node c = cp.first;
1352
56560
    std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1353
56560
    Assert(itc != atoms.end());
1354
113120
    Trace("ext-rew-eqchain") << "  - add term " << c << " with atom list "
1355
56560
                             << alist[c] << "...\n";
1356
113120
    std::vector<Node> subsumes;
1357
56560
    sst.addTerm(c, alist[c], subsumes);
1358
56560
    for (const Node& cc : subsumes)
1359
    {
1360
4
      if (!cstatus[cc])
1361
      {
1362
        // subsumes a child that was already eliminated
1363
        continue;
1364
      }
1365
8
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << c << " subsumes "
1366
4
                               << cc << std::endl;
1367
      // for each of the atoms in cc
1368
4
      std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1369
4
      Assert(itcc != atoms.end());
1370
4
      std::vector<Node> common_children;
1371
4
      std::vector<Node> diff_children;
1372
12
      for (const std::pair<const Node, bool>& ap : itcc->second)
1373
      {
1374
        // compare the polarity
1375
16
        Node a = ap.first;
1376
8
        bool polcc = ap.second;
1377
8
        Assert(itc->second.find(a) != itc->second.end());
1378
8
        bool polc = itc->second[a];
1379
16
        Trace("ext-rew-eqchain") << "    eqchain-simplify: atom " << a
1380
8
                                 << " has polarities : " << polc << " " << polcc
1381
8
                                 << "\n";
1382
16
        Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1383
8
        if (polc != polcc)
1384
        {
1385
2
          diff_children.push_back(lit);
1386
        }
1387
        else
1388
        {
1389
6
          common_children.push_back(lit);
1390
        }
1391
      }
1392
4
      std::vector<Node> rem_children;
1393
12
      for (const std::pair<const Node, bool>& ap : itc->second)
1394
      {
1395
16
        Node a = ap.first;
1396
8
        if (atoms[cc].find(a) == atoms[cc].end())
1397
        {
1398
          bool polc = ap.second;
1399
          rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1400
        }
1401
      }
1402
8
      Trace("ext-rew-eqchain")
1403
8
          << "    #common/diff/rem: " << common_children.size() << "/"
1404
4
          << diff_children.size() << "/" << rem_children.size() << "\n";
1405
4
      bool do_rewrite = false;
1406
9
      if (common_children.empty() && itc->second.size() == itcc->second.size()
1407
5
          && itcc->second.size() == 2)
1408
      {
1409
        // x | y = ~x | ~y ---> ~( x = y )
1410
1
        do_rewrite = true;
1411
1
        children.push_back(diff_children[0]);
1412
1
        children.push_back(diff_children[1]);
1413
1
        gpol = !gpol;
1414
1
        Trace("ext-rew-eqchain") << "    apply 2-child all-diff\n";
1415
      }
1416
3
      else if (common_children.empty() && diff_children.size() == 1)
1417
      {
1418
        do_rewrite = true;
1419
        // x = ( ~x | y ) ---> ~( ~x | ~y )
1420
        Node remn = rem_children.size() == 1 ? rem_children[0]
1421
                                             : nm->mkNode(ork, rem_children);
1422
        remn = TermUtil::mkNegate(notk, remn);
1423
        children.push_back(nm->mkNode(ork, diff_children[0], remn));
1424
        if (!isXor)
1425
        {
1426
          gpol = !gpol;
1427
        }
1428
        Trace("ext-rew-eqchain") << "    apply unit resolution\n";
1429
      }
1430
6
      else if (diff_children.size() == 1
1431
3
               && itc->second.size() == itcc->second.size())
1432
      {
1433
        // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1434
        do_rewrite = true;
1435
        Assert(!common_children.empty());
1436
        Node comn = common_children.size() == 1
1437
                        ? common_children[0]
1438
                        : nm->mkNode(ork, common_children);
1439
        children.push_back(comn);
1440
        if (isXor)
1441
        {
1442
          gpol = !gpol;
1443
        }
1444
        Trace("ext-rew-eqchain") << "    apply resolution\n";
1445
      }
1446
3
      else if (diff_children.empty())
1447
      {
1448
3
        do_rewrite = true;
1449
3
        if (rem_children.empty())
1450
        {
1451
          // x | y = x | y ---> true
1452
          // this can happen if we have ( ~x & ~y ) = ( x | y )
1453
3
          children.push_back(TermUtil::mkTypeMaxValue(tn));
1454
3
          if (isXor)
1455
          {
1456
            gpol = !gpol;
1457
          }
1458
3
          Trace("ext-rew-eqchain") << "    apply cancel\n";
1459
        }
1460
        else
1461
        {
1462
          // x | y = ( x | y | z ) ---> ( x | y | ~z )
1463
          Node remn = rem_children.size() == 1 ? rem_children[0]
1464
                                               : nm->mkNode(ork, rem_children);
1465
          remn = TermUtil::mkNegate(notk, remn);
1466
          Node comn = common_children.size() == 1
1467
                          ? common_children[0]
1468
                          : nm->mkNode(ork, common_children);
1469
          children.push_back(nm->mkNode(ork, comn, remn));
1470
          if (isXor)
1471
          {
1472
            gpol = !gpol;
1473
          }
1474
          Trace("ext-rew-eqchain") << "    apply subsume\n";
1475
        }
1476
      }
1477
4
      if (do_rewrite)
1478
      {
1479
        // eliminate the children, reverse polarity as needed
1480
12
        for (unsigned r = 0; r < 2; r++)
1481
        {
1482
16
          Node c_rem = r == 0 ? c : cc;
1483
8
          cstatus[c_rem] = false;
1484
8
          if (c_rem.getKind() == andk)
1485
          {
1486
5
            gpol = !gpol;
1487
          }
1488
        }
1489
4
        break;
1490
      }
1491
    }
1492
  }
1493
28117
  Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1494
1495
  // sorted right associative chain
1496
28117
  bool has_nvar = false;
1497
28117
  unsigned nvar_index = 0;
1498
84731
  for (std::pair<const Node, bool>& cp : cstatus)
1499
  {
1500
56614
    if (cp.second)
1501
    {
1502
56552
      if (!cp.first.isVar())
1503
      {
1504
29510
        has_nvar = true;
1505
29510
        nvar_index = children.size();
1506
      }
1507
56552
      children.push_back(cp.first);
1508
    }
1509
  }
1510
28117
  std::sort(children.begin(), children.end());
1511
1512
56234
  Node new_ret;
1513
28117
  if (!gpol)
1514
  {
1515
    // negate the constant child if it exists
1516
393
    unsigned nindex = has_nvar ? nvar_index : 0;
1517
393
    children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1518
  }
1519
28117
  new_ret = children.back();
1520
28117
  unsigned index = children.size() - 1;
1521
85051
  while (index > 0)
1522
  {
1523
28467
    index--;
1524
28467
    new_ret = nm->mkNode(eqk, children[index], new_ret);
1525
  }
1526
28117
  new_ret = d_rew.rewrite(new_ret);
1527
28117
  if (new_ret != ret)
1528
  {
1529
345
    return new_ret;
1530
  }
1531
27772
  return Node::null();
1532
}
1533
1534
1942307
Node ExtendedRewriter::partialSubstitute(
1535
    Node n,
1536
    const std::map<Node, Node>& assign,
1537
    const std::map<Kind, bool>& rkinds) const
1538
{
1539
3884614
  std::unordered_map<TNode, Node> visited;
1540
1942307
  std::unordered_map<TNode, Node>::iterator it;
1541
1942307
  std::map<Node, Node>::const_iterator ita;
1542
3884614
  std::vector<TNode> visit;
1543
3884614
  TNode cur;
1544
1942307
  visit.push_back(n);
1545
25520801
  do
1546
  {
1547
27463108
    cur = visit.back();
1548
27463108
    visit.pop_back();
1549
27463108
    it = visited.find(cur);
1550
1551
27463108
    if (it == visited.end())
1552
    {
1553
11607464
      ita = assign.find(cur);
1554
11607464
      if (ita != assign.end())
1555
      {
1556
29638
        visited[cur] = ita->second;
1557
      }
1558
      else
1559
      {
1560
        // If rkinds is non-empty, then can only recurse on its kinds.
1561
        // We also always disallow substitution into witness. Notice that
1562
        // we disallow witness here, due to unsoundness when applying contextual
1563
        // substitutions over witness terms (see #4620).
1564
11577826
        Kind k = cur.getKind();
1565
11577826
        if (k != WITNESS && (rkinds.empty() || rkinds.find(k) != rkinds.end()))
1566
        {
1567
11577826
          visited[cur] = Node::null();
1568
11577826
          visit.push_back(cur);
1569
25520801
          for (const Node& cn : cur)
1570
          {
1571
13942975
            visit.push_back(cn);
1572
          }
1573
        }
1574
        else
1575
        {
1576
          visited[cur] = cur;
1577
        }
1578
      }
1579
    }
1580
15855644
    else if (it->second.isNull())
1581
    {
1582
23155652
      Node ret = cur;
1583
11577826
      bool childChanged = false;
1584
23155652
      std::vector<Node> children;
1585
11577826
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
1586
      {
1587
60869
        children.push_back(cur.getOperator());
1588
      }
1589
25520801
      for (const Node& cn : cur)
1590
      {
1591
13942975
        it = visited.find(cn);
1592
13942975
        Assert(it != visited.end());
1593
13942975
        Assert(!it->second.isNull());
1594
13942975
        childChanged = childChanged || cn != it->second;
1595
13942975
        children.push_back(it->second);
1596
      }
1597
11577826
      if (childChanged)
1598
      {
1599
98674
        ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1600
      }
1601
11577826
      visited[cur] = ret;
1602
    }
1603
27463108
  } while (!visit.empty());
1604
1942307
  Assert(visited.find(n) != visited.end());
1605
1942307
  Assert(!visited.find(n)->second.isNull());
1606
3884614
  return visited[n];
1607
}
1608
1609
1290720
Node ExtendedRewriter::partialSubstitute(
1610
    Node n,
1611
    const std::vector<Node>& vars,
1612
    const std::vector<Node>& subs,
1613
    const std::map<Kind, bool>& rkinds) const
1614
{
1615
1290720
  Assert(vars.size() == subs.size());
1616
2581440
  std::map<Node, Node> assign;
1617
2598315
  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
1618
  {
1619
1307595
    assign[vars[i]] = subs[i];
1620
  }
1621
2581440
  return partialSubstitute(n, assign, rkinds);
1622
}
1623
1624
100167
Node ExtendedRewriter::solveEquality(Node n) const
1625
{
1626
  // TODO (#1706) : implement
1627
100167
  Assert(n.getKind() == EQUAL);
1628
1629
100167
  return Node::null();
1630
}
1631
1632
140995
bool ExtendedRewriter::inferSubstitution(Node n,
1633
                                         std::vector<Node>& vars,
1634
                                         std::vector<Node>& subs,
1635
                                         bool usePred) const
1636
{
1637
140995
  if (n.getKind() == AND)
1638
  {
1639
4482
    bool ret = false;
1640
25839
    for (const Node& nc : n)
1641
    {
1642
21357
      bool cret = inferSubstitution(nc, vars, subs, usePred);
1643
21357
      ret = ret || cret;
1644
    }
1645
4482
    return ret;
1646
  }
1647
136513
  if (n.getKind() == EQUAL)
1648
  {
1649
    // see if it can be put into form x = y
1650
108455
    Node slv_eq = solveEquality(n);
1651
100167
    if (!slv_eq.isNull())
1652
    {
1653
      n = slv_eq;
1654
    }
1655
300501
    Node v[2];
1656
255884
    for (unsigned i = 0; i < 2; i++)
1657
    {
1658
200161
      if (n[i].isConst())
1659
      {
1660
44444
        vars.push_back(n[1 - i]);
1661
44444
        subs.push_back(n[i]);
1662
44444
        return true;
1663
      }
1664
155717
      if (n[i].isVar())
1665
      {
1666
143788
        v[i] = n[i];
1667
      }
1668
11929
      else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1669
      {
1670
269
        v[i] = n[i][0];
1671
      }
1672
    }
1673
75646
    for (unsigned i = 0; i < 2; i++)
1674
    {
1675
87281
      TNode r1 = v[i];
1676
87281
      Node r2 = v[1 - i];
1677
67358
      if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1678
      {
1679
47435
        r2 = n[1 - i];
1680
47435
        if (v[i] != n[i])
1681
        {
1682
260
          Assert(TermUtil::isNegate(n[i].getKind()));
1683
260
          r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1684
        }
1685
        // TODO (#1706) : union find
1686
47435
        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1687
        {
1688
47435
          vars.push_back(r1);
1689
47435
          subs.push_back(r2);
1690
47435
          return true;
1691
        }
1692
      }
1693
    }
1694
  }
1695
44634
  if (usePred)
1696
  {
1697
36987
    bool negated = n.getKind() == NOT;
1698
36987
    vars.push_back(negated ? n[0] : n);
1699
36987
    subs.push_back(negated ? d_false : d_true);
1700
36987
    return true;
1701
  }
1702
7647
  return false;
1703
}
1704
1705
4474
Node ExtendedRewriter::extendedRewriteStrings(Node ret) const
1706
{
1707
4474
  Node new_ret;
1708
8948
  Trace("q-ext-rewrite-debug")
1709
4474
      << "Extended rewrite strings : " << ret << std::endl;
1710
1711
4474
  if (ret.getKind() == EQUAL)
1712
  {
1713
310
    new_ret = strings::SequencesRewriter(nullptr).rewriteEqualityExt(ret);
1714
  }
1715
1716
4474
  return new_ret;
1717
}
1718
1719
114572
void ExtendedRewriter::debugExtendedRewrite(Node n,
1720
                                            Node ret,
1721
                                            const char* c) const
1722
{
1723
114572
  if (Trace.isOn("q-ext-rewrite"))
1724
  {
1725
    if (!ret.isNull())
1726
    {
1727
      Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1728
      Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1729
                             << " rewrites to " << ret << std::endl;
1730
    }
1731
  }
1732
114572
}
1733
1734
}  // namespace quantifiers
1735
}  // namespace theory
1736
29577
}  // namespace cvc5