GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/extended_rewrite.cpp Lines: 777 843 92.2 %
Date: 2021-09-29 Branches: 1873 3794 49.4 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of extended rewriting techniques.
14
 */
15
16
#include "theory/quantifiers/extended_rewrite.h"
17
18
#include <sstream>
19
20
#include "theory/arith/arith_msum.h"
21
#include "theory/bv/theory_bv_utils.h"
22
#include "theory/datatypes/datatypes_rewriter.h"
23
#include "theory/quantifiers/term_util.h"
24
#include "theory/rewriter.h"
25
#include "theory/strings/sequences_rewriter.h"
26
#include "theory/theory.h"
27
28
using namespace cvc5::kind;
29
using namespace std;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
35
struct ExtRewriteAttributeId
36
{
37
};
38
typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
39
40
struct ExtRewriteAggAttributeId
41
{
42
};
43
typedef expr::Attribute<ExtRewriteAggAttributeId, Node> ExtRewriteAggAttribute;
44
45
243145
ExtendedRewriter::ExtendedRewriter(Rewriter& rew, bool aggr)
46
243145
    : d_rew(rew), d_aggr(aggr)
47
{
48
243145
  d_true = NodeManager::currentNM()->mkConst(true);
49
243145
  d_false = NodeManager::currentNM()->mkConst(false);
50
243145
}
51
52
254221
void ExtendedRewriter::setCache(Node n, Node ret) const
53
{
54
254221
  if (d_aggr)
55
  {
56
    ExtRewriteAggAttribute erga;
57
253573
    n.setAttribute(erga, ret);
58
  }
59
  else
60
  {
61
    ExtRewriteAttribute era;
62
648
    n.setAttribute(era, ret);
63
  }
64
254221
}
65
66
518269
Node ExtendedRewriter::getCache(Node n) const
67
{
68
518269
  if (d_aggr)
69
  {
70
517757
    if (n.hasAttribute(ExtRewriteAggAttribute()))
71
    {
72
389566
      return n.getAttribute(ExtRewriteAggAttribute());
73
    }
74
  }
75
  else
76
  {
77
512
    if (n.hasAttribute(ExtRewriteAttribute()))
78
    {
79
188
      return n.getAttribute(ExtRewriteAttribute());
80
    }
81
  }
82
128515
  return Node::null();
83
}
84
85
257248
bool ExtendedRewriter::addToChildren(Node nc,
86
                                     std::vector<Node>& children,
87
                                     bool dropDup) const
88
{
89
  // If the operator is non-additive, do not consider duplicates
90
257248
  if (dropDup
91
257248
      && std::find(children.begin(), children.end(), nc) != children.end())
92
  {
93
416
    return false;
94
  }
95
256832
  children.push_back(nc);
96
256832
  return true;
97
}
98
99
518269
Node ExtendedRewriter::extendedRewrite(Node n) const
100
{
101
518269
  n = d_rew.rewrite(n);
102
103
  // has it already been computed?
104
1036538
  Node ncache = getCache(n);
105
518269
  if (!ncache.isNull())
106
  {
107
389754
    return ncache;
108
  }
109
110
257030
  Node ret = n;
111
128515
  NodeManager* nm = NodeManager::currentNM();
112
113
  //--------------------pre-rewrite
114
128515
  if (d_aggr)
115
  {
116
253573
    Node pre_new_ret;
117
128191
    if (ret.getKind() == IMPLIES)
118
    {
119
23
      pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
120
23
      debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
121
    }
122
128168
    else if (ret.getKind() == XOR)
123
    {
124
161
      pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
125
161
      debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
126
    }
127
128007
    else if (ret.getKind() == NOT)
128
    {
129
12155
      pre_new_ret = extendedRewriteNnf(ret);
130
12155
      debugExtendedRewrite(ret, pre_new_ret, "NNF");
131
    }
132
128191
    if (!pre_new_ret.isNull())
133
    {
134
2809
      ret = extendedRewrite(pre_new_ret);
135
136
5618
      Trace("q-ext-rewrite-debug")
137
2809
          << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
138
2809
      setCache(n, ret);
139
2809
      return ret;
140
    }
141
  }
142
  //--------------------end pre-rewrite
143
144
  //--------------------rewrite children
145
125706
  if (n.getNumChildren() > 0)
146
  {
147
225208
    std::vector<Node> children;
148
112604
    if (n.getMetaKind() == metakind::PARAMETERIZED)
149
    {
150
2956
      children.push_back(n.getOperator());
151
    }
152
112604
    Kind k = n.getKind();
153
112604
    bool childChanged = false;
154
112604
    bool isNonAdditive = TermUtil::isNonAdditive(k);
155
    // We flatten associative operators below, which requires k to be n-ary.
156
112604
    bool isAssoc = TermUtil::isAssoc(k, true);
157
368539
    for (unsigned i = 0; i < n.getNumChildren(); i++)
158
    {
159
511870
      Node nc = extendedRewrite(n[i]);
160
255935
      childChanged = nc != n[i] || childChanged;
161
255935
      if (isAssoc && nc.getKind() == n.getKind())
162
      {
163
2747
        for (const Node& ncc : nc)
164
        {
165
2030
          if (!addToChildren(ncc, children, isNonAdditive))
166
          {
167
9
            childChanged = true;
168
          }
169
        }
170
      }
171
255218
      else if (!addToChildren(nc, children, isNonAdditive))
172
      {
173
407
        childChanged = true;
174
      }
175
    }
176
112604
    Assert(!children.empty());
177
    // Some commutative operators have rewriters that are agnostic to order,
178
    // thus, we sort here.
179
112604
    if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
180
    {
181
62660
      childChanged = true;
182
62660
      std::sort(children.begin(), children.end());
183
    }
184
112604
    if (childChanged)
185
    {
186
72753
      if (isNonAdditive && children.size() == 1)
187
      {
188
        // we may have subsumed children down to one
189
33
        ret = children[0];
190
      }
191
72720
      else if (isAssoc
192
72720
               && children.size() > kind::metakind::getMaxArityForKind(k))
193
      {
194
2
        Assert(kind::metakind::getMaxArityForKind(k) >= 2);
195
        // kind may require binary construction
196
2
        ret = children[0];
197
6
        for (unsigned i = 1, nchild = children.size(); i < nchild; i++)
198
        {
199
4
          ret = nm->mkNode(k, ret, children[i]);
200
        }
201
      }
202
      else
203
      {
204
72718
        ret = nm->mkNode(k, children);
205
      }
206
    }
207
  }
208
125706
  ret = d_rew.rewrite(ret);
209
  //--------------------end rewrite children
210
211
  // now, do extended rewrite
212
251412
  Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
213
125706
                               << " (from " << n << ")" << std::endl;
214
251412
  Node new_ret;
215
216
  //---------------------- theory-independent post-rewriting
217
125706
  if (ret.getKind() == ITE)
218
  {
219
11747
    new_ret = extendedRewriteIte(ITE, ret);
220
  }
221
113959
  else if (ret.getKind() == AND || ret.getKind() == OR)
222
  {
223
16366
    new_ret = extendedRewriteAndOr(ret);
224
  }
225
97593
  else if (ret.getKind() == EQUAL)
226
  {
227
16747
    new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
228
16747
    debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
229
  }
230
125706
  Assert(new_ret.isNull() || new_ret != ret);
231
125706
  if (new_ret.isNull() && ret.getKind() != ITE)
232
  {
233
    // simple ITE pulling
234
111010
    new_ret = extendedRewritePullIte(ITE, ret);
235
  }
236
  //----------------------end theory-independent post-rewriting
237
238
  //----------------------theory-specific post-rewriting
239
125706
  if (new_ret.isNull())
240
  {
241
    TheoryId tid;
242
110581
    if (ret.getKind() == ITE)
243
    {
244
10540
      tid = Theory::theoryOf(ret.getType());
245
    }
246
    else
247
    {
248
100041
      tid = Theory::theoryOf(ret);
249
    }
250
221162
    Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
251
110581
                                 << std::endl;
252
110581
    if (tid == THEORY_STRINGS)
253
    {
254
4131
      new_ret = extendedRewriteStrings(ret);
255
    }
256
  }
257
  //----------------------end theory-specific post-rewriting
258
259
  //----------------------aggressive rewrites
260
125706
  if (new_ret.isNull() && d_aggr)
261
  {
262
110035
    new_ret = extendedRewriteAggr(ret);
263
  }
264
  //----------------------end aggressive rewrites
265
266
125706
  setCache(n, ret);
267
125706
  if (!new_ret.isNull())
268
  {
269
16380
    ret = extendedRewrite(new_ret);
270
  }
271
251412
  Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
272
125706
                               << std::endl;
273
125706
  if (Trace.isOn("q-ext-rewrite-nf"))
274
  {
275
    if (n == ret)
276
    {
277
      Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
278
    }
279
  }
280
125706
  setCache(n, ret);
281
125706
  return ret;
282
}
283
284
110035
Node ExtendedRewriter::extendedRewriteAggr(Node n) const
285
{
286
110035
  Node new_ret;
287
220070
  Trace("q-ext-rewrite-debug2")
288
110035
      << "Do aggressive rewrites on " << n << std::endl;
289
110035
  bool polarity = n.getKind() != NOT;
290
220070
  Node ret_atom = n.getKind() == NOT ? n[0] : n;
291
230713
  if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
292
324829
      || ret_atom.getKind() == GEQ)
293
  {
294
    // ITE term removal in polynomials
295
    // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
296
39036
    Trace("q-ext-rewrite-debug2")
297
19518
        << "Compute monomial sum " << ret_atom << std::endl;
298
    // compute monomial sum
299
39036
    std::map<Node, Node> msum;
300
19518
    if (ArithMSum::getMonomialSumLit(ret_atom, msum))
301
    {
302
65640
      for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
303
           ++itm)
304
      {
305
93271
        Node v = itm->first;
306
94298
        Trace("q-ext-rewrite-debug2")
307
47149
            << itm->first << " * " << itm->second << std::endl;
308
47149
        if (v.getKind() == ITE)
309
        {
310
11831
          Node veq;
311
6429
          int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
312
6429
          if (res != 0)
313
          {
314
12858
            Trace("q-ext-rewrite-debug")
315
6429
                << "  have ITE relation, solved form : " << veq << std::endl;
316
            // try pulling ITE
317
6429
            new_ret = extendedRewritePullIte(ITE, veq);
318
6429
            if (!new_ret.isNull())
319
            {
320
1027
              if (!polarity)
321
              {
322
                new_ret = new_ret.negate();
323
              }
324
1027
              break;
325
            }
326
          }
327
          else
328
          {
329
            Trace("q-ext-rewrite-debug")
330
                << "  failed to isolate " << v << " in " << n << std::endl;
331
          }
332
        }
333
      }
334
    }
335
    else
336
    {
337
      Trace("q-ext-rewrite-debug")
338
          << "  failed to get monomial sum of " << n << std::endl;
339
    }
340
  }
341
  // TODO (#1706) : conditional rewriting, condition merging
342
220070
  return new_ret;
343
}
344
345
21360
Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const
346
{
347
21360
  Assert(n.getKind() == itek);
348
21360
  Assert(n[1] != n[2]);
349
350
21360
  NodeManager* nm = NodeManager::currentNM();
351
352
21360
  Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
353
354
42720
  Node flip_cond;
355
21360
  if (n[0].getKind() == NOT)
356
  {
357
    flip_cond = n[0][0];
358
  }
359
21360
  else if (n[0].getKind() == OR)
360
  {
361
    // a | b ---> ~( ~a & ~b )
362
91
    flip_cond = TermUtil::simpleNegate(n[0]);
363
  }
364
21360
  if (!flip_cond.isNull())
365
  {
366
182
    Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
367
    // only print debug trace if full=true
368
91
    if (full)
369
    {
370
91
      debugExtendedRewrite(n, new_ret, "ITE flip");
371
    }
372
91
    return new_ret;
373
  }
374
  // Boolean true/false return
375
42538
  TypeNode tn = n.getType();
376
21269
  if (tn.isBoolean())
377
  {
378
31563
    for (unsigned i = 1; i <= 2; i++)
379
    {
380
21042
      if (n[i].isConst())
381
      {
382
        Node cond = i == 1 ? n[0] : n[0].negate();
383
        Node other = n[i == 1 ? 2 : 1];
384
        Kind retk = AND;
385
        if (n[i].getConst<bool>())
386
        {
387
          retk = OR;
388
        }
389
        else
390
        {
391
          cond = cond.negate();
392
        }
393
        Node new_ret = nm->mkNode(retk, cond, other);
394
        if (full)
395
        {
396
          // ite( A, true, B ) ---> A V B
397
          // ite( A, false, B ) ---> ~A /\ B
398
          // ite( A, B,  true ) ---> ~A V B
399
          // ite( A, B, false ) ---> A /\ B
400
          debugExtendedRewrite(n, new_ret, "ITE const return");
401
        }
402
        return new_ret;
403
      }
404
    }
405
  }
406
407
  // get entailed equalities in the condition
408
42538
  std::vector<Node> eq_conds;
409
21269
  Kind ck = n[0].getKind();
410
21269
  if (ck == EQUAL)
411
  {
412
7714
    eq_conds.push_back(n[0]);
413
  }
414
13555
  else if (ck == AND)
415
  {
416
4596
    for (const Node& cn : n[0])
417
    {
418
3601
      if (cn.getKind() == EQUAL)
419
      {
420
503
        eq_conds.push_back(cn);
421
      }
422
    }
423
  }
424
425
42538
  Node new_ret;
426
42538
  Node b;
427
42538
  Node e;
428
42538
  Node t1 = n[1];
429
42538
  Node t2 = n[2];
430
42538
  std::stringstream ss_reason;
431
432
29432
  for (const Node& eq : eq_conds)
433
  {
434
    // simple invariant ITE
435
24570
    for (unsigned i = 0; i <= 1; i++)
436
    {
437
      // ite( x = y ^ C, y, x ) ---> x
438
      // this is subsumed by the rewrites below
439
16407
      if (t2 == eq[i] && t1 == eq[1 - i])
440
      {
441
54
        new_ret = t2;
442
54
        ss_reason << "ITE simple rev subs";
443
54
        break;
444
      }
445
    }
446
8217
    if (!new_ret.isNull())
447
    {
448
54
      break;
449
    }
450
  }
451
21269
  if (new_ret.isNull())
452
  {
453
    // merging branches
454
63328
    for (unsigned i = 1; i <= 2; i++)
455
    {
456
42330
      if (n[i].getKind() == ITE)
457
      {
458
3412
        Node no = n[3 - i];
459
4799
        for (unsigned j = 1; j <= 2; j++)
460
        {
461
3310
          if (n[i][j] == no)
462
          {
463
            // e.g.
464
            // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
465
434
            Node nc1 = i == 2 ? n[0].negate() : n[0];
466
434
            Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
467
434
            Node new_cond = nm->mkNode(AND, nc1, nc2);
468
217
            new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
469
217
            ss_reason << "ITE merge branch";
470
217
            break;
471
          }
472
        }
473
      }
474
42330
      if (!new_ret.isNull())
475
      {
476
217
        break;
477
      }
478
    }
479
  }
480
481
21269
  if (new_ret.isNull() && d_aggr)
482
  {
483
    // If x is less than t based on an ordering, then we use { x -> t } as a
484
    // substitution to the children of ite( x = t ^ C, s, t ) below.
485
41988
    std::vector<Node> vars;
486
41988
    std::vector<Node> subs;
487
20994
    inferSubstitution(n[0], vars, subs, true);
488
489
20994
    if (!vars.empty())
490
    {
491
      // reverse substitution to opposite child
492
      // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
493
      // We can use ordinary substitute since the result of the substitution
494
      // is not being returned. In other words, nn is only being used to query
495
      // whether the second branch is a generalization of the first.
496
      Node nn =
497
41988
          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
498
20994
      if (nn != t2)
499
      {
500
3196
        nn = d_rew.rewrite(nn);
501
3196
        if (nn == t1)
502
        {
503
5
          new_ret = t2;
504
5
          ss_reason << "ITE rev subs";
505
        }
506
      }
507
508
      // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
509
      // must use partial substitute here, to avoid substitution into witness
510
41988
      std::map<Kind, bool> rkinds;
511
20994
      nn = partialSubstitute(t1, vars, subs, rkinds);
512
20994
      nn = d_rew.rewrite(nn);
513
20994
      if (nn != t1)
514
      {
515
        // If full=false, then we've duplicated a term u in the children of n.
516
        // For example, when ITE pulling, we have n is of the form:
517
        //   ite( C, f( u, t1 ), f( u, t2 ) )
518
        // We must show that at least one copy of u dissappears in this case.
519
1776
        if (nn == t2)
520
        {
521
14
          new_ret = nn;
522
14
          ss_reason << "ITE subs invariant";
523
        }
524
1762
        else if (full || nn.isConst())
525
        {
526
801
          new_ret = nm->mkNode(itek, n[0], nn, t2);
527
801
          ss_reason << "ITE subs";
528
        }
529
      }
530
    }
531
20994
    if (new_ret.isNull())
532
    {
533
      // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
534
      // use partial substitute to avoid substitution into witness
535
40348
      std::map<Node, Node> assign;
536
20174
      assign[n[0]] = d_false;
537
40348
      std::map<Kind, bool> rkinds;
538
40348
      Node nn = partialSubstitute(t2, assign, rkinds);
539
20174
      if (nn != t2)
540
      {
541
450
        nn = d_rew.rewrite(nn);
542
450
        if (nn == t1)
543
        {
544
3
          new_ret = nn;
545
3
          ss_reason << "ITE subs invariant false";
546
        }
547
447
        else if (full || nn.isConst())
548
        {
549
194
          new_ret = nm->mkNode(itek, n[0], t1, nn);
550
194
          ss_reason << "ITE subs false";
551
        }
552
      }
553
    }
554
  }
555
556
  // only print debug trace if full=true
557
21269
  if (!new_ret.isNull() && full)
558
  {
559
1116
    debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
560
  }
561
562
21269
  return new_ret;
563
}
564
565
16366
Node ExtendedRewriter::extendedRewriteAndOr(Node n) const
566
{
567
  // all the below rewrites are aggressive
568
16366
  if (!d_aggr)
569
  {
570
18
    return Node::null();
571
  }
572
32696
  Node new_ret;
573
  // we allow substitutions to recurse over any kind, except WITNESS which is
574
  // managed by partialSubstitute.
575
32696
  std::map<Kind, bool> bcp_kinds;
576
16348
  new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
577
16348
  if (!new_ret.isNull())
578
  {
579
1651
    debugExtendedRewrite(n, new_ret, "Bool bcp");
580
1651
    return new_ret;
581
  }
582
  // factoring
583
14697
  new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
584
14697
  if (!new_ret.isNull())
585
  {
586
416
    debugExtendedRewrite(n, new_ret, "Bool factoring");
587
416
    return new_ret;
588
  }
589
590
  // equality resolution
591
14281
  new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
592
14281
  debugExtendedRewrite(n, new_ret, "Bool eq res");
593
14281
  return new_ret;
594
}
595
596
117439
Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const
597
{
598
117439
  Assert(n.getKind() != ITE);
599
117439
  if (n.isClosure())
600
  {
601
    // don't pull ITE out of quantifiers
602
177
    return n;
603
  }
604
117262
  NodeManager* nm = NodeManager::currentNM();
605
234524
  TypeNode tn = n.getType();
606
234524
  std::vector<Node> children;
607
117262
  bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
608
117262
  if (hasOp)
609
  {
610
2986
    children.push_back(n.getOperator());
611
  }
612
117262
  unsigned nchildren = n.getNumChildren();
613
337381
  for (unsigned i = 0; i < nchildren; i++)
614
  {
615
220119
    children.push_back(n[i]);
616
  }
617
234524
  std::map<unsigned, std::map<unsigned, Node> > ite_c;
618
321184
  for (unsigned i = 0; i < nchildren; i++)
619
  {
620
    // only pull ITEs apart if we are aggressive
621
646221
    if (n[i].getKind() == itek
622
646221
        && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
623
    {
624
21285
      unsigned ii = hasOp ? i + 1 : i;
625
63855
      for (unsigned j = 0; j < 2; j++)
626
      {
627
42570
        children[ii] = n[i][j + 1];
628
85140
        Node pull = nm->mkNode(n.getKind(), children);
629
85140
        Node pullr = d_rew.rewrite(pull);
630
42570
        children[ii] = n[i];
631
42570
        ite_c[i][j] = pullr;
632
      }
633
21285
      if (ite_c[i][0] == ite_c[i][1])
634
      {
635
        // ITE dual invariance
636
        // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
637
        // f( t1..ite( A, s1, s2 )..tn ) ---> t
638
351
        debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
639
351
        return ite_c[i][0];
640
      }
641
20934
      if (d_aggr)
642
      {
643
40808
        if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
644
33983
            && !n[1 - i].getType().isBoolean() && tn.isBoolean())
645
        {
646
          // always pull variable or constant with binary (theory) predicate
647
          // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
648
17734
          Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
649
8867
          debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
650
8867
          return new_ret;
651
        }
652
32702
        for (unsigned j = 0; j < 2; j++)
653
        {
654
43541
          Node pullr = ite_c[i][j];
655
22904
          if (pullr.isConst() || pullr == n[i][j + 1])
656
          {
657
            // ITE single child elimination
658
            // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
659
            // implies
660
            // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
661
4534
            Node new_ret;
662
2267
            if (tn.isBoolean() && pullr.isConst())
663
            {
664
              // remove false/true child immediately
665
826
              bool pol = pullr.getConst<bool>();
666
1652
              std::vector<Node> new_children;
667
826
              new_children.push_back((j == 0) == pol ? n[i][0]
668
                                                     : n[i][0].negate());
669
826
              new_children.push_back(ite_c[i][1 - j]);
670
826
              new_ret = nm->mkNode(pol ? OR : AND, new_children);
671
826
              debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
672
            }
673
            else
674
            {
675
1441
              new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
676
1441
              debugExtendedRewrite(n, new_ret, "ITE single elim");
677
            }
678
2267
            return new_ret;
679
          }
680
        }
681
      }
682
    }
683
  }
684
105777
  if (d_aggr)
685
  {
686
114902
    for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
687
    {
688
19216
      Node nite = n[ip.first];
689
9775
      Assert(nite.getKind() == itek);
690
      // now, simply pull the ITE and try ITE rewrites
691
19216
      Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
692
9775
      pull_ite = d_rew.rewrite(pull_ite);
693
9775
      if (pull_ite.getKind() == ITE)
694
      {
695
19054
        Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
696
9613
        if (!new_pull_ite.isNull())
697
        {
698
172
          debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
699
172
          return new_pull_ite;
700
        }
701
      }
702
      else
703
      {
704
        // A general rewrite could eliminate the ITE by pulling.
705
        // An example is:
706
        //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
707
        //   ite( C, ~~x, ite( C, y, x ) ) --->
708
        //   x
709
        // where ~ is bitvector negation.
710
162
        debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
711
162
        return pull_ite;
712
      }
713
    }
714
  }
715
716
105443
  return Node::null();
717
}
718
719
12155
Node ExtendedRewriter::extendedRewriteNnf(Node ret) const
720
{
721
12155
  Assert(ret.getKind() == NOT);
722
723
12155
  Kind nk = ret[0].getKind();
724
12155
  bool neg_ch = false;
725
12155
  bool neg_ch_1 = false;
726
12155
  if (nk == AND || nk == OR)
727
  {
728
2400
    neg_ch = true;
729
2400
    nk = nk == AND ? OR : AND;
730
  }
731
9755
  else if (nk == IMPLIES)
732
  {
733
    neg_ch = true;
734
    neg_ch_1 = true;
735
    nk = AND;
736
  }
737
9755
  else if (nk == ITE)
738
  {
739
40
    neg_ch = true;
740
40
    neg_ch_1 = true;
741
  }
742
9715
  else if (nk == XOR)
743
  {
744
9
    nk = EQUAL;
745
  }
746
9706
  else if (nk == EQUAL && ret[0][0].getType().isBoolean())
747
  {
748
176
    neg_ch_1 = true;
749
  }
750
  else
751
  {
752
9530
    return Node::null();
753
  }
754
755
5250
  std::vector<Node> new_children;
756
11949
  for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
757
  {
758
18648
    Node c = ret[0][i];
759
9324
    c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
760
9324
    new_children.push_back(c);
761
  }
762
2625
  return NodeManager::currentNM()->mkNode(nk, new_children);
763
}
764
765
16348
Node ExtendedRewriter::extendedRewriteBcp(Kind andk,
766
                                          Kind ork,
767
                                          Kind notk,
768
                                          std::map<Kind, bool>& bcp_kinds,
769
                                          Node ret) const
770
{
771
16348
  Kind k = ret.getKind();
772
16348
  Assert(k == andk || k == ork);
773
16348
  Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
774
775
16348
  NodeManager* nm = NodeManager::currentNM();
776
777
32696
  TypeNode tn = ret.getType();
778
32696
  Node truen = TermUtil::mkTypeMaxValue(tn);
779
32696
  Node falsen = TermUtil::mkTypeValue(tn, 0);
780
781
  // terms to process
782
32696
  std::vector<Node> to_process;
783
70062
  for (const Node& cn : ret)
784
  {
785
53714
    to_process.push_back(cn);
786
  }
787
  // the processing terms
788
32696
  std::vector<Node> clauses;
789
  // the terms we have propagated information to
790
32696
  std::unordered_set<Node> prop_clauses;
791
  // the assignment
792
32696
  std::map<Node, Node> assign;
793
32696
  std::vector<Node> avars;
794
32696
  std::vector<Node> asubs;
795
796
16348
  Kind ok = k == andk ? ork : andk;
797
  // global polarity : when k=ork, everything is negated
798
16348
  bool gpol = k == andk;
799
800
1505
  do
801
  {
802
    // process the current nodes
803
53517
    while (!to_process.empty())
804
    {
805
35919
      std::vector<Node> new_to_process;
806
73473
      for (const Node& cn : to_process)
807
      {
808
55641
        Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
809
55641
        Kind cnk = cn.getKind();
810
55641
        bool pol = cnk != notk;
811
111027
        Node cln = cnk == notk ? cn[0] : cn;
812
55641
        Assert(cln.getKind() != notk);
813
55641
        if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
814
        {
815
          // flatten
816
721
          for (const Node& ccln : cln)
817
          {
818
974
            Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
819
487
            new_to_process.push_back(lccln);
820
          }
821
        }
822
        else
823
        {
824
          // add it to the assignment
825
110559
          Node val = gpol == pol ? truen : falsen;
826
55407
          std::map<Node, Node>::iterator it = assign.find(cln);
827
110814
          Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
828
55407
                               << std::endl;
829
55407
          if (it != assign.end())
830
          {
831
264
            if (val != it->second)
832
            {
833
255
              Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
834
              // a conflicting assignment: we are done
835
255
              return gpol ? falsen : truen;
836
            }
837
          }
838
          else
839
          {
840
55143
            assign[cln] = val;
841
55143
            avars.push_back(cln);
842
55143
            asubs.push_back(val);
843
          }
844
845
          // also, treat it as clause if possible
846
110304
          if (cln.getNumChildren() > 0
847
91178
              && (bcp_kinds.empty()
848
55152
                  || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
849
          {
850
72052
            if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
851
36026
                && prop_clauses.find(cn) == prop_clauses.end())
852
            {
853
36020
              Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
854
36020
              clauses.push_back(cn);
855
            }
856
          }
857
        }
858
      }
859
17832
      to_process.clear();
860
17832
      to_process.insert(
861
35664
          to_process.end(), new_to_process.begin(), new_to_process.end());
862
    }
863
864
    // apply substitution to all subterms of clauses
865
35196
    std::vector<Node> new_clauses;
866
55671
    for (const Node& c : clauses)
867
    {
868
38073
      bool cpol = c.getKind() != notk;
869
76146
      Node ca = c.getKind() == notk ? c[0] : c;
870
38073
      bool childChanged = false;
871
76146
      std::vector<Node> ccs_children;
872
127641
      for (const Node& cc : ca)
873
      {
874
        // always use partial substitute, to avoid substitution in witness
875
89568
        Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
876
        // substitution is only applicable to compatible kinds in bcp_kinds
877
179136
        Node ccs = partialSubstitute(cc, assign, bcp_kinds);
878
89568
        childChanged = childChanged || ccs != cc;
879
89568
        ccs_children.push_back(ccs);
880
      }
881
38073
      if (childChanged)
882
      {
883
1606
        if (ca.getMetaKind() == metakind::PARAMETERIZED)
884
        {
885
          ccs_children.insert(ccs_children.begin(), ca.getOperator());
886
        }
887
3212
        Node ccs = nm->mkNode(ca.getKind(), ccs_children);
888
1606
        ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
889
3212
        Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
890
1606
                             << std::endl;
891
1606
        ccs = d_rew.rewrite(ccs);
892
1606
        Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
893
1606
        to_process.push_back(ccs);
894
        // store this as a node that propagation touched. This marks c so that
895
        // it will not be included in the final construction.
896
1606
        prop_clauses.insert(ca);
897
      }
898
      else
899
      {
900
36467
        new_clauses.push_back(c);
901
      }
902
    }
903
17598
    clauses.clear();
904
17598
    clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
905
17598
  } while (!to_process.empty());
906
907
  // remake the node
908
16093
  if (!prop_clauses.empty())
909
  {
910
2792
    std::vector<Node> children;
911
8107
    for (std::pair<const Node, Node>& l : assign)
912
    {
913
13422
      Node a = l.first;
914
      // if propagation did not touch a
915
6711
      if (prop_clauses.find(a) == prop_clauses.end())
916
      {
917
5126
        Assert(l.second == truen || l.second == falsen);
918
10252
        Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
919
5126
        children.push_back(ln);
920
      }
921
    }
922
2792
    Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
923
1396
    Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
924
1396
    return new_ret;
925
  }
926
927
14697
  return Node::null();
928
}
929
930
14697
Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
931
                                                Kind ork,
932
                                                Kind notk,
933
                                                Node n) const
934
{
935
14697
  Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
936
14697
  NodeManager* nm = NodeManager::currentNM();
937
938
14697
  Kind nk = n.getKind();
939
14697
  Assert(nk == andk || nk == ork);
940
14697
  Kind onk = nk == andk ? ork : andk;
941
  // count the number of times atoms occur
942
29394
  std::map<Node, std::vector<Node> > lit_to_cl;
943
29394
  std::map<Node, std::vector<Node> > cl_to_lits;
944
62583
  for (const Node& nc : n)
945
  {
946
47886
    Kind nck = nc.getKind();
947
47886
    if (nck == onk)
948
    {
949
34005
      for (const Node& ncl : nc)
950
      {
951
78708
        if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
952
78708
            == lit_to_cl[ncl].end())
953
        {
954
26236
          lit_to_cl[ncl].push_back(nc);
955
26236
          cl_to_lits[nc].push_back(ncl);
956
        }
957
      }
958
    }
959
    else
960
    {
961
40117
      lit_to_cl[nc].push_back(nc);
962
40117
      cl_to_lits[nc].push_back(nc);
963
    }
964
  }
965
  // get the maximum shared literal to factor
966
14697
  unsigned max_size = 0;
967
29394
  Node flit;
968
75565
  for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
969
  {
970
60868
    if (ltc.second.size() > max_size)
971
    {
972
14923
      max_size = ltc.second.size();
973
14923
      flit = ltc.first;
974
    }
975
  }
976
14697
  if (max_size > 1)
977
  {
978
    // do the factoring
979
832
    std::vector<Node> children;
980
832
    std::vector<Node> fchildren;
981
416
    std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
982
416
    std::vector<Node>& cls = itl->second;
983
2116
    for (const Node& nc : n)
984
    {
985
1700
      if (std::find(cls.begin(), cls.end(), nc) == cls.end())
986
      {
987
353
        children.push_back(nc);
988
      }
989
      else
990
      {
991
        // rebuild
992
1347
        std::vector<Node>& lits = cl_to_lits[nc];
993
        std::vector<Node>::iterator itlfl =
994
1347
            std::find(lits.begin(), lits.end(), flit);
995
1347
        Assert(itlfl != lits.end());
996
1347
        lits.erase(itlfl);
997
        // rebuild
998
1347
        if (!lits.empty())
999
        {
1000
2694
          Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
1001
1347
          fchildren.push_back(new_cl);
1002
        }
1003
      }
1004
    }
1005
    // rebuild the factored children
1006
416
    Assert(!fchildren.empty());
1007
832
    Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
1008
416
    children.push_back(nm->mkNode(onk, flit, fcn));
1009
832
    Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
1010
416
    Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
1011
416
    return ret;
1012
  }
1013
  else
1014
  {
1015
14281
    Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
1016
  }
1017
14281
  return Node::null();
1018
}
1019
1020
14281
Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
1021
                                            Kind ork,
1022
                                            Kind eqk,
1023
                                            Kind notk,
1024
                                            std::map<Kind, bool>& bcp_kinds,
1025
                                            Node n,
1026
                                            bool isXor) const
1027
{
1028
14281
  Assert(n.getKind() == andk || n.getKind() == ork);
1029
14281
  Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
1030
1031
14281
  NodeManager* nm = NodeManager::currentNM();
1032
14281
  Kind nk = n.getKind();
1033
14281
  bool gpol = (nk == andk);
1034
57309
  for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
1035
  {
1036
86667
    Node lit = n[i];
1037
43639
    if (lit.getKind() == eqk)
1038
    {
1039
      // eq is the equality we are basing a substitution on
1040
13425
      Node eq;
1041
7018
      if (gpol == isXor)
1042
      {
1043
        // can only turn disequality into equality if types are the same
1044
2338
        if (lit[1].getType() == lit.getType())
1045
        {
1046
          // t != s ---> ~t = s
1047
36
          if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1048
          {
1049
28
            eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1050
          }
1051
          else
1052
          {
1053
8
            eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1054
          }
1055
        }
1056
      }
1057
      else
1058
      {
1059
4680
        eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1060
      }
1061
7018
      if (!eq.isNull())
1062
      {
1063
        // see if it corresponds to a substitution
1064
8821
        std::vector<Node> vars;
1065
8821
        std::vector<Node> subs;
1066
4716
        if (inferSubstitution(eq, vars, subs))
1067
        {
1068
3595
          Assert(vars.size() == 1);
1069
6579
          std::vector<Node> children;
1070
3595
          bool childrenChanged = false;
1071
          // apply to all other children
1072
34355
          for (unsigned j = 0; j < nchild; j++)
1073
          {
1074
61520
            Node ccs = n[j];
1075
30760
            if (i != j)
1076
            {
1077
              // Substitution is only applicable to compatible kinds. We always
1078
              // use the partialSubstitute method to avoid substitution into
1079
              // witness terms.
1080
27165
              ccs = partialSubstitute(ccs, vars, subs, bcp_kinds);
1081
27165
              childrenChanged = childrenChanged || n[j] != ccs;
1082
            }
1083
30760
            children.push_back(ccs);
1084
          }
1085
3595
          if (childrenChanged)
1086
          {
1087
611
            return nm->mkNode(nk, children);
1088
          }
1089
        }
1090
      }
1091
    }
1092
  }
1093
1094
13670
  return Node::null();
1095
}
1096
1097
/** sort pairs by their second (unsigned) argument */
1098
33436
static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1099
                           const std::pair<Node, unsigned>& b)
1100
{
1101
33436
  return (a.second < b.second);
1102
}
1103
1104
/** A simple subsumption trie used to compute pairwise list subsets */
1105
101368
class SimpSubsumeTrie
1106
{
1107
 public:
1108
  /** the children of this node */
1109
  std::map<Node, SimpSubsumeTrie> d_children;
1110
  /** the term at this node */
1111
  Node d_data;
1112
  /** add term to the trie
1113
   *
1114
   * This adds term c to this trie, whose atom list is alist. This adds terms
1115
   * s to subsumes such that the atom list of s is a subset of the atom list
1116
   * of c. For example, say:
1117
   *   c1.alist = { A }
1118
   *   c2.alist = { C }
1119
   *   c3.alist = { B, C }
1120
   *   c4.alist = { A, B, D }
1121
   *   c5.alist = { A, B, C }
1122
   * If these terms are added in the order c1, c2, c3, c4, c5, then:
1123
   *   addTerm c1 results in subsumes = {}
1124
   *   addTerm c2 results in subsumes = {}
1125
   *   addTerm c3 results in subsumes = { c2 }
1126
   *   addTerm c4 results in subsumes = { c1 }
1127
   *   addTerm c5 results in subsumes = { c1, c2, c3 }
1128
   * Notice that the intended use case of this trie is to add term t before t'
1129
   * only when size( t.alist ) <= size( t'.alist ).
1130
   *
1131
   * The last two arguments describe the state of the path [t0...tn] we
1132
   * have followed in the trie during the recursive call.
1133
   * If doAdd = true,
1134
   *   then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1135
   *   we add c as the current node of this trie.
1136
   * If doAdd = false,
1137
   *   then t1...tn occur in alist.
1138
   */
1139
67386
  void addTerm(Node c,
1140
               std::vector<Node>& alist,
1141
               std::vector<Node>& subsumes,
1142
               unsigned index = 0,
1143
               bool doAdd = true)
1144
  {
1145
67386
    if (!d_data.isNull())
1146
    {
1147
12
      subsumes.push_back(d_data);
1148
    }
1149
67386
    if (doAdd)
1150
    {
1151
67368
      if (index == alist.size())
1152
      {
1153
33320
        d_data = c;
1154
33320
        return;
1155
      }
1156
    }
1157
    // try all children where we have this atom
1158
50873
    for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1159
    {
1160
16807
      if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1161
      {
1162
18
        cp.second.addTerm(c, alist, subsumes, 0, false);
1163
      }
1164
    }
1165
34066
    if (doAdd)
1166
    {
1167
34048
      d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1168
    }
1169
  }
1170
};
1171
1172
16747
Node ExtendedRewriter::extendedRewriteEqChain(
1173
    Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) const
1174
{
1175
16747
  Assert(ret.getKind() == eqk);
1176
1177
  // this rewrite is aggressive; it in fact has the precondition that other
1178
  // aggressive rewrites (including BCP) have been applied.
1179
16747
  if (!d_aggr)
1180
  {
1181
96
    return Node::null();
1182
  }
1183
1184
16651
  NodeManager* nm = NodeManager::currentNM();
1185
1186
33302
  TypeNode tn = ret[0].getType();
1187
1188
  // sort/cancelling for Boolean EQUAL/XOR-chains
1189
16651
  Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1190
1191
  // get the children on either side
1192
16651
  bool gpol = true;
1193
33302
  std::vector<Node> children;
1194
49953
  for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1195
  {
1196
66604
    Node curr = ret[r];
1197
    // assume, if necessary, right associative
1198
33558
    while (curr.getKind() == eqk && curr[0].getType() == tn)
1199
    {
1200
128
      children.push_back(curr[0]);
1201
128
      curr = curr[1];
1202
    }
1203
33302
    children.push_back(curr);
1204
  }
1205
1206
33302
  std::map<Node, bool> cstatus;
1207
  // add children to status
1208
50081
  for (const Node& c : children)
1209
  {
1210
66860
    Node a = c;
1211
33430
    if (a.getKind() == notk)
1212
    {
1213
440
      gpol = !gpol;
1214
440
      a = a[0];
1215
    }
1216
33430
    Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1217
33430
    std::map<Node, bool>::iterator itc = cstatus.find(a);
1218
33430
    if (itc == cstatus.end())
1219
    {
1220
33398
      cstatus[a] = true;
1221
    }
1222
    else
1223
    {
1224
      // cancels
1225
32
      cstatus.erase(a);
1226
32
      if (isXor)
1227
      {
1228
        gpol = !gpol;
1229
      }
1230
    }
1231
  }
1232
16651
  Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1233
1234
16651
  if (cstatus.empty())
1235
  {
1236
2
    return TermUtil::mkTypeConst(tn, gpol);
1237
  }
1238
1239
16649
  children.clear();
1240
1241
  // compute the atoms of each child
1242
16649
  Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1243
16649
  Trace("ext-rew-eqchain") << "  eqchain-simplify: get atoms...\n";
1244
33298
  std::map<Node, std::map<Node, bool> > atoms;
1245
33298
  std::map<Node, std::vector<Node> > alist;
1246
33298
  std::vector<std::pair<Node, unsigned> > atom_count;
1247
50015
  for (std::pair<const Node, bool>& cp : cstatus)
1248
  {
1249
33366
    if (!cp.second)
1250
    {
1251
      // already eliminated
1252
      continue;
1253
    }
1254
66732
    Node c = cp.first;
1255
33366
    Kind ck = c.getKind();
1256
33366
    Trace("ext-rew-eqchain") << "  process c = " << c << std::endl;
1257
33366
    if (ck == andk || ck == ork)
1258
    {
1259
1397
      for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1260
      {
1261
2147
        Node cl = c[j];
1262
1085
        bool pol = cl.getKind() != notk;
1263
2147
        Node ca = pol ? cl : cl[0];
1264
1085
        bool newVal = (ck == andk ? !pol : pol);
1265
2170
        Trace("ext-rew-eqchain")
1266
1085
            << "  atoms(" << c << ", " << ca << ") = " << newVal << std::endl;
1267
1085
        Assert(atoms[c].find(ca) == atoms[c].end());
1268
        // polarity is flipped when we are AND
1269
1085
        atoms[c][ca] = newVal;
1270
1085
        alist[c].push_back(ca);
1271
1272
        // if this already exists as a child of the equality chain, eliminate.
1273
        // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1274
        // consider ( x & y ) a unit, whereas below it is expanded to
1275
        // ~( ~x | ~y ).
1276
1085
        std::map<Node, bool>::iterator itc = cstatus.find(ca);
1277
1085
        if (itc != cstatus.end() && itc->second)
1278
        {
1279
          // cancel it
1280
23
          cstatus[ca] = false;
1281
23
          cstatus[c] = false;
1282
          // make new child
1283
          // x = ( y | ~x ) ---> y & x
1284
          // x = ( y | x ) ---> ~y | x
1285
          // x = ( y & x ) ---> y | ~x
1286
          // x = ( y & ~x ) ---> ~y & ~x
1287
46
          std::vector<Node> new_children;
1288
83
          for (unsigned k = 0, nchildc = c.getNumChildren(); k < nchildc; k++)
1289
          {
1290
60
            if (j != k)
1291
            {
1292
37
              new_children.push_back(c[k]);
1293
            }
1294
          }
1295
46
          Node nc[2];
1296
23
          nc[0] = c[j];
1297
23
          nc[1] = new_children.size() == 1 ? new_children[0]
1298
                                           : nm->mkNode(ck, new_children);
1299
          // negate the proper child
1300
23
          unsigned nindex = (ck == andk) == pol ? 0 : 1;
1301
23
          nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1302
23
          Kind nk = pol ? ork : andk;
1303
          // store as new child
1304
23
          children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1305
23
          if (isXor)
1306
          {
1307
            gpol = !gpol;
1308
          }
1309
23
          break;
1310
        }
1311
335
      }
1312
    }
1313
    else
1314
    {
1315
33031
      bool pol = ck != notk;
1316
66062
      Node ca = pol ? c : c[0];
1317
33031
      atoms[c][ca] = pol;
1318
66062
      Trace("ext-rew-eqchain")
1319
33031
          << "  atoms(" << c << ", " << ca << ") = " << pol << std::endl;
1320
33031
      alist[c].push_back(ca);
1321
    }
1322
33366
    atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1323
  }
1324
  // sort the atoms in each atom list
1325
50015
  for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1326
50015
       it != alist.end();
1327
       ++it)
1328
  {
1329
33366
    std::sort(it->second.begin(), it->second.end());
1330
  }
1331
  // check subsumptions
1332
  // sort by #atoms
1333
16649
  std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1334
16649
  if (Trace.isOn("ext-rew-eqchain"))
1335
  {
1336
    for (const std::pair<Node, unsigned>& ac : atom_count)
1337
    {
1338
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << ac.first << " has "
1339
                               << ac.second << " atoms." << std::endl;
1340
    }
1341
    Trace("ext-rew-eqchain") << "  eqchain-simplify: compute subsumptions...\n";
1342
  }
1343
33298
  SimpSubsumeTrie sst;
1344
50015
  for (std::pair<const Node, bool>& cp : cstatus)
1345
  {
1346
33366
    if (!cp.second)
1347
    {
1348
      // already eliminated
1349
46
      continue;
1350
    }
1351
66640
    Node c = cp.first;
1352
33320
    std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1353
33320
    Assert(itc != atoms.end());
1354
66640
    Trace("ext-rew-eqchain") << "  - add term " << c << " with atom list "
1355
33320
                             << alist[c] << "...\n";
1356
66640
    std::vector<Node> subsumes;
1357
33320
    sst.addTerm(c, alist[c], subsumes);
1358
33320
    for (const Node& cc : subsumes)
1359
    {
1360
4
      if (!cstatus[cc])
1361
      {
1362
        // subsumes a child that was already eliminated
1363
        continue;
1364
      }
1365
8
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << c << " subsumes "
1366
4
                               << cc << std::endl;
1367
      // for each of the atoms in cc
1368
4
      std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1369
4
      Assert(itcc != atoms.end());
1370
4
      std::vector<Node> common_children;
1371
4
      std::vector<Node> diff_children;
1372
12
      for (const std::pair<const Node, bool>& ap : itcc->second)
1373
      {
1374
        // compare the polarity
1375
16
        Node a = ap.first;
1376
8
        bool polcc = ap.second;
1377
8
        Assert(itc->second.find(a) != itc->second.end());
1378
8
        bool polc = itc->second[a];
1379
16
        Trace("ext-rew-eqchain") << "    eqchain-simplify: atom " << a
1380
8
                                 << " has polarities : " << polc << " " << polcc
1381
8
                                 << "\n";
1382
16
        Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1383
8
        if (polc != polcc)
1384
        {
1385
2
          diff_children.push_back(lit);
1386
        }
1387
        else
1388
        {
1389
6
          common_children.push_back(lit);
1390
        }
1391
      }
1392
4
      std::vector<Node> rem_children;
1393
12
      for (const std::pair<const Node, bool>& ap : itc->second)
1394
      {
1395
16
        Node a = ap.first;
1396
8
        if (atoms[cc].find(a) == atoms[cc].end())
1397
        {
1398
          bool polc = ap.second;
1399
          rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1400
        }
1401
      }
1402
8
      Trace("ext-rew-eqchain")
1403
8
          << "    #common/diff/rem: " << common_children.size() << "/"
1404
4
          << diff_children.size() << "/" << rem_children.size() << "\n";
1405
4
      bool do_rewrite = false;
1406
9
      if (common_children.empty() && itc->second.size() == itcc->second.size()
1407
5
          && itcc->second.size() == 2)
1408
      {
1409
        // x | y = ~x | ~y ---> ~( x = y )
1410
1
        do_rewrite = true;
1411
1
        children.push_back(diff_children[0]);
1412
1
        children.push_back(diff_children[1]);
1413
1
        gpol = !gpol;
1414
1
        Trace("ext-rew-eqchain") << "    apply 2-child all-diff\n";
1415
      }
1416
3
      else if (common_children.empty() && diff_children.size() == 1)
1417
      {
1418
        do_rewrite = true;
1419
        // x = ( ~x | y ) ---> ~( ~x | ~y )
1420
        Node remn = rem_children.size() == 1 ? rem_children[0]
1421
                                             : nm->mkNode(ork, rem_children);
1422
        remn = TermUtil::mkNegate(notk, remn);
1423
        children.push_back(nm->mkNode(ork, diff_children[0], remn));
1424
        if (!isXor)
1425
        {
1426
          gpol = !gpol;
1427
        }
1428
        Trace("ext-rew-eqchain") << "    apply unit resolution\n";
1429
      }
1430
6
      else if (diff_children.size() == 1
1431
3
               && itc->second.size() == itcc->second.size())
1432
      {
1433
        // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1434
        do_rewrite = true;
1435
        Assert(!common_children.empty());
1436
        Node comn = common_children.size() == 1
1437
                        ? common_children[0]
1438
                        : nm->mkNode(ork, common_children);
1439
        children.push_back(comn);
1440
        if (isXor)
1441
        {
1442
          gpol = !gpol;
1443
        }
1444
        Trace("ext-rew-eqchain") << "    apply resolution\n";
1445
      }
1446
3
      else if (diff_children.empty())
1447
      {
1448
3
        do_rewrite = true;
1449
3
        if (rem_children.empty())
1450
        {
1451
          // x | y = x | y ---> true
1452
          // this can happen if we have ( ~x & ~y ) = ( x | y )
1453
3
          children.push_back(TermUtil::mkTypeMaxValue(tn));
1454
3
          if (isXor)
1455
          {
1456
            gpol = !gpol;
1457
          }
1458
3
          Trace("ext-rew-eqchain") << "    apply cancel\n";
1459
        }
1460
        else
1461
        {
1462
          // x | y = ( x | y | z ) ---> ( x | y | ~z )
1463
          Node remn = rem_children.size() == 1 ? rem_children[0]
1464
                                               : nm->mkNode(ork, rem_children);
1465
          remn = TermUtil::mkNegate(notk, remn);
1466
          Node comn = common_children.size() == 1
1467
                          ? common_children[0]
1468
                          : nm->mkNode(ork, common_children);
1469
          children.push_back(nm->mkNode(ork, comn, remn));
1470
          if (isXor)
1471
          {
1472
            gpol = !gpol;
1473
          }
1474
          Trace("ext-rew-eqchain") << "    apply subsume\n";
1475
        }
1476
      }
1477
4
      if (do_rewrite)
1478
      {
1479
        // eliminate the children, reverse polarity as needed
1480
12
        for (unsigned r = 0; r < 2; r++)
1481
        {
1482
16
          Node c_rem = r == 0 ? c : cc;
1483
8
          cstatus[c_rem] = false;
1484
8
          if (c_rem.getKind() == andk)
1485
          {
1486
5
            gpol = !gpol;
1487
          }
1488
        }
1489
4
        break;
1490
      }
1491
    }
1492
  }
1493
16649
  Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1494
1495
  // sorted right associative chain
1496
16649
  bool has_nvar = false;
1497
16649
  unsigned nvar_index = 0;
1498
50015
  for (std::pair<const Node, bool>& cp : cstatus)
1499
  {
1500
33366
    if (cp.second)
1501
    {
1502
33312
      if (!cp.first.isVar())
1503
      {
1504
22101
        has_nvar = true;
1505
22101
        nvar_index = children.size();
1506
      }
1507
33312
      children.push_back(cp.first);
1508
    }
1509
  }
1510
16649
  std::sort(children.begin(), children.end());
1511
1512
33298
  Node new_ret;
1513
16649
  if (!gpol)
1514
  {
1515
    // negate the constant child if it exists
1516
264
    unsigned nindex = has_nvar ? nvar_index : 0;
1517
264
    children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1518
  }
1519
16649
  new_ret = children.back();
1520
16649
  unsigned index = children.size() - 1;
1521
50031
  while (index > 0)
1522
  {
1523
16691
    index--;
1524
16691
    new_ret = nm->mkNode(eqk, children[index], new_ret);
1525
  }
1526
16649
  new_ret = d_rew.rewrite(new_ret);
1527
16649
  if (new_ret != ret)
1528
  {
1529
269
    return new_ret;
1530
  }
1531
16380
  return Node::null();
1532
}
1533
1534
157901
Node ExtendedRewriter::partialSubstitute(
1535
    Node n,
1536
    const std::map<Node, Node>& assign,
1537
    const std::map<Kind, bool>& rkinds) const
1538
{
1539
315802
  std::unordered_map<TNode, Node> visited;
1540
157901
  std::unordered_map<TNode, Node>::iterator it;
1541
157901
  std::map<Node, Node>::const_iterator ita;
1542
315802
  std::vector<TNode> visit;
1543
315802
  TNode cur;
1544
157901
  visit.push_back(n);
1545
1527975
  do
1546
  {
1547
1685876
    cur = visit.back();
1548
1685876
    visit.pop_back();
1549
1685876
    it = visited.find(cur);
1550
1551
1685876
    if (it == visited.end())
1552
    {
1553
726499
      ita = assign.find(cur);
1554
726499
      if (ita != assign.end())
1555
      {
1556
5721
        visited[cur] = ita->second;
1557
      }
1558
      else
1559
      {
1560
        // If rkinds is non-empty, then can only recurse on its kinds.
1561
        // We also always disallow substitution into witness. Notice that
1562
        // we disallow witness here, due to unsoundness when applying contextual
1563
        // substitutions over witness terms (see #4620).
1564
720778
        Kind k = cur.getKind();
1565
720778
        if (k != WITNESS && (rkinds.empty() || rkinds.find(k) != rkinds.end()))
1566
        {
1567
720778
          visited[cur] = Node::null();
1568
720778
          visit.push_back(cur);
1569
1527975
          for (const Node& cn : cur)
1570
          {
1571
807197
            visit.push_back(cn);
1572
          }
1573
        }
1574
        else
1575
        {
1576
          visited[cur] = cur;
1577
        }
1578
      }
1579
    }
1580
959377
    else if (it->second.isNull())
1581
    {
1582
1441556
      Node ret = cur;
1583
720778
      bool childChanged = false;
1584
1441556
      std::vector<Node> children;
1585
720778
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
1586
      {
1587
14810
        children.push_back(cur.getOperator());
1588
      }
1589
1527975
      for (const Node& cn : cur)
1590
      {
1591
807197
        it = visited.find(cn);
1592
807197
        Assert(it != visited.end());
1593
807197
        Assert(!it->second.isNull());
1594
807197
        childChanged = childChanged || cn != it->second;
1595
807197
        children.push_back(it->second);
1596
      }
1597
720778
      if (childChanged)
1598
      {
1599
12280
        ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1600
      }
1601
720778
      visited[cur] = ret;
1602
    }
1603
1685876
  } while (!visit.empty());
1604
157901
  Assert(visited.find(n) != visited.end());
1605
157901
  Assert(!visited.find(n)->second.isNull());
1606
315802
  return visited[n];
1607
}
1608
1609
48159
Node ExtendedRewriter::partialSubstitute(
1610
    Node n,
1611
    const std::vector<Node>& vars,
1612
    const std::vector<Node>& subs,
1613
    const std::map<Kind, bool>& rkinds) const
1614
{
1615
48159
  Assert(vars.size() == subs.size());
1616
96318
  std::map<Node, Node> assign;
1617
98830
  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
1618
  {
1619
50671
    assign[vars[i]] = subs[i];
1620
  }
1621
96318
  return partialSubstitute(n, assign, rkinds);
1622
}
1623
1624
12798
Node ExtendedRewriter::solveEquality(Node n) const
1625
{
1626
  // TODO (#1706) : implement
1627
12798
  Assert(n.getKind() == EQUAL);
1628
1629
12798
  return Node::null();
1630
}
1631
1632
29186
bool ExtendedRewriter::inferSubstitution(Node n,
1633
                                         std::vector<Node>& vars,
1634
                                         std::vector<Node>& subs,
1635
                                         bool usePred) const
1636
{
1637
29186
  if (n.getKind() == AND)
1638
  {
1639
964
    bool ret = false;
1640
4440
    for (const Node& nc : n)
1641
    {
1642
3476
      bool cret = inferSubstitution(nc, vars, subs, usePred);
1643
3476
      ret = ret || cret;
1644
    }
1645
964
    return ret;
1646
  }
1647
28222
  if (n.getKind() == EQUAL)
1648
  {
1649
    // see if it can be put into form x = y
1650
14054
    Node slv_eq = solveEquality(n);
1651
12798
    if (!slv_eq.isNull())
1652
    {
1653
      n = slv_eq;
1654
    }
1655
38394
    Node v[2];
1656
29350
    for (unsigned i = 0; i < 2; i++)
1657
    {
1658
25509
      if (n[i].isConst())
1659
      {
1660
8957
        vars.push_back(n[1 - i]);
1661
8957
        subs.push_back(n[i]);
1662
8957
        return true;
1663
      }
1664
16552
      if (n[i].isVar())
1665
      {
1666
12678
        v[i] = n[i];
1667
      }
1668
3874
      else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1669
      {
1670
22
        v[i] = n[i][0];
1671
      }
1672
    }
1673
6405
    for (unsigned i = 0; i < 2; i++)
1674
    {
1675
7713
      TNode r1 = v[i];
1676
7713
      Node r2 = v[1 - i];
1677
5149
      if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1678
      {
1679
2585
        r2 = n[1 - i];
1680
2585
        if (v[i] != n[i])
1681
        {
1682
13
          Assert(TermUtil::isNegate(n[i].getKind()));
1683
13
          r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1684
        }
1685
        // TODO (#1706) : union find
1686
2585
        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1687
        {
1688
2585
          vars.push_back(r1);
1689
2585
          subs.push_back(r2);
1690
2585
          return true;
1691
        }
1692
      }
1693
    }
1694
  }
1695
16680
  if (usePred)
1696
  {
1697
15559
    bool negated = n.getKind() == NOT;
1698
15559
    vars.push_back(negated ? n[0] : n);
1699
15559
    subs.push_back(negated ? d_false : d_true);
1700
15559
    return true;
1701
  }
1702
1121
  return false;
1703
}
1704
1705
4131
Node ExtendedRewriter::extendedRewriteStrings(Node ret) const
1706
{
1707
4131
  Node new_ret;
1708
8262
  Trace("q-ext-rewrite-debug")
1709
4131
      << "Extended rewrite strings : " << ret << std::endl;
1710
1711
4131
  if (ret.getKind() == EQUAL)
1712
  {
1713
456
    strings::SequencesRewriter sr(&d_rew, nullptr);
1714
228
    new_ret = sr.rewriteEqualityExt(ret);
1715
  }
1716
1717
4131
  return new_ret;
1718
}
1719
1720
58460
void ExtendedRewriter::debugExtendedRewrite(Node n,
1721
                                            Node ret,
1722
                                            const char* c) const
1723
{
1724
58460
  if (Trace.isOn("q-ext-rewrite"))
1725
  {
1726
    if (!ret.isNull())
1727
    {
1728
      Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1729
      Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1730
                             << " rewrites to " << ret << std::endl;
1731
    }
1732
  }
1733
58460
}
1734
1735
}  // namespace quantifiers
1736
}  // namespace theory
1737
22746
}  // namespace cvc5