GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/extended_rewrite.cpp Lines: 803 863 93.0 %
Date: 2021-11-07 Branches: 1963 3928 50.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of extended rewriting techniques.
14
 */
15
16
#include "theory/quantifiers/extended_rewrite.h"
17
18
#include <sstream>
19
20
#include "theory/arith/arith_msum.h"
21
#include "theory/bv/theory_bv_utils.h"
22
#include "theory/datatypes/datatypes_rewriter.h"
23
#include "theory/quantifiers/term_util.h"
24
#include "theory/rewriter.h"
25
#include "theory/strings/arith_entail.h"
26
#include "theory/strings/sequences_rewriter.h"
27
#include "theory/strings/word.h"
28
#include "theory/theory.h"
29
30
using namespace cvc5::kind;
31
using namespace std;
32
33
namespace cvc5 {
34
namespace theory {
35
namespace quantifiers {
36
37
struct ExtRewriteAttributeId
38
{
39
};
40
typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
41
42
struct ExtRewriteAggAttributeId
43
{
44
};
45
typedef expr::Attribute<ExtRewriteAggAttributeId, Node> ExtRewriteAggAttribute;
46
47
483454
ExtendedRewriter::ExtendedRewriter(Rewriter& rew, bool aggr)
48
483454
    : d_rew(rew), d_aggr(aggr)
49
{
50
483454
  d_true = NodeManager::currentNM()->mkConst(true);
51
483454
  d_false = NodeManager::currentNM()->mkConst(false);
52
483454
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
53
483454
}
54
55
698521
void ExtendedRewriter::setCache(Node n, Node ret) const
56
{
57
698521
  if (d_aggr)
58
  {
59
    ExtRewriteAggAttribute erga;
60
697873
    n.setAttribute(erga, ret);
61
  }
62
  else
63
  {
64
    ExtRewriteAttribute era;
65
648
    n.setAttribute(era, ret);
66
  }
67
698521
}
68
69
1441293
Node ExtendedRewriter::getCache(Node n) const
70
{
71
1441293
  if (d_aggr)
72
  {
73
1440781
    if (n.hasAttribute(ExtRewriteAggAttribute()))
74
    {
75
1086649
      return n.getAttribute(ExtRewriteAggAttribute());
76
    }
77
  }
78
  else
79
  {
80
512
    if (n.hasAttribute(ExtRewriteAttribute()))
81
    {
82
188
      return n.getAttribute(ExtRewriteAttribute());
83
    }
84
  }
85
354456
  return Node::null();
86
}
87
88
922253
bool ExtendedRewriter::addToChildren(Node nc,
89
                                     std::vector<Node>& children,
90
                                     bool dropDup) const
91
{
92
  // If the operator is non-additive, do not consider duplicates
93
922253
  if (dropDup
94
922253
      && std::find(children.begin(), children.end(), nc) != children.end())
95
  {
96
2525
    return false;
97
  }
98
919728
  children.push_back(nc);
99
919728
  return true;
100
}
101
102
1441293
Node ExtendedRewriter::extendedRewrite(Node n) const
103
{
104
1441293
  n = d_rew.rewrite(n);
105
106
  // has it already been computed?
107
2882586
  Node ncache = getCache(n);
108
1441293
  if (!ncache.isNull())
109
  {
110
1086837
    return ncache;
111
  }
112
113
708912
  Node ret = n;
114
354456
  NodeManager* nm = NodeManager::currentNM();
115
116
  //--------------------pre-rewrite
117
354456
  if (d_aggr)
118
  {
119
697873
    Node pre_new_ret;
120
354132
    if (ret.getKind() == IMPLIES)
121
    {
122
1783
      pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
123
1783
      debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
124
    }
125
352349
    else if (ret.getKind() == XOR)
126
    {
127
429
      pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
128
429
      debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
129
    }
130
351920
    else if (ret.getKind() == NOT)
131
    {
132
33746
      pre_new_ret = extendedRewriteNnf(ret);
133
33746
      debugExtendedRewrite(ret, pre_new_ret, "NNF");
134
    }
135
354132
    if (!pre_new_ret.isNull())
136
    {
137
10391
      ret = extendedRewrite(pre_new_ret);
138
139
20782
      Trace("q-ext-rewrite-debug")
140
10391
          << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
141
10391
      setCache(n, ret);
142
10391
      return ret;
143
    }
144
  }
145
  //--------------------end pre-rewrite
146
147
  //--------------------rewrite children
148
344065
  if (n.getNumChildren() > 0)
149
  {
150
629458
    std::vector<Node> children;
151
314729
    if (n.getMetaKind() == metakind::PARAMETERIZED)
152
    {
153
9380
      children.push_back(n.getOperator());
154
    }
155
314729
    Kind k = n.getKind();
156
314729
    bool childChanged = false;
157
314729
    bool isNonAdditive = TermUtil::isNonAdditive(k);
158
    // We flatten associative operators below, which requires k to be n-ary.
159
314729
    bool isAssoc = TermUtil::isAssoc(k, true);
160
1221004
    for (unsigned i = 0; i < n.getNumChildren(); i++)
161
    {
162
1812550
      Node nc = extendedRewrite(n[i]);
163
906275
      childChanged = nc != n[i] || childChanged;
164
906275
      if (isAssoc && nc.getKind() == n.getKind())
165
      {
166
26400
        for (const Node& ncc : nc)
167
        {
168
21189
          if (!addToChildren(ncc, children, isNonAdditive))
169
          {
170
47
            childChanged = true;
171
          }
172
        }
173
      }
174
901064
      else if (!addToChildren(nc, children, isNonAdditive))
175
      {
176
2478
        childChanged = true;
177
      }
178
    }
179
314729
    Assert(!children.empty());
180
    // Some commutative operators have rewriters that are agnostic to order,
181
    // thus, we sort here.
182
314729
    if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
183
    {
184
184173
      childChanged = true;
185
184173
      std::sort(children.begin(), children.end());
186
    }
187
314729
    if (childChanged)
188
    {
189
218702
      if (isNonAdditive && children.size() == 1)
190
      {
191
        // we may have subsumed children down to one
192
85
        ret = children[0];
193
      }
194
218617
      else if (isAssoc
195
218617
               && children.size() > kind::metakind::getMaxArityForKind(k))
196
      {
197
2
        Assert(kind::metakind::getMaxArityForKind(k) >= 2);
198
        // kind may require binary construction
199
2
        ret = children[0];
200
6
        for (unsigned i = 1, nchild = children.size(); i < nchild; i++)
201
        {
202
4
          ret = nm->mkNode(k, ret, children[i]);
203
        }
204
      }
205
      else
206
      {
207
218615
        ret = nm->mkNode(k, children);
208
      }
209
    }
210
  }
211
344065
  ret = d_rew.rewrite(ret);
212
  //--------------------end rewrite children
213
214
  // now, do extended rewrite
215
688130
  Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
216
344065
                               << " (from " << n << ")" << std::endl;
217
688130
  Node new_ret;
218
219
  //---------------------- theory-independent post-rewriting
220
344065
  if (ret.getKind() == ITE)
221
  {
222
26967
    new_ret = extendedRewriteIte(ITE, ret);
223
  }
224
317098
  else if (ret.getKind() == AND || ret.getKind() == OR)
225
  {
226
80107
    new_ret = extendedRewriteAndOr(ret);
227
  }
228
236991
  else if (ret.getKind() == EQUAL)
229
  {
230
41789
    new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
231
41789
    debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
232
  }
233
344065
  Assert(new_ret.isNull() || new_ret != ret);
234
344065
  if (new_ret.isNull() && ret.getKind() != ITE)
235
  {
236
    // simple ITE pulling
237
303666
    new_ret = extendedRewritePullIte(ITE, ret);
238
  }
239
  //----------------------end theory-independent post-rewriting
240
241
  //----------------------theory-specific post-rewriting
242
344065
  if (new_ret.isNull())
243
  {
244
    TheoryId tid;
245
304343
    if (ret.getKind() == ITE)
246
    {
247
24358
      tid = Theory::theoryOf(ret.getType());
248
    }
249
    else
250
    {
251
279985
      tid = Theory::theoryOf(ret);
252
    }
253
608686
    Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
254
304343
                                 << std::endl;
255
304343
    if (tid == THEORY_STRINGS)
256
    {
257
7835
      new_ret = extendedRewriteStrings(ret);
258
    }
259
  }
260
  //----------------------end theory-specific post-rewriting
261
262
  //----------------------aggressive rewrites
263
344065
  if (new_ret.isNull() && d_aggr)
264
  {
265
303657
    new_ret = extendedRewriteAggr(ret);
266
  }
267
  //----------------------end aggressive rewrites
268
269
344065
  setCache(n, ret);
270
344065
  if (!new_ret.isNull())
271
  {
272
41173
    ret = extendedRewrite(new_ret);
273
  }
274
688130
  Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
275
344065
                               << std::endl;
276
344065
  if (Trace.isOn("q-ext-rewrite-nf"))
277
  {
278
    if (n == ret)
279
    {
280
      Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
281
    }
282
  }
283
344065
  setCache(n, ret);
284
344065
  return ret;
285
}
286
287
303657
Node ExtendedRewriter::extendedRewriteAggr(Node n) const
288
{
289
303657
  Node new_ret;
290
607314
  Trace("q-ext-rewrite-debug2")
291
303657
      << "Do aggressive rewrites on " << n << std::endl;
292
303657
  bool polarity = n.getKind() != NOT;
293
607314
  Node ret_atom = n.getKind() == NOT ? n[0] : n;
294
643527
  if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
295
892546
      || ret_atom.getKind() == GEQ)
296
  {
297
    // ITE term removal in polynomials
298
    // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
299
94648
    Trace("q-ext-rewrite-debug2")
300
47324
        << "Compute monomial sum " << ret_atom << std::endl;
301
    // compute monomial sum
302
94648
    std::map<Node, Node> msum;
303
47324
    if (ArithMSum::getMonomialSumLit(ret_atom, msum))
304
    {
305
154123
      for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
306
           ++itm)
307
      {
308
214681
        Node v = itm->first;
309
215764
        Trace("q-ext-rewrite-debug2")
310
107882
            << itm->first << " * " << itm->second << std::endl;
311
107882
        if (v.getKind() == ITE)
312
        {
313
14501
          Node veq;
314
7792
          int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
315
7792
          if (res != 0)
316
          {
317
15578
            Trace("q-ext-rewrite-debug")
318
7789
                << "  have ITE relation, solved form : " << veq << std::endl;
319
            // try pulling ITE
320
7789
            new_ret = extendedRewritePullIte(ITE, veq);
321
7789
            if (!new_ret.isNull())
322
            {
323
1083
              if (!polarity)
324
              {
325
                new_ret = new_ret.negate();
326
              }
327
1083
              break;
328
            }
329
          }
330
          else
331
          {
332
6
            Trace("q-ext-rewrite-debug")
333
3
                << "  failed to isolate " << v << " in " << n << std::endl;
334
          }
335
        }
336
      }
337
    }
338
    else
339
    {
340
      Trace("q-ext-rewrite-debug")
341
          << "  failed to get monomial sum of " << n << std::endl;
342
    }
343
  }
344
  // TODO (#1706) : conditional rewriting, condition merging
345
607314
  return new_ret;
346
}
347
348
55987
Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const
349
{
350
55987
  Assert(n.getKind() == itek);
351
55987
  Assert(n[1] != n[2]);
352
353
55987
  NodeManager* nm = NodeManager::currentNM();
354
355
55987
  Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
356
357
111974
  Node flip_cond;
358
55987
  if (n[0].getKind() == NOT)
359
  {
360
    flip_cond = n[0][0];
361
  }
362
55987
  else if (n[0].getKind() == OR)
363
  {
364
    // a | b ---> ~( ~a & ~b )
365
317
    flip_cond = TermUtil::simpleNegate(n[0]);
366
  }
367
55987
  if (!flip_cond.isNull())
368
  {
369
634
    Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
370
    // only print debug trace if full=true
371
317
    if (full)
372
    {
373
317
      debugExtendedRewrite(n, new_ret, "ITE flip");
374
    }
375
317
    return new_ret;
376
  }
377
  // Boolean true/false return
378
111340
  TypeNode tn = n.getType();
379
55670
  if (tn.isBoolean())
380
  {
381
88242
    for (unsigned i = 1; i <= 2; i++)
382
    {
383
58828
      if (n[i].isConst())
384
      {
385
        Node cond = i == 1 ? n[0] : n[0].negate();
386
        Node other = n[i == 1 ? 2 : 1];
387
        Kind retk = AND;
388
        if (n[i].getConst<bool>())
389
        {
390
          retk = OR;
391
        }
392
        else
393
        {
394
          cond = cond.negate();
395
        }
396
        Node new_ret = nm->mkNode(retk, cond, other);
397
        if (full)
398
        {
399
          // ite( A, true, B ) ---> A V B
400
          // ite( A, false, B ) ---> ~A /\ B
401
          // ite( A, B,  true ) ---> ~A V B
402
          // ite( A, B, false ) ---> A /\ B
403
          debugExtendedRewrite(n, new_ret, "ITE const return");
404
        }
405
        return new_ret;
406
      }
407
    }
408
  }
409
410
  // get entailed equalities in the condition
411
111340
  std::vector<Node> eq_conds;
412
55670
  Kind ck = n[0].getKind();
413
55670
  if (ck == EQUAL)
414
  {
415
16409
    eq_conds.push_back(n[0]);
416
  }
417
39261
  else if (ck == AND)
418
  {
419
43059
    for (const Node& cn : n[0])
420
    {
421
33543
      if (cn.getKind() == EQUAL)
422
      {
423
2002
        eq_conds.push_back(cn);
424
      }
425
    }
426
  }
427
428
111340
  Node new_ret;
429
111340
  Node b;
430
111340
  Node e;
431
111340
  Node t1 = n[1];
432
111340
  Node t2 = n[2];
433
111340
  std::stringstream ss_reason;
434
435
73982
  for (const Node& eq : eq_conds)
436
  {
437
    // simple invariant ITE
438
55079
    for (unsigned i = 0; i <= 1; i++)
439
    {
440
      // ite( x = y ^ C, y, x ) ---> x
441
      // this is subsumed by the rewrites below
442
36767
      if (t2 == eq[i] && t1 == eq[1 - i])
443
      {
444
97
        new_ret = t2;
445
97
        ss_reason << "ITE simple rev subs";
446
97
        break;
447
      }
448
    }
449
18409
    if (!new_ret.isNull())
450
    {
451
97
      break;
452
    }
453
  }
454
55670
  if (new_ret.isNull())
455
  {
456
    // merging branches
457
165884
    for (unsigned i = 1; i <= 2; i++)
458
    {
459
110875
      if (n[i].getKind() == ITE)
460
      {
461
10272
        Node no = n[3 - i];
462
14598
        for (unsigned j = 1; j <= 2; j++)
463
        {
464
10026
          if (n[i][j] == no)
465
          {
466
            // e.g.
467
            // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
468
1128
            Node nc1 = i == 2 ? n[0].negate() : n[0];
469
1128
            Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
470
1128
            Node new_cond = nm->mkNode(AND, nc1, nc2);
471
564
            new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
472
564
            ss_reason << "ITE merge branch";
473
564
            break;
474
          }
475
        }
476
      }
477
110875
      if (!new_ret.isNull())
478
      {
479
564
        break;
480
      }
481
    }
482
  }
483
484
55670
  if (new_ret.isNull() && d_aggr)
485
  {
486
    // If x is less than t based on an ordering, then we use { x -> t } as a
487
    // substitution to the children of ite( x = t ^ C, s, t ) below.
488
110010
    std::vector<Node> vars;
489
110010
    std::vector<Node> subs;
490
55005
    inferSubstitution(n[0], vars, subs, true);
491
492
55005
    if (!vars.empty())
493
    {
494
      // reverse substitution to opposite child
495
      // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
496
      // We can use ordinary substitute since the result of the substitution
497
      // is not being returned. In other words, nn is only being used to query
498
      // whether the second branch is a generalization of the first.
499
      Node nn =
500
110010
          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
501
55005
      if (nn != t2)
502
      {
503
11934
        nn = d_rew.rewrite(nn);
504
11934
        if (nn == t1)
505
        {
506
32
          new_ret = t2;
507
32
          ss_reason << "ITE rev subs";
508
        }
509
      }
510
511
      // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
512
      // must use partial substitute here, to avoid substitution into witness
513
110010
      std::map<Kind, bool> rkinds;
514
55005
      nn = partialSubstitute(t1, vars, subs, rkinds);
515
55005
      nn = d_rew.rewrite(nn);
516
55005
      if (nn != t1)
517
      {
518
        // If full=false, then we've duplicated a term u in the children of n.
519
        // For example, when ITE pulling, we have n is of the form:
520
        //   ite( C, f( u, t1 ), f( u, t2 ) )
521
        // We must show that at least one copy of u dissappears in this case.
522
8364
        if (nn == t2)
523
        {
524
53
          new_ret = nn;
525
53
          ss_reason << "ITE subs invariant";
526
        }
527
8311
        else if (full || nn.isConst())
528
        {
529
1564
          new_ret = nm->mkNode(itek, n[0], nn, t2);
530
1564
          ss_reason << "ITE subs";
531
        }
532
      }
533
    }
534
55005
    if (new_ret.isNull())
535
    {
536
      // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
537
      // use partial substitute to avoid substitution into witness
538
106712
      std::map<Node, Node> assign;
539
53356
      assign[n[0]] = d_false;
540
106712
      std::map<Kind, bool> rkinds;
541
106712
      Node nn = partialSubstitute(t2, assign, rkinds);
542
53356
      if (nn != t2)
543
      {
544
5666
        nn = d_rew.rewrite(nn);
545
5666
        if (nn == t1)
546
        {
547
18
          new_ret = nn;
548
18
          ss_reason << "ITE subs invariant false";
549
        }
550
5648
        else if (full || nn.isConst())
551
        {
552
484
          new_ret = nm->mkNode(itek, n[0], t1, nn);
553
484
          ss_reason << "ITE subs false";
554
        }
555
      }
556
    }
557
  }
558
559
  // only print debug trace if full=true
560
55670
  if (!new_ret.isNull() && full)
561
  {
562
2292
    debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
563
  }
564
565
55670
  return new_ret;
566
}
567
568
80107
Node ExtendedRewriter::extendedRewriteAndOr(Node n) const
569
{
570
  // all the below rewrites are aggressive
571
80107
  if (!d_aggr)
572
  {
573
18
    return Node::null();
574
  }
575
160178
  Node new_ret;
576
  // we allow substitutions to recurse over any kind, except WITNESS which is
577
  // managed by partialSubstitute.
578
160178
  std::map<Kind, bool> bcp_kinds;
579
80089
  new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
580
80089
  if (!new_ret.isNull())
581
  {
582
3871
    debugExtendedRewrite(n, new_ret, "Bool bcp");
583
3871
    return new_ret;
584
  }
585
  // factoring
586
76218
  new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
587
76218
  if (!new_ret.isNull())
588
  {
589
5103
    debugExtendedRewrite(n, new_ret, "Bool factoring");
590
5103
    return new_ret;
591
  }
592
593
  // equality resolution
594
71115
  new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
595
71115
  debugExtendedRewrite(n, new_ret, "Bool eq res");
596
71115
  return new_ret;
597
}
598
599
311455
Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const
600
{
601
311455
  Assert(n.getKind() != ITE);
602
311455
  if (n.isClosure())
603
  {
604
    // don't pull ITE out of quantifiers
605
749
    return n;
606
  }
607
310706
  NodeManager* nm = NodeManager::currentNM();
608
621412
  TypeNode tn = n.getType();
609
621412
  std::vector<Node> children;
610
310706
  bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
611
310706
  if (hasOp)
612
  {
613
9189
    children.push_back(n.getOperator());
614
  }
615
310706
  unsigned nchildren = n.getNumChildren();
616
1003522
  for (unsigned i = 0; i < nchildren; i++)
617
  {
618
692816
    children.push_back(n[i]);
619
  }
620
621412
  std::map<unsigned, std::map<unsigned, Node> > ite_c;
621
970597
  for (unsigned i = 0; i < nchildren; i++)
622
  {
623
    // only pull ITEs apart if we are aggressive
624
2049648
    if (n[i].getKind() == itek
625
2049648
        && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
626
    {
627
54114
      unsigned ii = hasOp ? i + 1 : i;
628
162342
      for (unsigned j = 0; j < 2; j++)
629
      {
630
108228
        children[ii] = n[i][j + 1];
631
216456
        Node pull = nm->mkNode(n.getKind(), children);
632
216456
        Node pullr = d_rew.rewrite(pull);
633
108228
        children[ii] = n[i];
634
108228
        ite_c[i][j] = pullr;
635
      }
636
54114
      if (ite_c[i][0] == ite_c[i][1])
637
      {
638
        // ITE dual invariance
639
        // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
640
        // f( t1..ite( A, s1, s2 )..tn ) ---> t
641
2047
        debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
642
2047
        return ite_c[i][0];
643
      }
644
52067
      if (d_aggr)
645
      {
646
94810
        if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
647
72590
            && !n[1 - i].getType().isBoolean() && tn.isBoolean())
648
        {
649
          // always pull variable or constant with binary (theory) predicate
650
          // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
651
27288
          Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
652
13644
          debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
653
13644
          return new_ret;
654
        }
655
102255
        for (unsigned j = 0; j < 2; j++)
656
        {
657
135302
          Node pullr = ite_c[i][j];
658
71468
          if (pullr.isConst() || pullr == n[i][j + 1])
659
          {
660
            // ITE single child elimination
661
            // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
662
            // implies
663
            // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
664
15268
            Node new_ret;
665
7634
            if (tn.isBoolean() && pullr.isConst())
666
            {
667
              // remove false/true child immediately
668
1962
              bool pol = pullr.getConst<bool>();
669
3924
              std::vector<Node> new_children;
670
1962
              new_children.push_back((j == 0) == pol ? n[i][0]
671
                                                     : n[i][0].negate());
672
1962
              new_children.push_back(ite_c[i][1 - j]);
673
1962
              new_ret = nm->mkNode(pol ? OR : AND, new_children);
674
1962
              debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
675
            }
676
            else
677
            {
678
5672
              new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
679
5672
              debugExtendedRewrite(n, new_ret, "ITE single elim");
680
            }
681
7634
            return new_ret;
682
          }
683
        }
684
      }
685
    }
686
  }
687
287381
  if (d_aggr)
688
  {
689
315565
    for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
690
    {
691
57690
      Node nite = n[ip.first];
692
29190
      Assert(nite.getKind() == itek);
693
      // now, simply pull the ITE and try ITE rewrites
694
57690
      Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
695
29190
      pull_ite = d_rew.rewrite(pull_ite);
696
29190
      if (pull_ite.getKind() == ITE)
697
      {
698
57520
        Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
699
29020
        if (!new_pull_ite.isNull())
700
        {
701
520
          debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
702
520
          return new_pull_ite;
703
        }
704
      }
705
      else
706
      {
707
        // A general rewrite could eliminate the ITE by pulling.
708
        // An example is:
709
        //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
710
        //   ite( C, ~~x, ite( C, y, x ) ) --->
711
        //   x
712
        // where ~ is bitvector negation.
713
170
        debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
714
170
        return pull_ite;
715
      }
716
    }
717
  }
718
719
286691
  return Node::null();
720
}
721
722
33746
Node ExtendedRewriter::extendedRewriteNnf(Node ret) const
723
{
724
33746
  Assert(ret.getKind() == NOT);
725
726
33746
  Kind nk = ret[0].getKind();
727
33746
  bool neg_ch = false;
728
33746
  bool neg_ch_1 = false;
729
33746
  if (nk == AND || nk == OR)
730
  {
731
7478
    neg_ch = true;
732
7478
    nk = nk == AND ? OR : AND;
733
  }
734
26268
  else if (nk == IMPLIES)
735
  {
736
58
    neg_ch = true;
737
58
    neg_ch_1 = true;
738
58
    nk = AND;
739
  }
740
26210
  else if (nk == ITE)
741
  {
742
323
    neg_ch = true;
743
323
    neg_ch_1 = true;
744
  }
745
25887
  else if (nk == XOR)
746
  {
747
78
    nk = EQUAL;
748
  }
749
25809
  else if (nk == EQUAL && ret[0][0].getType().isBoolean())
750
  {
751
242
    neg_ch_1 = true;
752
  }
753
  else
754
  {
755
25567
    return Node::null();
756
  }
757
758
16358
  std::vector<Node> new_children;
759
37816
  for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
760
  {
761
59274
    Node c = ret[0][i];
762
29637
    c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
763
29637
    new_children.push_back(c);
764
  }
765
8179
  return NodeManager::currentNM()->mkNode(nk, new_children);
766
}
767
768
80089
Node ExtendedRewriter::extendedRewriteBcp(Kind andk,
769
                                          Kind ork,
770
                                          Kind notk,
771
                                          std::map<Kind, bool>& bcp_kinds,
772
                                          Node ret) const
773
{
774
80089
  Kind k = ret.getKind();
775
80089
  Assert(k == andk || k == ork);
776
80089
  Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
777
778
80089
  NodeManager* nm = NodeManager::currentNM();
779
780
160178
  TypeNode tn = ret.getType();
781
160178
  Node truen = TermUtil::mkTypeMaxValue(tn);
782
160178
  Node falsen = TermUtil::mkTypeValue(tn, 0);
783
784
  // terms to process
785
160178
  std::vector<Node> to_process;
786
513798
  for (const Node& cn : ret)
787
  {
788
433709
    to_process.push_back(cn);
789
  }
790
  // the processing terms
791
160178
  std::vector<Node> clauses;
792
  // the terms we have propagated information to
793
160178
  std::unordered_set<Node> prop_clauses;
794
  // the assignment
795
160178
  std::map<Node, Node> assign;
796
160178
  std::vector<Node> avars;
797
160178
  std::vector<Node> asubs;
798
799
80089
  Kind ok = k == andk ? ork : andk;
800
  // global polarity : when k=ork, everything is negated
801
80089
  bool gpol = k == andk;
802
803
3593
  do
804
  {
805
    // process the current nodes
806
252326
    while (!to_process.empty())
807
    {
808
169166
      std::vector<Node> new_to_process;
809
525244
      for (const Node& cn : to_process)
810
      {
811
440922
        Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
812
440922
        Kind cnk = cn.getKind();
813
440922
        bool pol = cnk != notk;
814
881322
        Node cln = cnk == notk ? cn[0] : cn;
815
440922
        Assert(cln.getKind() != notk);
816
440922
        if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
817
        {
818
          // flatten
819
3988
          for (const Node& ccln : cln)
820
          {
821
5550
            Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
822
2775
            new_to_process.push_back(lccln);
823
          }
824
        }
825
        else
826
        {
827
          // add it to the assignment
828
878896
          Node val = gpol == pol ? truen : falsen;
829
439709
          std::map<Node, Node>::iterator it = assign.find(cln);
830
879418
          Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
831
439709
                               << std::endl;
832
439709
          if (it != assign.end())
833
          {
834
596
            if (val != it->second)
835
            {
836
522
              Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
837
              // a conflicting assignment: we are done
838
522
              return gpol ? falsen : truen;
839
            }
840
          }
841
          else
842
          {
843
439113
            assign[cln] = val;
844
439113
            avars.push_back(cln);
845
439113
            asubs.push_back(val);
846
          }
847
848
          // also, treat it as clause if possible
849
878374
          if (cln.getNumChildren() > 0
850
788161
              && (bcp_kinds.empty()
851
439187
                  || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
852
          {
853
697948
            if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
854
348974
                && prop_clauses.find(cn) == prop_clauses.end())
855
            {
856
348946
              Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
857
348946
              clauses.push_back(cn);
858
            }
859
          }
860
        }
861
      }
862
84322
      to_process.clear();
863
84322
      to_process.insert(
864
168644
          to_process.end(), new_to_process.begin(), new_to_process.end());
865
    }
866
867
    // apply substitution to all subterms of clauses
868
166320
    std::vector<Node> new_clauses;
869
445601
    for (const Node& c : clauses)
870
    {
871
362441
      bool cpol = c.getKind() != notk;
872
724882
      Node ca = c.getKind() == notk ? c[0] : c;
873
362441
      bool childChanged = false;
874
724882
      std::vector<Node> ccs_children;
875
1331433
      for (const Node& cc : ca)
876
      {
877
        // always use partial substitute, to avoid substitution in witness
878
968992
        Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
879
        // substitution is only applicable to compatible kinds in bcp_kinds
880
1937984
        Node ccs = partialSubstitute(cc, assign, bcp_kinds);
881
968992
        childChanged = childChanged || ccs != cc;
882
968992
        ccs_children.push_back(ccs);
883
      }
884
362441
      if (childChanged)
885
      {
886
4914
        if (ca.getMetaKind() == metakind::PARAMETERIZED)
887
        {
888
4
          ccs_children.insert(ccs_children.begin(), ca.getOperator());
889
        }
890
9828
        Node ccs = nm->mkNode(ca.getKind(), ccs_children);
891
4914
        ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
892
9828
        Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
893
4914
                             << std::endl;
894
4914
        ccs = d_rew.rewrite(ccs);
895
4914
        Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
896
4914
        to_process.push_back(ccs);
897
        // store this as a node that propagation touched. This marks c so that
898
        // it will not be included in the final construction.
899
4914
        prop_clauses.insert(ca);
900
      }
901
      else
902
      {
903
357527
        new_clauses.push_back(c);
904
      }
905
    }
906
83160
    clauses.clear();
907
83160
    clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
908
83160
  } while (!to_process.empty());
909
910
  // remake the node
911
79567
  if (!prop_clauses.empty())
912
  {
913
6698
    std::vector<Node> children;
914
28618
    for (std::pair<const Node, Node>& l : assign)
915
    {
916
50538
      Node a = l.first;
917
      // if propagation did not touch a
918
25269
      if (prop_clauses.find(a) == prop_clauses.end())
919
      {
920
20377
        Assert(l.second == truen || l.second == falsen);
921
40754
        Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
922
20377
        children.push_back(ln);
923
      }
924
    }
925
6698
    Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
926
3349
    Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
927
3349
    return new_ret;
928
  }
929
930
76218
  return Node::null();
931
}
932
933
76218
Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
934
                                                Kind ork,
935
                                                Kind notk,
936
                                                Node n) const
937
{
938
76218
  Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
939
76218
  NodeManager* nm = NodeManager::currentNM();
940
941
76218
  Kind nk = n.getKind();
942
76218
  Assert(nk == andk || nk == ork);
943
76218
  Kind onk = nk == andk ? ork : andk;
944
  // count the number of times atoms occur
945
152436
  std::map<Node, std::vector<Node> > lit_to_cl;
946
152436
  std::map<Node, std::vector<Node> > cl_to_lits;
947
488187
  for (const Node& nc : n)
948
  {
949
411969
    Kind nck = nc.getKind();
950
411969
    if (nck == onk)
951
    {
952
535334
      for (const Node& ncl : nc)
953
      {
954
1294878
        if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
955
1294878
            == lit_to_cl[ncl].end())
956
        {
957
431626
          lit_to_cl[ncl].push_back(nc);
958
431626
          cl_to_lits[nc].push_back(ncl);
959
        }
960
      }
961
    }
962
    else
963
    {
964
308261
      lit_to_cl[nc].push_back(nc);
965
308261
      cl_to_lits[nc].push_back(nc);
966
    }
967
  }
968
  // get the maximum shared literal to factor
969
76218
  unsigned max_size = 0;
970
152436
  Node flit;
971
647849
  for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
972
  {
973
571631
    if (ltc.second.size() > max_size)
974
    {
975
81166
      max_size = ltc.second.size();
976
81166
      flit = ltc.first;
977
    }
978
  }
979
76218
  if (max_size > 1)
980
  {
981
    // do the factoring
982
10206
    std::vector<Node> children;
983
10206
    std::vector<Node> fchildren;
984
5103
    std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
985
5103
    std::vector<Node>& cls = itl->second;
986
80416
    for (const Node& nc : n)
987
    {
988
75313
      if (std::find(cls.begin(), cls.end(), nc) == cls.end())
989
      {
990
57724
        children.push_back(nc);
991
      }
992
      else
993
      {
994
        // rebuild
995
17589
        std::vector<Node>& lits = cl_to_lits[nc];
996
        std::vector<Node>::iterator itlfl =
997
17589
            std::find(lits.begin(), lits.end(), flit);
998
17589
        Assert(itlfl != lits.end());
999
17589
        lits.erase(itlfl);
1000
        // rebuild
1001
17589
        if (!lits.empty())
1002
        {
1003
35178
          Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
1004
17589
          fchildren.push_back(new_cl);
1005
        }
1006
      }
1007
    }
1008
    // rebuild the factored children
1009
5103
    Assert(!fchildren.empty());
1010
10206
    Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
1011
5103
    children.push_back(nm->mkNode(onk, flit, fcn));
1012
10206
    Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
1013
5103
    Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
1014
5103
    return ret;
1015
  }
1016
  else
1017
  {
1018
71115
    Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
1019
  }
1020
71115
  return Node::null();
1021
}
1022
1023
71115
Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
1024
                                            Kind ork,
1025
                                            Kind eqk,
1026
                                            Kind notk,
1027
                                            std::map<Kind, bool>& bcp_kinds,
1028
                                            Node n,
1029
                                            bool isXor) const
1030
{
1031
71115
  Assert(n.getKind() == andk || n.getKind() == ork);
1032
71115
  Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
1033
1034
71115
  NodeManager* nm = NodeManager::currentNM();
1035
71115
  Kind nk = n.getKind();
1036
71115
  bool gpol = (nk == andk);
1037
373835
  for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
1038
  {
1039
609319
    Node lit = n[i];
1040
306599
    if (lit.getKind() == eqk)
1041
    {
1042
      // eq is the equality we are basing a substitution on
1043
211801
      Node eq;
1044
107840
      if (gpol == isXor)
1045
      {
1046
        // can only turn disequality into equality if types are the same
1047
16824
        if (lit[1].getType() == lit.getType())
1048
        {
1049
          // t != s ---> ~t = s
1050
287
          if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1051
          {
1052
41
            eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1053
          }
1054
          else
1055
          {
1056
246
            eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1057
          }
1058
        }
1059
      }
1060
      else
1061
      {
1062
91016
        eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1063
      }
1064
107840
      if (!eq.isNull())
1065
      {
1066
        // see if it corresponds to a substitution
1067
178727
        std::vector<Node> vars;
1068
178727
        std::vector<Node> subs;
1069
91303
        if (inferSubstitution(eq, vars, subs))
1070
        {
1071
82328
          Assert(vars.size() == 1);
1072
160777
          std::vector<Node> children;
1073
82328
          bool childrenChanged = false;
1074
          // apply to all other children
1075
1431287
          for (unsigned j = 0; j < nchild; j++)
1076
          {
1077
2697918
            Node ccs = n[j];
1078
1348959
            if (i != j)
1079
            {
1080
              // Substitution is only applicable to compatible kinds. We always
1081
              // use the partialSubstitute method to avoid substitution into
1082
              // witness terms.
1083
1266631
              ccs = partialSubstitute(ccs, vars, subs, bcp_kinds);
1084
1266631
              childrenChanged = childrenChanged || n[j] != ccs;
1085
            }
1086
1348959
            children.push_back(ccs);
1087
          }
1088
82328
          if (childrenChanged)
1089
          {
1090
3879
            return nm->mkNode(nk, children);
1091
          }
1092
        }
1093
      }
1094
    }
1095
  }
1096
1097
67236
  return Node::null();
1098
}
1099
1100
/** sort pairs by their second (unsigned) argument */
1101
85113
static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1102
                           const std::pair<Node, unsigned>& b)
1103
{
1104
85113
  return (a.second < b.second);
1105
}
1106
1107
/** A simple subsumption trie used to compute pairwise list subsets */
1108
256196
class SimpSubsumeTrie
1109
{
1110
 public:
1111
  /** the children of this node */
1112
  std::map<Node, SimpSubsumeTrie> d_children;
1113
  /** the term at this node */
1114
  Node d_data;
1115
  /** add term to the trie
1116
   *
1117
   * This adds term c to this trie, whose atom list is alist. This adds terms
1118
   * s to subsumes such that the atom list of s is a subset of the atom list
1119
   * of c. For example, say:
1120
   *   c1.alist = { A }
1121
   *   c2.alist = { C }
1122
   *   c3.alist = { B, C }
1123
   *   c4.alist = { A, B, D }
1124
   *   c5.alist = { A, B, C }
1125
   * If these terms are added in the order c1, c2, c3, c4, c5, then:
1126
   *   addTerm c1 results in subsumes = {}
1127
   *   addTerm c2 results in subsumes = {}
1128
   *   addTerm c3 results in subsumes = { c2 }
1129
   *   addTerm c4 results in subsumes = { c1 }
1130
   *   addTerm c5 results in subsumes = { c1, c2, c3 }
1131
   * Notice that the intended use case of this trie is to add term t before t'
1132
   * only when size( t.alist ) <= size( t'.alist ).
1133
   *
1134
   * The last two arguments describe the state of the path [t0...tn] we
1135
   * have followed in the trie during the recursive call.
1136
   * If doAdd = true,
1137
   *   then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1138
   *   we add c as the current node of this trie.
1139
   * If doAdd = false,
1140
   *   then t1...tn occur in alist.
1141
   */
1142
170507
  void addTerm(Node c,
1143
               std::vector<Node>& alist,
1144
               std::vector<Node>& subsumes,
1145
               unsigned index = 0,
1146
               bool doAdd = true)
1147
  {
1148
170507
    if (!d_data.isNull())
1149
    {
1150
24
      subsumes.push_back(d_data);
1151
    }
1152
170507
    if (doAdd)
1153
    {
1154
170465
      if (index == alist.size())
1155
      {
1156
84026
        d_data = c;
1157
84026
        return;
1158
      }
1159
    }
1160
    // try all children where we have this atom
1161
130218
    for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1162
    {
1163
43737
      if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1164
      {
1165
42
        cp.second.addTerm(c, alist, subsumes, 0, false);
1166
      }
1167
    }
1168
86481
    if (doAdd)
1169
    {
1170
86439
      d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1171
    }
1172
  }
1173
};
1174
1175
41789
Node ExtendedRewriter::extendedRewriteEqChain(
1176
    Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) const
1177
{
1178
41789
  Assert(ret.getKind() == eqk);
1179
1180
  // this rewrite is aggressive; it in fact has the precondition that other
1181
  // aggressive rewrites (including BCP) have been applied.
1182
41789
  if (!d_aggr)
1183
  {
1184
96
    return Node::null();
1185
  }
1186
1187
41693
  NodeManager* nm = NodeManager::currentNM();
1188
1189
83386
  TypeNode tn = ret[0].getType();
1190
1191
  // sort/cancelling for Boolean EQUAL/XOR-chains
1192
41693
  Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1193
1194
  // get the children on either side
1195
41693
  bool gpol = true;
1196
83386
  std::vector<Node> children;
1197
125079
  for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1198
  {
1199
166772
    Node curr = ret[r];
1200
    // assume, if necessary, right associative
1201
85158
    while (curr.getKind() == eqk && curr[0].getType() == tn)
1202
    {
1203
886
      children.push_back(curr[0]);
1204
886
      curr = curr[1];
1205
    }
1206
83386
    children.push_back(curr);
1207
  }
1208
1209
83386
  std::map<Node, bool> cstatus;
1210
  // add children to status
1211
125965
  for (const Node& c : children)
1212
  {
1213
168544
    Node a = c;
1214
84272
    if (a.getKind() == notk)
1215
    {
1216
806
      gpol = !gpol;
1217
806
      a = a[0];
1218
    }
1219
84272
    Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1220
84272
    std::map<Node, bool>::iterator itc = cstatus.find(a);
1221
84272
    if (itc == cstatus.end())
1222
    {
1223
84223
      cstatus[a] = true;
1224
    }
1225
    else
1226
    {
1227
      // cancels
1228
49
      cstatus.erase(a);
1229
49
      if (isXor)
1230
      {
1231
        gpol = !gpol;
1232
      }
1233
    }
1234
  }
1235
41693
  Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1236
1237
41693
  if (cstatus.empty())
1238
  {
1239
2
    return TermUtil::mkTypeConst(tn, gpol);
1240
  }
1241
1242
41691
  children.clear();
1243
1244
  // compute the atoms of each child
1245
41691
  Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1246
41691
  Trace("ext-rew-eqchain") << "  eqchain-simplify: get atoms...\n";
1247
83382
  std::map<Node, std::map<Node, bool> > atoms;
1248
83382
  std::map<Node, std::vector<Node> > alist;
1249
83382
  std::vector<std::pair<Node, unsigned> > atom_count;
1250
125865
  for (std::pair<const Node, bool>& cp : cstatus)
1251
  {
1252
84174
    if (!cp.second)
1253
    {
1254
      // already eliminated
1255
      continue;
1256
    }
1257
168348
    Node c = cp.first;
1258
84174
    Kind ck = c.getKind();
1259
84174
    Trace("ext-rew-eqchain") << "  process c = " << c << std::endl;
1260
84174
    if (ck == andk || ck == ork)
1261
    {
1262
5509
      for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1263
      {
1264
7980
        Node cl = c[j];
1265
4027
        bool pol = cl.getKind() != notk;
1266
7980
        Node ca = pol ? cl : cl[0];
1267
4027
        bool newVal = (ck == andk ? !pol : pol);
1268
8054
        Trace("ext-rew-eqchain")
1269
4027
            << "  atoms(" << c << ", " << ca << ") = " << newVal << std::endl;
1270
4027
        Assert(atoms[c].find(ca) == atoms[c].end());
1271
        // polarity is flipped when we are AND
1272
4027
        atoms[c][ca] = newVal;
1273
4027
        alist[c].push_back(ca);
1274
1275
        // if this already exists as a child of the equality chain, eliminate.
1276
        // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1277
        // consider ( x & y ) a unit, whereas below it is expanded to
1278
        // ~( ~x | ~y ).
1279
4027
        std::map<Node, bool>::iterator itc = cstatus.find(ca);
1280
4027
        if (itc != cstatus.end() && itc->second)
1281
        {
1282
          // cancel it
1283
74
          cstatus[ca] = false;
1284
74
          cstatus[c] = false;
1285
          // make new child
1286
          // x = ( y | ~x ) ---> y & x
1287
          // x = ( y | x ) ---> ~y | x
1288
          // x = ( y & x ) ---> y | ~x
1289
          // x = ( y & ~x ) ---> ~y & ~x
1290
148
          std::vector<Node> new_children;
1291
250
          for (unsigned k = 0, nchildc = c.getNumChildren(); k < nchildc; k++)
1292
          {
1293
176
            if (j != k)
1294
            {
1295
102
              new_children.push_back(c[k]);
1296
            }
1297
          }
1298
148
          Node nc[2];
1299
74
          nc[0] = c[j];
1300
74
          nc[1] = new_children.size() == 1 ? new_children[0]
1301
                                           : nm->mkNode(ck, new_children);
1302
          // negate the proper child
1303
74
          unsigned nindex = (ck == andk) == pol ? 0 : 1;
1304
74
          nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1305
74
          Kind nk = pol ? ork : andk;
1306
          // store as new child
1307
74
          children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1308
74
          if (isXor)
1309
          {
1310
            gpol = !gpol;
1311
          }
1312
74
          break;
1313
        }
1314
1556
      }
1315
    }
1316
    else
1317
    {
1318
82618
      bool pol = ck != notk;
1319
165236
      Node ca = pol ? c : c[0];
1320
82618
      atoms[c][ca] = pol;
1321
165236
      Trace("ext-rew-eqchain")
1322
82618
          << "  atoms(" << c << ", " << ca << ") = " << pol << std::endl;
1323
82618
      alist[c].push_back(ca);
1324
    }
1325
84174
    atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1326
  }
1327
  // sort the atoms in each atom list
1328
125865
  for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1329
125865
       it != alist.end();
1330
       ++it)
1331
  {
1332
84174
    std::sort(it->second.begin(), it->second.end());
1333
  }
1334
  // check subsumptions
1335
  // sort by #atoms
1336
41691
  std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1337
41691
  if (Trace.isOn("ext-rew-eqchain"))
1338
  {
1339
    for (const std::pair<Node, unsigned>& ac : atom_count)
1340
    {
1341
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << ac.first << " has "
1342
                               << ac.second << " atoms." << std::endl;
1343
    }
1344
    Trace("ext-rew-eqchain") << "  eqchain-simplify: compute subsumptions...\n";
1345
  }
1346
83382
  SimpSubsumeTrie sst;
1347
125865
  for (std::pair<const Node, bool>& cp : cstatus)
1348
  {
1349
84174
    if (!cp.second)
1350
    {
1351
      // already eliminated
1352
148
      continue;
1353
    }
1354
168052
    Node c = cp.first;
1355
84026
    std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1356
84026
    Assert(itc != atoms.end());
1357
168052
    Trace("ext-rew-eqchain") << "  - add term " << c << " with atom list "
1358
84026
                             << alist[c] << "...\n";
1359
168052
    std::vector<Node> subsumes;
1360
84026
    sst.addTerm(c, alist[c], subsumes);
1361
84026
    for (const Node& cc : subsumes)
1362
    {
1363
8
      if (!cstatus[cc])
1364
      {
1365
        // subsumes a child that was already eliminated
1366
        continue;
1367
      }
1368
16
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << c << " subsumes "
1369
8
                               << cc << std::endl;
1370
      // for each of the atoms in cc
1371
8
      std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1372
8
      Assert(itcc != atoms.end());
1373
8
      std::vector<Node> common_children;
1374
8
      std::vector<Node> diff_children;
1375
24
      for (const std::pair<const Node, bool>& ap : itcc->second)
1376
      {
1377
        // compare the polarity
1378
32
        Node a = ap.first;
1379
16
        bool polcc = ap.second;
1380
16
        Assert(itc->second.find(a) != itc->second.end());
1381
16
        bool polc = itc->second[a];
1382
32
        Trace("ext-rew-eqchain") << "    eqchain-simplify: atom " << a
1383
16
                                 << " has polarities : " << polc << " " << polcc
1384
16
                                 << "\n";
1385
32
        Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1386
16
        if (polc != polcc)
1387
        {
1388
4
          diff_children.push_back(lit);
1389
        }
1390
        else
1391
        {
1392
12
          common_children.push_back(lit);
1393
        }
1394
      }
1395
8
      std::vector<Node> rem_children;
1396
24
      for (const std::pair<const Node, bool>& ap : itc->second)
1397
      {
1398
32
        Node a = ap.first;
1399
16
        if (atoms[cc].find(a) == atoms[cc].end())
1400
        {
1401
          bool polc = ap.second;
1402
          rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1403
        }
1404
      }
1405
16
      Trace("ext-rew-eqchain")
1406
16
          << "    #common/diff/rem: " << common_children.size() << "/"
1407
8
          << diff_children.size() << "/" << rem_children.size() << "\n";
1408
8
      bool do_rewrite = false;
1409
18
      if (common_children.empty() && itc->second.size() == itcc->second.size()
1410
10
          && itcc->second.size() == 2)
1411
      {
1412
        // x | y = ~x | ~y ---> ~( x = y )
1413
2
        do_rewrite = true;
1414
2
        children.push_back(diff_children[0]);
1415
2
        children.push_back(diff_children[1]);
1416
2
        gpol = !gpol;
1417
2
        Trace("ext-rew-eqchain") << "    apply 2-child all-diff\n";
1418
      }
1419
6
      else if (common_children.empty() && diff_children.size() == 1)
1420
      {
1421
        do_rewrite = true;
1422
        // x = ( ~x | y ) ---> ~( ~x | ~y )
1423
        Node remn = rem_children.size() == 1 ? rem_children[0]
1424
                                             : nm->mkNode(ork, rem_children);
1425
        remn = TermUtil::mkNegate(notk, remn);
1426
        children.push_back(nm->mkNode(ork, diff_children[0], remn));
1427
        if (!isXor)
1428
        {
1429
          gpol = !gpol;
1430
        }
1431
        Trace("ext-rew-eqchain") << "    apply unit resolution\n";
1432
      }
1433
12
      else if (diff_children.size() == 1
1434
6
               && itc->second.size() == itcc->second.size())
1435
      {
1436
        // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1437
        do_rewrite = true;
1438
        Assert(!common_children.empty());
1439
        Node comn = common_children.size() == 1
1440
                        ? common_children[0]
1441
                        : nm->mkNode(ork, common_children);
1442
        children.push_back(comn);
1443
        if (isXor)
1444
        {
1445
          gpol = !gpol;
1446
        }
1447
        Trace("ext-rew-eqchain") << "    apply resolution\n";
1448
      }
1449
6
      else if (diff_children.empty())
1450
      {
1451
6
        do_rewrite = true;
1452
6
        if (rem_children.empty())
1453
        {
1454
          // x | y = x | y ---> true
1455
          // this can happen if we have ( ~x & ~y ) = ( x | y )
1456
6
          children.push_back(TermUtil::mkTypeMaxValue(tn));
1457
6
          if (isXor)
1458
          {
1459
            gpol = !gpol;
1460
          }
1461
6
          Trace("ext-rew-eqchain") << "    apply cancel\n";
1462
        }
1463
        else
1464
        {
1465
          // x | y = ( x | y | z ) ---> ( x | y | ~z )
1466
          Node remn = rem_children.size() == 1 ? rem_children[0]
1467
                                               : nm->mkNode(ork, rem_children);
1468
          remn = TermUtil::mkNegate(notk, remn);
1469
          Node comn = common_children.size() == 1
1470
                          ? common_children[0]
1471
                          : nm->mkNode(ork, common_children);
1472
          children.push_back(nm->mkNode(ork, comn, remn));
1473
          if (isXor)
1474
          {
1475
            gpol = !gpol;
1476
          }
1477
          Trace("ext-rew-eqchain") << "    apply subsume\n";
1478
        }
1479
      }
1480
8
      if (do_rewrite)
1481
      {
1482
        // eliminate the children, reverse polarity as needed
1483
24
        for (unsigned r = 0; r < 2; r++)
1484
        {
1485
32
          Node c_rem = r == 0 ? c : cc;
1486
16
          cstatus[c_rem] = false;
1487
16
          if (c_rem.getKind() == andk)
1488
          {
1489
10
            gpol = !gpol;
1490
          }
1491
        }
1492
8
        break;
1493
      }
1494
    }
1495
  }
1496
41691
  Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1497
1498
  // sorted right associative chain
1499
41691
  bool has_nvar = false;
1500
41691
  unsigned nvar_index = 0;
1501
125865
  for (std::pair<const Node, bool>& cp : cstatus)
1502
  {
1503
84174
    if (cp.second)
1504
    {
1505
84010
      if (!cp.first.isVar())
1506
      {
1507
52708
        has_nvar = true;
1508
52708
        nvar_index = children.size();
1509
      }
1510
84010
      children.push_back(cp.first);
1511
    }
1512
  }
1513
41691
  std::sort(children.begin(), children.end());
1514
1515
83382
  Node new_ret;
1516
41691
  if (!gpol)
1517
  {
1518
    // negate the constant child if it exists
1519
620
    unsigned nindex = has_nvar ? nvar_index : 0;
1520
620
    children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1521
  }
1522
41691
  new_ret = children.back();
1523
41691
  unsigned index = children.size() - 1;
1524
126497
  while (index > 0)
1525
  {
1526
42403
    index--;
1527
42403
    new_ret = nm->mkNode(eqk, children[index], new_ret);
1528
  }
1529
41691
  new_ret = d_rew.rewrite(new_ret);
1530
41691
  if (new_ret != ret)
1531
  {
1532
577
    return new_ret;
1533
  }
1534
41114
  return Node::null();
1535
}
1536
1537
2343984
Node ExtendedRewriter::partialSubstitute(
1538
    Node n,
1539
    const std::map<Node, Node>& assign,
1540
    const std::map<Kind, bool>& rkinds) const
1541
{
1542
4687968
  std::unordered_map<TNode, Node> visited;
1543
2343984
  std::unordered_map<TNode, Node>::iterator it;
1544
2343984
  std::map<Node, Node>::const_iterator ita;
1545
4687968
  std::vector<TNode> visit;
1546
4687968
  TNode cur;
1547
2343984
  visit.push_back(n);
1548
35874725
  do
1549
  {
1550
38218709
    cur = visit.back();
1551
38218709
    visit.pop_back();
1552
38218709
    it = visited.find(cur);
1553
1554
38218709
    if (it == visited.end())
1555
    {
1556
16265237
      ita = assign.find(cur);
1557
16265237
      if (ita != assign.end())
1558
      {
1559
38314
        visited[cur] = ita->second;
1560
      }
1561
      else
1562
      {
1563
        // If rkinds is non-empty, then can only recurse on its kinds.
1564
        // We also always disallow substitution into witness. Notice that
1565
        // we disallow witness here, due to unsoundness when applying contextual
1566
        // substitutions over witness terms (see #4620).
1567
16226923
        Kind k = cur.getKind();
1568
16226923
        if (k != WITNESS && (rkinds.empty() || rkinds.find(k) != rkinds.end()))
1569
        {
1570
16226923
          visited[cur] = Node::null();
1571
16226923
          visit.push_back(cur);
1572
35874725
          for (const Node& cn : cur)
1573
          {
1574
19647802
            visit.push_back(cn);
1575
          }
1576
        }
1577
        else
1578
        {
1579
          visited[cur] = cur;
1580
        }
1581
      }
1582
    }
1583
21953472
    else if (it->second.isNull())
1584
    {
1585
32453846
      Node ret = cur;
1586
16226923
      bool childChanged = false;
1587
32453846
      std::vector<Node> children;
1588
16226923
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
1589
      {
1590
598306
        children.push_back(cur.getOperator());
1591
      }
1592
35874725
      for (const Node& cn : cur)
1593
      {
1594
19647802
        it = visited.find(cn);
1595
19647802
        Assert(it != visited.end());
1596
19647802
        Assert(!it->second.isNull());
1597
19647802
        childChanged = childChanged || cn != it->second;
1598
19647802
        children.push_back(it->second);
1599
      }
1600
16226923
      if (childChanged)
1601
      {
1602
145852
        ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1603
      }
1604
16226923
      visited[cur] = ret;
1605
    }
1606
38218709
  } while (!visit.empty());
1607
2343984
  Assert(visited.find(n) != visited.end());
1608
2343984
  Assert(!visited.find(n)->second.isNull());
1609
4687968
  return visited[n];
1610
}
1611
1612
1321636
Node ExtendedRewriter::partialSubstitute(
1613
    Node n,
1614
    const std::vector<Node>& vars,
1615
    const std::vector<Node>& subs,
1616
    const std::map<Kind, bool>& rkinds) const
1617
{
1618
1321636
  Assert(vars.size() == subs.size());
1619
2643272
  std::map<Node, Node> assign;
1620
2667025
  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
1621
  {
1622
1345389
    assign[vars[i]] = subs[i];
1623
  }
1624
2643272
  return partialSubstitute(n, assign, rkinds);
1625
}
1626
1627
109430
Node ExtendedRewriter::solveEquality(Node n) const
1628
{
1629
  // TODO (#1706) : implement
1630
109430
  Assert(n.getKind() == EQUAL);
1631
1632
109430
  return Node::null();
1633
}
1634
1635
179465
bool ExtendedRewriter::inferSubstitution(Node n,
1636
                                         std::vector<Node>& vars,
1637
                                         std::vector<Node>& subs,
1638
                                         bool usePred) const
1639
{
1640
179465
  if (n.getKind() == AND)
1641
  {
1642
9404
    bool ret = false;
1643
42561
    for (const Node& nc : n)
1644
    {
1645
33157
      bool cret = inferSubstitution(nc, vars, subs, usePred);
1646
33157
      ret = ret || cret;
1647
    }
1648
9404
    return ret;
1649
  }
1650
170061
  if (n.getKind() == EQUAL)
1651
  {
1652
    // see if it can be put into form x = y
1653
120995
    Node slv_eq = solveEquality(n);
1654
109430
    if (!slv_eq.isNull())
1655
    {
1656
      n = slv_eq;
1657
    }
1658
328290
    Node v[2];
1659
277528
    for (unsigned i = 0; i < 2; i++)
1660
    {
1661
218219
      if (n[i].isConst())
1662
      {
1663
50121
        vars.push_back(n[1 - i]);
1664
50121
        subs.push_back(n[i]);
1665
50121
        return true;
1666
      }
1667
168098
      if (n[i].isVar())
1668
      {
1669
145969
        v[i] = n[i];
1670
      }
1671
22129
      else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1672
      {
1673
348
        v[i] = n[i][0];
1674
      }
1675
    }
1676
85810
    for (unsigned i = 0; i < 2; i++)
1677
    {
1678
100746
      TNode r1 = v[i];
1679
100746
      Node r2 = v[1 - i];
1680
74245
      if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1681
      {
1682
47744
        r2 = n[1 - i];
1683
47744
        if (v[i] != n[i])
1684
        {
1685
286
          Assert(TermUtil::isNegate(n[i].getKind()));
1686
286
          r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1687
        }
1688
        // TODO (#1706) : union find
1689
47744
        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1690
        {
1691
47744
          vars.push_back(r1);
1692
47744
          subs.push_back(r2);
1693
47744
          return true;
1694
        }
1695
      }
1696
    }
1697
  }
1698
72196
  if (usePred)
1699
  {
1700
63221
    bool negated = n.getKind() == NOT;
1701
63221
    vars.push_back(negated ? n[0] : n);
1702
63221
    subs.push_back(negated ? d_false : d_true);
1703
63221
    return true;
1704
  }
1705
8975
  return false;
1706
}
1707
1708
7835
Node ExtendedRewriter::extendedRewriteStrings(Node node) const
1709
{
1710
15670
  Trace("q-ext-rewrite-debug")
1711
7835
      << "Extended rewrite strings : " << node << std::endl;
1712
1713
7835
  Kind k = node.getKind();
1714
7835
  if (k == EQUAL)
1715
  {
1716
716
    strings::SequencesRewriter sr(&d_rew, nullptr);
1717
358
    return sr.rewriteEqualityExt(node);
1718
  }
1719
7477
  else if (k == STRING_SUBSTR)
1720
  {
1721
2245
    NodeManager* nm = NodeManager::currentNM();
1722
4480
    Node tot_len = d_rew.rewrite(nm->mkNode(STRING_LENGTH, node[0]));
1723
4480
    strings::ArithEntail aent(&d_rew);
1724
    // (str.substr s x y) --> "" if x < len(s) |= 0 >= y
1725
4480
    Node n1_lt_tot_len = d_rew.rewrite(nm->mkNode(LT, node[1], tot_len));
1726
2245
    if (aent.checkWithAssumption(n1_lt_tot_len, d_zero, node[2], false))
1727
    {
1728
12
      Node ret = strings::Word::mkEmptyWord(node.getType());
1729
6
      debugExtendedRewrite(node, ret, "SS_START_ENTAILS_ZERO_LEN");
1730
6
      return ret;
1731
    }
1732
1733
    // (str.substr s x y) --> "" if 0 < y |= x >= str.len(s)
1734
4474
    Node non_zero_len = d_rew.rewrite(nm->mkNode(LT, d_zero, node[2]));
1735
2239
    if (aent.checkWithAssumption(non_zero_len, node[1], tot_len, false))
1736
    {
1737
4
      Node ret = strings::Word::mkEmptyWord(node.getType());
1738
2
      debugExtendedRewrite(node, ret, "SS_NON_ZERO_LEN_ENTAILS_OOB");
1739
2
      return ret;
1740
    }
1741
    // (str.substr s x y) --> "" if x >= 0 |= 0 >= str.len(s)
1742
4472
    Node geq_zero_start = d_rew.rewrite(nm->mkNode(GEQ, node[1], d_zero));
1743
2237
    if (aent.checkWithAssumption(geq_zero_start, d_zero, tot_len, false))
1744
    {
1745
4
      Node ret = strings::Word::mkEmptyWord(node.getType());
1746
2
      debugExtendedRewrite(node, ret, "SS_GEQ_ZERO_START_ENTAILS_EMP_S");
1747
2
      return ret;
1748
    }
1749
  }
1750
1751
7467
  return Node::null();
1752
}
1753
1754
184470
void ExtendedRewriter::debugExtendedRewrite(Node n,
1755
                                            Node ret,
1756
                                            const char* c) const
1757
{
1758
184470
  if (Trace.isOn("q-ext-rewrite"))
1759
  {
1760
    if (!ret.isNull())
1761
    {
1762
      Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1763
      Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1764
                             << " rewrites to " << ret << std::endl;
1765
    }
1766
  }
1767
184470
}
1768
1769
}  // namespace quantifiers
1770
}  // namespace theory
1771
31137
}  // namespace cvc5