GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/extended_rewrite.cpp Lines: 764 841 90.8 %
Date: 2021-05-22 Branches: 1841 3800 48.4 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of extended rewriting techniques.
14
 */
15
16
#include "theory/quantifiers/extended_rewrite.h"
17
18
#include <sstream>
19
20
#include "theory/arith/arith_msum.h"
21
#include "theory/bv/theory_bv_utils.h"
22
#include "theory/datatypes/datatypes_rewriter.h"
23
#include "theory/quantifiers/term_util.h"
24
#include "theory/rewriter.h"
25
#include "theory/strings/sequences_rewriter.h"
26
#include "theory/theory.h"
27
28
using namespace cvc5::kind;
29
using namespace std;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
35
struct ExtRewriteAttributeId
36
{
37
};
38
typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
39
40
struct ExtRewriteAggAttributeId
41
{
42
};
43
typedef expr::Attribute<ExtRewriteAggAttributeId, Node> ExtRewriteAggAttribute;
44
45
11109
ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
46
{
47
11109
  d_true = NodeManager::currentNM()->mkConst(true);
48
11109
  d_false = NodeManager::currentNM()->mkConst(false);
49
11109
}
50
51
217532
void ExtendedRewriter::setCache(Node n, Node ret)
52
{
53
217532
  if (d_aggr)
54
  {
55
    ExtRewriteAggAttribute erga;
56
216884
    n.setAttribute(erga, ret);
57
  }
58
  else
59
  {
60
    ExtRewriteAttribute era;
61
648
    n.setAttribute(era, ret);
62
  }
63
217532
}
64
65
504875
Node ExtendedRewriter::getCache(Node n)
66
{
67
504875
  if (d_aggr)
68
  {
69
504363
    if (n.hasAttribute(ExtRewriteAggAttribute()))
70
    {
71
395098
      return n.getAttribute(ExtRewriteAggAttribute());
72
    }
73
  }
74
  else
75
  {
76
512
    if (n.hasAttribute(ExtRewriteAttribute()))
77
    {
78
188
      return n.getAttribute(ExtRewriteAttribute());
79
    }
80
  }
81
109589
  return Node::null();
82
}
83
84
218978
bool ExtendedRewriter::addToChildren(Node nc,
85
                                     std::vector<Node>& children,
86
                                     bool dropDup)
87
{
88
  // If the operator is non-additive, do not consider duplicates
89
218978
  if (dropDup
90
218978
      && std::find(children.begin(), children.end(), nc) != children.end())
91
  {
92
90
    return false;
93
  }
94
218888
  children.push_back(nc);
95
218888
  return true;
96
}
97
98
504875
Node ExtendedRewriter::extendedRewrite(Node n)
99
{
100
504875
  n = Rewriter::rewrite(n);
101
102
  // has it already been computed?
103
1009750
  Node ncache = getCache(n);
104
504875
  if (!ncache.isNull())
105
  {
106
395286
    return ncache;
107
  }
108
109
219178
  Node ret = n;
110
109589
  NodeManager* nm = NodeManager::currentNM();
111
112
  //--------------------pre-rewrite
113
109589
  if (d_aggr)
114
  {
115
216884
    Node pre_new_ret;
116
109265
    if (ret.getKind() == IMPLIES)
117
    {
118
22
      pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
119
22
      debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
120
    }
121
109243
    else if (ret.getKind() == XOR)
122
    {
123
114
      pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
124
114
      debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
125
    }
126
109129
    else if (ret.getKind() == NOT)
127
    {
128
11391
      pre_new_ret = extendedRewriteNnf(ret);
129
11391
      debugExtendedRewrite(ret, pre_new_ret, "NNF");
130
    }
131
109265
    if (!pre_new_ret.isNull())
132
    {
133
1646
      ret = extendedRewrite(pre_new_ret);
134
135
3292
      Trace("q-ext-rewrite-debug")
136
1646
          << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
137
1646
      setCache(n, ret);
138
1646
      return ret;
139
    }
140
  }
141
  //--------------------end pre-rewrite
142
143
  //--------------------rewrite children
144
107943
  if (n.getNumChildren() > 0)
145
  {
146
205734
    std::vector<Node> children;
147
102867
    if (n.getMetaKind() == metakind::PARAMETERIZED)
148
    {
149
3406
      children.push_back(n.getOperator());
150
    }
151
102867
    Kind k = n.getKind();
152
102867
    bool childChanged = false;
153
102867
    bool isNonAdditive = TermUtil::isNonAdditive(k);
154
    // We flatten associative operators below, which requires k to be n-ary.
155
102867
    bool isAssoc = TermUtil::isAssoc(k, true);
156
321518
    for (unsigned i = 0; i < n.getNumChildren(); i++)
157
    {
158
437302
      Node nc = extendedRewrite(n[i]);
159
218651
      childChanged = nc != n[i] || childChanged;
160
218651
      if (isAssoc && nc.getKind() == n.getKind())
161
      {
162
933
        for (const Node& ncc : nc)
163
        {
164
630
          if (!addToChildren(ncc, children, isNonAdditive))
165
          {
166
8
            childChanged = true;
167
          }
168
        }
169
      }
170
218348
      else if (!addToChildren(nc, children, isNonAdditive))
171
      {
172
82
        childChanged = true;
173
      }
174
    }
175
102867
    Assert(!children.empty());
176
    // Some commutative operators have rewriters that are agnostic to order,
177
    // thus, we sort here.
178
102867
    if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
179
    {
180
57303
      childChanged = true;
181
57303
      std::sort(children.begin(), children.end());
182
    }
183
102867
    if (childChanged)
184
    {
185
65932
      if (isNonAdditive && children.size() == 1)
186
      {
187
        // we may have subsumed children down to one
188
24
        ret = children[0];
189
      }
190
65908
      else if (isAssoc
191
65908
               && children.size() > kind::metakind::getMaxArityForKind(k))
192
      {
193
2
        Assert(kind::metakind::getMaxArityForKind(k) >= 2);
194
        // kind may require binary construction
195
2
        ret = children[0];
196
6
        for (unsigned i = 1, nchild = children.size(); i < nchild; i++)
197
        {
198
4
          ret = nm->mkNode(k, ret, children[i]);
199
        }
200
      }
201
      else
202
      {
203
65906
        ret = nm->mkNode(k, children);
204
      }
205
    }
206
  }
207
107943
  ret = Rewriter::rewrite(ret);
208
  //--------------------end rewrite children
209
210
  // now, do extended rewrite
211
215886
  Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
212
107943
                               << " (from " << n << ")" << std::endl;
213
215886
  Node new_ret;
214
215
  //---------------------- theory-independent post-rewriting
216
107943
  if (ret.getKind() == ITE)
217
  {
218
10210
    new_ret = extendedRewriteIte(ITE, ret);
219
  }
220
97733
  else if (ret.getKind() == AND || ret.getKind() == OR)
221
  {
222
14136
    new_ret = extendedRewriteAndOr(ret);
223
  }
224
83597
  else if (ret.getKind() == EQUAL)
225
  {
226
14167
    new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
227
14167
    debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
228
  }
229
107943
  Assert(new_ret.isNull() || new_ret != ret);
230
107943
  if (new_ret.isNull() && ret.getKind() != ITE)
231
  {
232
    // simple ITE pulling
233
95665
    new_ret = extendedRewritePullIte(ITE, ret);
234
  }
235
  //----------------------end theory-independent post-rewriting
236
237
  //----------------------theory-specific post-rewriting
238
107943
  if (new_ret.isNull())
239
  {
240
    TheoryId tid;
241
94234
    if (ret.getKind() == ITE)
242
    {
243
9456
      tid = Theory::theoryOf(ret.getType());
244
    }
245
    else
246
    {
247
84778
      tid = Theory::theoryOf(ret);
248
    }
249
188468
    Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
250
94234
                                 << std::endl;
251
94234
    if (tid == THEORY_STRINGS)
252
    {
253
1287
      new_ret = extendedRewriteStrings(ret);
254
    }
255
  }
256
  //----------------------end theory-specific post-rewriting
257
258
  //----------------------aggressive rewrites
259
107943
  if (new_ret.isNull() && d_aggr)
260
  {
261
93739
    new_ret = extendedRewriteAggr(ret);
262
  }
263
  //----------------------end aggressive rewrites
264
265
107943
  setCache(n, ret);
266
107943
  if (!new_ret.isNull())
267
  {
268
14911
    ret = extendedRewrite(new_ret);
269
  }
270
215886
  Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
271
107943
                               << std::endl;
272
107943
  if (Trace.isOn("q-ext-rewrite-nf"))
273
  {
274
    if (n == ret)
275
    {
276
      Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
277
    }
278
  }
279
107943
  setCache(n, ret);
280
107943
  return ret;
281
}
282
283
93739
Node ExtendedRewriter::extendedRewriteAggr(Node n)
284
{
285
93739
  Node new_ret;
286
187478
  Trace("q-ext-rewrite-debug2")
287
93739
      << "Do aggressive rewrites on " << n << std::endl;
288
93739
  bool polarity = n.getKind() != NOT;
289
187478
  Node ret_atom = n.getKind() == NOT ? n[0] : n;
290
196896
  if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
291
275252
      || ret_atom.getKind() == GEQ)
292
  {
293
    // ITE term removal in polynomials
294
    // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
295
39558
    Trace("q-ext-rewrite-debug2")
296
19779
        << "Compute monomial sum " << ret_atom << std::endl;
297
    // compute monomial sum
298
39558
    std::map<Node, Node> msum;
299
19779
    if (ArithMSum::getMonomialSumLit(ret_atom, msum))
300
    {
301
65760
      for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
302
           ++itm)
303
      {
304
92987
        Node v = itm->first;
305
94012
        Trace("q-ext-rewrite-debug2")
306
47006
            << itm->first << " * " << itm->second << std::endl;
307
47006
        if (v.getKind() == ITE)
308
        {
309
12157
          Node veq;
310
6591
          int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
311
6591
          if (res != 0)
312
          {
313
13182
            Trace("q-ext-rewrite-debug")
314
6591
                << "  have ITE relation, solved form : " << veq << std::endl;
315
            // try pulling ITE
316
6591
            new_ret = extendedRewritePullIte(ITE, veq);
317
6591
            if (!new_ret.isNull())
318
            {
319
1025
              if (!polarity)
320
              {
321
                new_ret = new_ret.negate();
322
              }
323
1025
              break;
324
            }
325
          }
326
          else
327
          {
328
            Trace("q-ext-rewrite-debug")
329
                << "  failed to isolate " << v << " in " << n << std::endl;
330
          }
331
        }
332
      }
333
    }
334
    else
335
    {
336
      Trace("q-ext-rewrite-debug")
337
          << "  failed to get monomial sum of " << n << std::endl;
338
    }
339
  }
340
  // TODO (#1706) : conditional rewriting, condition merging
341
187478
  return new_ret;
342
}
343
344
19364
Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
345
{
346
19364
  Assert(n.getKind() == itek);
347
19364
  Assert(n[1] != n[2]);
348
349
19364
  NodeManager* nm = NodeManager::currentNM();
350
351
19364
  Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
352
353
38728
  Node flip_cond;
354
19364
  if (n[0].getKind() == NOT)
355
  {
356
    flip_cond = n[0][0];
357
  }
358
19364
  else if (n[0].getKind() == OR)
359
  {
360
    // a | b ---> ~( ~a & ~b )
361
80
    flip_cond = TermUtil::simpleNegate(n[0]);
362
  }
363
19364
  if (!flip_cond.isNull())
364
  {
365
160
    Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
366
    // only print debug trace if full=true
367
80
    if (full)
368
    {
369
80
      debugExtendedRewrite(n, new_ret, "ITE flip");
370
    }
371
80
    return new_ret;
372
  }
373
  // Boolean true/false return
374
38568
  TypeNode tn = n.getType();
375
19284
  if (tn.isBoolean())
376
  {
377
29274
    for (unsigned i = 1; i <= 2; i++)
378
    {
379
19516
      if (n[i].isConst())
380
      {
381
        Node cond = i == 1 ? n[0] : n[0].negate();
382
        Node other = n[i == 1 ? 2 : 1];
383
        Kind retk = AND;
384
        if (n[i].getConst<bool>())
385
        {
386
          retk = OR;
387
        }
388
        else
389
        {
390
          cond = cond.negate();
391
        }
392
        Node new_ret = nm->mkNode(retk, cond, other);
393
        if (full)
394
        {
395
          // ite( A, true, B ) ---> A V B
396
          // ite( A, false, B ) ---> ~A /\ B
397
          // ite( A, B,  true ) ---> ~A V B
398
          // ite( A, B, false ) ---> A /\ B
399
          debugExtendedRewrite(n, new_ret, "ITE const return");
400
        }
401
        return new_ret;
402
      }
403
    }
404
  }
405
406
  // get entailed equalities in the condition
407
38568
  std::vector<Node> eq_conds;
408
19284
  Kind ck = n[0].getKind();
409
19284
  if (ck == EQUAL)
410
  {
411
6337
    eq_conds.push_back(n[0]);
412
  }
413
12947
  else if (ck == AND)
414
  {
415
977
    for (const Node& cn : n[0])
416
    {
417
666
      if (cn.getKind() == EQUAL)
418
      {
419
111
        eq_conds.push_back(cn);
420
      }
421
    }
422
  }
423
424
38568
  Node new_ret;
425
38568
  Node b;
426
38568
  Node e;
427
38568
  Node t1 = n[1];
428
38568
  Node t2 = n[2];
429
38568
  std::stringstream ss_reason;
430
431
25676
  for (const Node& eq : eq_conds)
432
  {
433
    // simple invariant ITE
434
19261
    for (unsigned i = 0; i <= 1; i++)
435
    {
436
      // ite( x = y ^ C, y, x ) ---> x
437
      // this is subsumed by the rewrites below
438
12869
      if (t2 == eq[i] && t1 == eq[1 - i])
439
      {
440
56
        new_ret = t2;
441
56
        ss_reason << "ITE simple rev subs";
442
56
        break;
443
      }
444
    }
445
6448
    if (!new_ret.isNull())
446
    {
447
56
      break;
448
    }
449
  }
450
19284
  if (new_ret.isNull())
451
  {
452
    // merging branches
453
57479
    for (unsigned i = 1; i <= 2; i++)
454
    {
455
38396
      if (n[i].getKind() == ITE)
456
      {
457
1266
        Node no = n[3 - i];
458
1695
        for (unsigned j = 1; j <= 2; j++)
459
        {
460
1207
          if (n[i][j] == no)
461
          {
462
            // e.g.
463
            // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
464
290
            Node nc1 = i == 2 ? n[0].negate() : n[0];
465
290
            Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
466
290
            Node new_cond = nm->mkNode(AND, nc1, nc2);
467
145
            new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
468
145
            ss_reason << "ITE merge branch";
469
145
            break;
470
          }
471
        }
472
      }
473
38396
      if (!new_ret.isNull())
474
      {
475
145
        break;
476
      }
477
    }
478
  }
479
480
19284
  if (new_ret.isNull() && d_aggr)
481
  {
482
    // If x is less than t based on an ordering, then we use { x -> t } as a
483
    // substitution to the children of ite( x = t ^ C, s, t ) below.
484
38158
    std::vector<Node> vars;
485
38158
    std::vector<Node> subs;
486
19079
    inferSubstitution(n[0], vars, subs, true);
487
488
19079
    if (!vars.empty())
489
    {
490
      // reverse substitution to opposite child
491
      // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
492
      // We can use ordinary substitute since the result of the substitution
493
      // is not being returned. In other words, nn is only being used to query
494
      // whether the second branch is a generalization of the first.
495
      Node nn =
496
38158
          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
497
19079
      if (nn != t2)
498
      {
499
2305
        nn = Rewriter::rewrite(nn);
500
2305
        if (nn == t1)
501
        {
502
7
          new_ret = t2;
503
7
          ss_reason << "ITE rev subs";
504
        }
505
      }
506
507
      // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
508
      // must use partial substitute here, to avoid substitution into witness
509
38158
      std::map<Kind, bool> rkinds;
510
19079
      nn = partialSubstitute(t1, vars, subs, rkinds);
511
19079
      if (nn != t1)
512
      {
513
        // If full=false, then we've duplicated a term u in the children of n.
514
        // For example, when ITE pulling, we have n is of the form:
515
        //   ite( C, f( u, t1 ), f( u, t2 ) )
516
        // We must show that at least one copy of u dissappears in this case.
517
1291
        nn = Rewriter::rewrite(nn);
518
1291
        if (nn == t2)
519
        {
520
5
          new_ret = nn;
521
5
          ss_reason << "ITE subs invariant";
522
        }
523
1286
        else if (full || nn.isConst())
524
        {
525
570
          new_ret = nm->mkNode(itek, n[0], nn, t2);
526
570
          ss_reason << "ITE subs";
527
        }
528
      }
529
    }
530
19079
    if (new_ret.isNull())
531
    {
532
      // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
533
      // use partial substitute to avoid substitution into witness
534
36994
      std::map<Node, Node> assign;
535
18497
      assign[n[0]] = d_false;
536
36994
      std::map<Kind, bool> rkinds;
537
36994
      Node nn = partialSubstitute(t2, assign, rkinds);
538
18497
      if (nn != t2)
539
      {
540
124
        nn = Rewriter::rewrite(nn);
541
124
        if (nn == t1)
542
        {
543
2
          new_ret = nn;
544
2
          ss_reason << "ITE subs invariant false";
545
        }
546
122
        else if (full || nn.isConst())
547
        {
548
44
          new_ret = nm->mkNode(itek, n[0], t1, nn);
549
44
          ss_reason << "ITE subs false";
550
        }
551
      }
552
    }
553
  }
554
555
  // only print debug trace if full=true
556
19284
  if (!new_ret.isNull() && full)
557
  {
558
674
    debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
559
  }
560
561
19284
  return new_ret;
562
}
563
564
14136
Node ExtendedRewriter::extendedRewriteAndOr(Node n)
565
{
566
  // all the below rewrites are aggressive
567
14136
  if (!d_aggr)
568
  {
569
18
    return Node::null();
570
  }
571
28236
  Node new_ret;
572
  // we allow substitutions to recurse over any kind, except WITNESS which is
573
  // managed by partialSubstitute.
574
28236
  std::map<Kind, bool> bcp_kinds;
575
14118
  new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
576
14118
  if (!new_ret.isNull())
577
  {
578
1465
    debugExtendedRewrite(n, new_ret, "Bool bcp");
579
1465
    return new_ret;
580
  }
581
  // factoring
582
12653
  new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
583
12653
  if (!new_ret.isNull())
584
  {
585
75
    debugExtendedRewrite(n, new_ret, "Bool factoring");
586
75
    return new_ret;
587
  }
588
589
  // equality resolution
590
12578
  new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
591
12578
  debugExtendedRewrite(n, new_ret, "Bool eq res");
592
12578
  return new_ret;
593
}
594
595
102256
Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
596
{
597
102256
  Assert(n.getKind() != ITE);
598
102256
  if (n.isClosure())
599
  {
600
    // don't pull ITE out of quantifiers
601
100
    return n;
602
  }
603
102156
  NodeManager* nm = NodeManager::currentNM();
604
204312
  TypeNode tn = n.getType();
605
204312
  std::vector<Node> children;
606
102156
  bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
607
102156
  if (hasOp)
608
  {
609
3433
    children.push_back(n.getOperator());
610
  }
611
102156
  unsigned nchildren = n.getNumChildren();
612
296342
  for (unsigned i = 0; i < nchildren; i++)
613
  {
614
194186
    children.push_back(n[i]);
615
  }
616
204312
  std::map<unsigned, std::map<unsigned, Node> > ite_c;
617
280344
  for (unsigned i = 0; i < nchildren; i++)
618
  {
619
    // only pull ITEs apart if we are aggressive
620
569040
    if (n[i].getKind() == itek
621
569040
        && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
622
    {
623
20818
      unsigned ii = hasOp ? i + 1 : i;
624
62454
      for (unsigned j = 0; j < 2; j++)
625
      {
626
41636
        children[ii] = n[i][j + 1];
627
83272
        Node pull = nm->mkNode(n.getKind(), children);
628
83272
        Node pullr = Rewriter::rewrite(pull);
629
41636
        children[ii] = n[i];
630
41636
        ite_c[i][j] = pullr;
631
      }
632
20818
      if (ite_c[i][0] == ite_c[i][1])
633
      {
634
        // ITE dual invariance
635
        // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
636
        // f( t1..ite( A, s1, s2 )..tn ) ---> t
637
397
        debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
638
397
        return ite_c[i][0];
639
      }
640
20421
      if (d_aggr)
641
      {
642
40127
        if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
643
33178
            && !n[1 - i].getType().isBoolean() && tn.isBoolean())
644
        {
645
          // always pull variable or constant with binary (theory) predicate
646
          // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
647
16696
          Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
648
8348
          debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
649
8348
          return new_ret;
650
        }
651
31990
        for (unsigned j = 0; j < 2; j++)
652
        {
653
42585
          Node pullr = ite_c[i][j];
654
22666
          if (pullr.isConst() || pullr == n[i][j + 1])
655
          {
656
            // ITE single child elimination
657
            // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
658
            // implies
659
            // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
660
5494
            Node new_ret;
661
2747
            if (tn.isBoolean() && pullr.isConst())
662
            {
663
              // remove false/true child immediately
664
824
              bool pol = pullr.getConst<bool>();
665
1648
              std::vector<Node> new_children;
666
824
              new_children.push_back((j == 0) == pol ? n[i][0]
667
                                                     : n[i][0].negate());
668
824
              new_children.push_back(ite_c[i][1 - j]);
669
824
              new_ret = nm->mkNode(pol ? OR : AND, new_children);
670
824
              debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
671
            }
672
            else
673
            {
674
1923
              new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
675
1923
              debugExtendedRewrite(n, new_ret, "ITE single elim");
676
            }
677
2747
            return new_ret;
678
          }
679
        }
680
      }
681
    }
682
  }
683
90664
  if (d_aggr)
684
  {
685
99347
    for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
686
    {
687
18318
      Node nite = n[ip.first];
688
9319
      Assert(nite.getKind() == itek);
689
      // now, simply pull the ITE and try ITE rewrites
690
18318
      Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
691
9319
      pull_ite = Rewriter::rewrite(pull_ite);
692
9319
      if (pull_ite.getKind() == ITE)
693
      {
694
18153
        Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
695
9154
        if (!new_pull_ite.isNull())
696
        {
697
155
          debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
698
155
          return new_pull_ite;
699
        }
700
      }
701
      else
702
      {
703
        // A general rewrite could eliminate the ITE by pulling.
704
        // An example is:
705
        //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
706
        //   ite( C, ~~x, ite( C, y, x ) ) --->
707
        //   x
708
        // where ~ is bitvector negation.
709
165
        debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
710
165
        return pull_ite;
711
      }
712
    }
713
  }
714
715
90344
  return Node::null();
716
}
717
718
11391
Node ExtendedRewriter::extendedRewriteNnf(Node ret)
719
{
720
11391
  Assert(ret.getKind() == NOT);
721
722
11391
  Kind nk = ret[0].getKind();
723
11391
  bool neg_ch = false;
724
11391
  bool neg_ch_1 = false;
725
11391
  if (nk == AND || nk == OR)
726
  {
727
1446
    neg_ch = true;
728
1446
    nk = nk == AND ? OR : AND;
729
  }
730
9945
  else if (nk == IMPLIES)
731
  {
732
    neg_ch = true;
733
    neg_ch_1 = true;
734
    nk = AND;
735
  }
736
9945
  else if (nk == ITE)
737
  {
738
51
    neg_ch = true;
739
51
    neg_ch_1 = true;
740
  }
741
9894
  else if (nk == XOR)
742
  {
743
7
    nk = EQUAL;
744
  }
745
9887
  else if (nk == EQUAL && ret[0][0].getType().isBoolean())
746
  {
747
6
    neg_ch_1 = true;
748
  }
749
  else
750
  {
751
9881
    return Node::null();
752
  }
753
754
3020
  std::vector<Node> new_children;
755
5174
  for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
756
  {
757
7328
    Node c = ret[0][i];
758
3664
    c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
759
3664
    new_children.push_back(c);
760
  }
761
1510
  return NodeManager::currentNM()->mkNode(nk, new_children);
762
}
763
764
14118
Node ExtendedRewriter::extendedRewriteBcp(
765
    Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret)
766
{
767
14118
  Kind k = ret.getKind();
768
14118
  Assert(k == andk || k == ork);
769
14118
  Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
770
771
14118
  NodeManager* nm = NodeManager::currentNM();
772
773
28236
  TypeNode tn = ret.getType();
774
28236
  Node truen = TermUtil::mkTypeMaxValue(tn);
775
28236
  Node falsen = TermUtil::mkTypeValue(tn, 0);
776
777
  // terms to process
778
28236
  std::vector<Node> to_process;
779
49566
  for (const Node& cn : ret)
780
  {
781
35448
    to_process.push_back(cn);
782
  }
783
  // the processing terms
784
28236
  std::vector<Node> clauses;
785
  // the terms we have propagated information to
786
28236
  std::unordered_set<Node> prop_clauses;
787
  // the assignment
788
28236
  std::map<Node, Node> assign;
789
28236
  std::vector<Node> avars;
790
28236
  std::vector<Node> asubs;
791
792
14118
  Kind ok = k == andk ? ork : andk;
793
  // global polarity : when k=ork, everything is negated
794
14118
  bool gpol = k == andk;
795
796
1283
  do
797
  {
798
    // process the current nodes
799
45937
    while (!to_process.empty())
800
    {
801
30743
      std::vector<Node> new_to_process;
802
52090
      for (const Node& cn : to_process)
803
      {
804
36822
        Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
805
36822
        Kind cnk = cn.getKind();
806
36822
        bool pol = cnk != notk;
807
73437
        Node cln = cnk == notk ? cn[0] : cn;
808
36822
        Assert(cln.getKind() != notk);
809
36822
        if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
810
        {
811
          // flatten
812
222
          for (const Node& ccln : cln)
813
          {
814
296
            Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
815
148
            new_to_process.push_back(lccln);
816
          }
817
        }
818
        else
819
        {
820
          // add it to the assignment
821
73289
          Node val = gpol == pol ? truen : falsen;
822
36748
          std::map<Node, Node>::iterator it = assign.find(cln);
823
73496
          Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
824
36748
                               << std::endl;
825
36748
          if (it != assign.end())
826
          {
827
210
            if (val != it->second)
828
            {
829
207
              Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
830
              // a conflicting assignment: we are done
831
207
              return gpol ? falsen : truen;
832
            }
833
          }
834
          else
835
          {
836
36538
            assign[cln] = val;
837
36538
            avars.push_back(cln);
838
36538
            asubs.push_back(val);
839
          }
840
841
          // also, treat it as clause if possible
842
73082
          if (cln.getNumChildren() > 0
843
63719
              && (bcp_kinds.empty()
844
36541
                  || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
845
          {
846
54356
            if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
847
27178
                && prop_clauses.find(cn) == prop_clauses.end())
848
            {
849
27178
              Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
850
27178
              clauses.push_back(cn);
851
            }
852
          }
853
        }
854
      }
855
15268
      to_process.clear();
856
15268
      to_process.insert(
857
30536
          to_process.end(), new_to_process.begin(), new_to_process.end());
858
    }
859
860
    // apply substitution to all subterms of clauses
861
30388
    std::vector<Node> new_clauses;
862
43373
    for (const Node& c : clauses)
863
    {
864
28179
      bool cpol = c.getKind() != notk;
865
56358
      Node ca = c.getKind() == notk ? c[0] : c;
866
28179
      bool childChanged = false;
867
56358
      std::vector<Node> ccs_children;
868
87517
      for (const Node& cc : ca)
869
      {
870
        // always use partial substitute, to avoid substitution in witness
871
59338
        Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
872
        // substitution is only applicable to compatible kinds in bcp_kinds
873
118676
        Node ccs = partialSubstitute(cc, assign, bcp_kinds);
874
59338
        childChanged = childChanged || ccs != cc;
875
59338
        ccs_children.push_back(ccs);
876
      }
877
28179
      if (childChanged)
878
      {
879
1308
        if (ca.getMetaKind() == metakind::PARAMETERIZED)
880
        {
881
          ccs_children.insert(ccs_children.begin(), ca.getOperator());
882
        }
883
2616
        Node ccs = nm->mkNode(ca.getKind(), ccs_children);
884
1308
        ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
885
2616
        Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
886
1308
                             << std::endl;
887
1308
        ccs = Rewriter::rewrite(ccs);
888
1308
        Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
889
1308
        to_process.push_back(ccs);
890
        // store this as a node that propagation touched. This marks c so that
891
        // it will not be included in the final construction.
892
1308
        prop_clauses.insert(ca);
893
      }
894
      else
895
      {
896
26871
        new_clauses.push_back(c);
897
      }
898
    }
899
15194
    clauses.clear();
900
15194
    clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
901
15194
  } while (!to_process.empty());
902
903
  // remake the node
904
13911
  if (!prop_clauses.empty())
905
  {
906
2516
    std::vector<Node> children;
907
6262
    for (std::pair<const Node, Node>& l : assign)
908
    {
909
10008
      Node a = l.first;
910
      // if propagation did not touch a
911
5004
      if (prop_clauses.find(a) == prop_clauses.end())
912
      {
913
3696
        Assert(l.second == truen || l.second == falsen);
914
7392
        Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
915
3696
        children.push_back(ln);
916
      }
917
    }
918
2516
    Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
919
1258
    Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
920
1258
    return new_ret;
921
  }
922
923
12653
  return Node::null();
924
}
925
926
12653
Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
927
                                                Kind ork,
928
                                                Kind notk,
929
                                                Node n)
930
{
931
12653
  Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
932
12653
  NodeManager* nm = NodeManager::currentNM();
933
934
12653
  Kind nk = n.getKind();
935
12653
  Assert(nk == andk || nk == ork);
936
12653
  Kind onk = nk == andk ? ork : andk;
937
  // count the number of times atoms occur
938
25306
  std::map<Node, std::vector<Node> > lit_to_cl;
939
25306
  std::map<Node, std::vector<Node> > cl_to_lits;
940
43806
  for (const Node& nc : n)
941
  {
942
31153
    Kind nck = nc.getKind();
943
31153
    if (nck == onk)
944
    {
945
13844
      for (const Node& ncl : nc)
946
      {
947
28788
        if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
948
28788
            == lit_to_cl[ncl].end())
949
        {
950
9596
          lit_to_cl[ncl].push_back(nc);
951
9596
          cl_to_lits[nc].push_back(ncl);
952
        }
953
      }
954
    }
955
    else
956
    {
957
26905
      lit_to_cl[nc].push_back(nc);
958
26905
      cl_to_lits[nc].push_back(nc);
959
    }
960
  }
961
  // get the maximum shared literal to factor
962
12653
  unsigned max_size = 0;
963
25306
  Node flit;
964
48975
  for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
965
  {
966
36322
    if (ltc.second.size() > max_size)
967
    {
968
12685
      max_size = ltc.second.size();
969
12685
      flit = ltc.first;
970
    }
971
  }
972
12653
  if (max_size > 1)
973
  {
974
    // do the factoring
975
150
    std::vector<Node> children;
976
150
    std::vector<Node> fchildren;
977
75
    std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
978
75
    std::vector<Node>& cls = itl->second;
979
439
    for (const Node& nc : n)
980
    {
981
364
      if (std::find(cls.begin(), cls.end(), nc) == cls.end())
982
      {
983
147
        children.push_back(nc);
984
      }
985
      else
986
      {
987
        // rebuild
988
217
        std::vector<Node>& lits = cl_to_lits[nc];
989
        std::vector<Node>::iterator itlfl =
990
217
            std::find(lits.begin(), lits.end(), flit);
991
217
        Assert(itlfl != lits.end());
992
217
        lits.erase(itlfl);
993
        // rebuild
994
217
        if (!lits.empty())
995
        {
996
434
          Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
997
217
          fchildren.push_back(new_cl);
998
        }
999
      }
1000
    }
1001
    // rebuild the factored children
1002
75
    Assert(!fchildren.empty());
1003
150
    Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
1004
75
    children.push_back(nm->mkNode(onk, flit, fcn));
1005
150
    Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
1006
75
    Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
1007
75
    return ret;
1008
  }
1009
  else
1010
  {
1011
12578
    Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
1012
  }
1013
12578
  return Node::null();
1014
}
1015
1016
12578
Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
1017
                                            Kind ork,
1018
                                            Kind eqk,
1019
                                            Kind notk,
1020
                                            std::map<Kind, bool>& bcp_kinds,
1021
                                            Node n,
1022
                                            bool isXor)
1023
{
1024
12578
  Assert(n.getKind() == andk || n.getKind() == ork);
1025
12578
  Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
1026
1027
12578
  NodeManager* nm = NodeManager::currentNM();
1028
12578
  Kind nk = n.getKind();
1029
12578
  bool gpol = (nk == andk);
1030
42668
  for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
1031
  {
1032
60496
    Node lit = n[i];
1033
30406
    if (lit.getKind() == eqk)
1034
    {
1035
      // eq is the equality we are basing a substitution on
1036
9444
      Node eq;
1037
4880
      if (gpol == isXor)
1038
      {
1039
        // can only turn disequality into equality if types are the same
1040
2399
        if (lit[1].getType() == lit.getType())
1041
        {
1042
          // t != s ---> ~t = s
1043
43
          if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1044
          {
1045
29
            eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1046
          }
1047
          else
1048
          {
1049
14
            eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1050
          }
1051
        }
1052
      }
1053
      else
1054
      {
1055
2481
        eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1056
      }
1057
4880
      if (!eq.isNull())
1058
      {
1059
        // see if it corresponds to a substitution
1060
4732
        std::vector<Node> vars;
1061
4732
        std::vector<Node> subs;
1062
2524
        if (inferSubstitution(eq, vars, subs))
1063
        {
1064
2095
          Assert(vars.size() == 1);
1065
3874
          std::vector<Node> children;
1066
2095
          bool childrenChanged = false;
1067
          // apply to all other children
1068
11402
          for (unsigned j = 0; j < nchild; j++)
1069
          {
1070
18614
            Node ccs = n[j];
1071
9307
            if (i != j)
1072
            {
1073
              // Substitution is only applicable to compatible kinds. We always
1074
              // use the partialSubstitute method to avoid substitution into
1075
              // witness terms.
1076
7212
              ccs = partialSubstitute(ccs, vars, subs, bcp_kinds);
1077
7212
              childrenChanged = childrenChanged || n[j] != ccs;
1078
            }
1079
9307
            children.push_back(ccs);
1080
          }
1081
2095
          if (childrenChanged)
1082
          {
1083
316
            return nm->mkNode(nk, children);
1084
          }
1085
        }
1086
      }
1087
    }
1088
  }
1089
1090
12262
  return Node::null();
1091
}
1092
1093
/** sort pairs by their second (unsigned) argument */
1094
28222
static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1095
                           const std::pair<Node, unsigned>& b)
1096
{
1097
28222
  return (a.second < b.second);
1098
}
1099
1100
/** A simple subsumption trie used to compute pairwise list subsets */
1101
85352
class SimpSubsumeTrie
1102
{
1103
 public:
1104
  /** the children of this node */
1105
  std::map<Node, SimpSubsumeTrie> d_children;
1106
  /** the term at this node */
1107
  Node d_data;
1108
  /** add term to the trie
1109
   *
1110
   * This adds term c to this trie, whose atom list is alist. This adds terms
1111
   * s to subsumes such that the atom list of s is a subset of the atom list
1112
   * of c. For example, say:
1113
   *   c1.alist = { A }
1114
   *   c2.alist = { C }
1115
   *   c3.alist = { B, C }
1116
   *   c4.alist = { A, B, D }
1117
   *   c5.alist = { A, B, C }
1118
   * If these terms are added in the order c1, c2, c3, c4, c5, then:
1119
   *   addTerm c1 results in subsumes = {}
1120
   *   addTerm c2 results in subsumes = {}
1121
   *   addTerm c3 results in subsumes = { c2 }
1122
   *   addTerm c4 results in subsumes = { c1 }
1123
   *   addTerm c5 results in subsumes = { c1, c2, c3 }
1124
   * Notice that the intended use case of this trie is to add term t before t'
1125
   * only when size( t.alist ) <= size( t'.alist ).
1126
   *
1127
   * The last two arguments describe the state of the path [t0...tn] we
1128
   * have followed in the trie during the recursive call.
1129
   * If doAdd = true,
1130
   *   then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1131
   *   we add c as the current node of this trie.
1132
   * If doAdd = false,
1133
   *   then t1...tn occur in alist.
1134
   */
1135
56784
  void addTerm(Node c,
1136
               std::vector<Node>& alist,
1137
               std::vector<Node>& subsumes,
1138
               unsigned index = 0,
1139
               bool doAdd = true)
1140
  {
1141
56784
    if (!d_data.isNull())
1142
    {
1143
3
      subsumes.push_back(d_data);
1144
    }
1145
56784
    if (doAdd)
1146
    {
1147
56778
      if (index == alist.size())
1148
      {
1149
28167
        d_data = c;
1150
28167
        return;
1151
      }
1152
    }
1153
    // try all children where we have this atom
1154
42818
    for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1155
    {
1156
14201
      if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1157
      {
1158
6
        cp.second.addTerm(c, alist, subsumes, 0, false);
1159
      }
1160
    }
1161
28617
    if (doAdd)
1162
    {
1163
28611
      d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1164
    }
1165
  }
1166
};
1167
1168
14167
Node ExtendedRewriter::extendedRewriteEqChain(
1169
    Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor)
1170
{
1171
14167
  Assert(ret.getKind() == eqk);
1172
1173
  // this rewrite is aggressive; it in fact has the precondition that other
1174
  // aggressive rewrites (including BCP) have been applied.
1175
14167
  if (!d_aggr)
1176
  {
1177
96
    return Node::null();
1178
  }
1179
1180
14071
  NodeManager* nm = NodeManager::currentNM();
1181
1182
28142
  TypeNode tn = ret[0].getType();
1183
1184
  // sort/cancelling for Boolean EQUAL/XOR-chains
1185
14071
  Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1186
1187
  // get the children on either side
1188
14071
  bool gpol = true;
1189
28142
  std::vector<Node> children;
1190
42213
  for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1191
  {
1192
56284
    Node curr = ret[r];
1193
    // assume, if necessary, right associative
1194
28400
    while (curr.getKind() == eqk && curr[0].getType() == tn)
1195
    {
1196
129
      children.push_back(curr[0]);
1197
129
      curr = curr[1];
1198
    }
1199
28142
    children.push_back(curr);
1200
  }
1201
1202
28142
  std::map<Node, bool> cstatus;
1203
  // add children to status
1204
42342
  for (const Node& c : children)
1205
  {
1206
56542
    Node a = c;
1207
28271
    if (a.getKind() == notk)
1208
    {
1209
262
      gpol = !gpol;
1210
262
      a = a[0];
1211
    }
1212
28271
    Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1213
28271
    std::map<Node, bool>::iterator itc = cstatus.find(a);
1214
28271
    if (itc == cstatus.end())
1215
    {
1216
28225
      cstatus[a] = true;
1217
    }
1218
    else
1219
    {
1220
      // cancels
1221
46
      cstatus.erase(a);
1222
46
      if (isXor)
1223
      {
1224
        gpol = !gpol;
1225
      }
1226
    }
1227
  }
1228
14071
  Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1229
1230
14071
  if (cstatus.empty())
1231
  {
1232
2
    return TermUtil::mkTypeConst(tn, gpol);
1233
  }
1234
1235
14069
  children.clear();
1236
1237
  // compute the atoms of each child
1238
14069
  Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1239
14069
  Trace("ext-rew-eqchain") << "  eqchain-simplify: get atoms...\n";
1240
28138
  std::map<Node, std::map<Node, bool> > atoms;
1241
28138
  std::map<Node, std::vector<Node> > alist;
1242
28138
  std::vector<std::pair<Node, unsigned> > atom_count;
1243
42248
  for (std::pair<const Node, bool>& cp : cstatus)
1244
  {
1245
28179
    if (!cp.second)
1246
    {
1247
      // already eliminated
1248
      continue;
1249
    }
1250
56358
    Node c = cp.first;
1251
28179
    Kind ck = c.getKind();
1252
28179
    Trace("ext-rew-eqchain") << "  process c = " << c << std::endl;
1253
28179
    if (ck == andk || ck == ork)
1254
    {
1255
642
      for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1256
      {
1257
1086
        Node cl = c[j];
1258
546
        bool pol = cl.getKind() != notk;
1259
1086
        Node ca = pol ? cl : cl[0];
1260
546
        bool newVal = (ck == andk ? !pol : pol);
1261
1092
        Trace("ext-rew-eqchain")
1262
546
            << "  atoms(" << c << ", " << ca << ") = " << newVal << std::endl;
1263
546
        Assert(atoms[c].find(ca) == atoms[c].end());
1264
        // polarity is flipped when we are AND
1265
546
        atoms[c][ca] = newVal;
1266
546
        alist[c].push_back(ca);
1267
1268
        // if this already exists as a child of the equality chain, eliminate.
1269
        // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1270
        // consider ( x & y ) a unit, whereas below it is expanded to
1271
        // ~( ~x | ~y ).
1272
546
        std::map<Node, bool>::iterator itc = cstatus.find(ca);
1273
546
        if (itc != cstatus.end() && itc->second)
1274
        {
1275
          // cancel it
1276
6
          cstatus[ca] = false;
1277
6
          cstatus[c] = false;
1278
          // make new child
1279
          // x = ( y | ~x ) ---> y & x
1280
          // x = ( y | x ) ---> ~y | x
1281
          // x = ( y & x ) ---> y | ~x
1282
          // x = ( y & ~x ) ---> ~y & ~x
1283
12
          std::vector<Node> new_children;
1284
18
          for (unsigned k = 0, nchildc = c.getNumChildren(); k < nchildc; k++)
1285
          {
1286
12
            if (j != k)
1287
            {
1288
6
              new_children.push_back(c[k]);
1289
            }
1290
          }
1291
12
          Node nc[2];
1292
6
          nc[0] = c[j];
1293
6
          nc[1] = new_children.size() == 1 ? new_children[0]
1294
                                           : nm->mkNode(ck, new_children);
1295
          // negate the proper child
1296
6
          unsigned nindex = (ck == andk) == pol ? 0 : 1;
1297
6
          nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1298
6
          Kind nk = pol ? ork : andk;
1299
          // store as new child
1300
6
          children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1301
6
          if (isXor)
1302
          {
1303
            gpol = !gpol;
1304
          }
1305
6
          break;
1306
        }
1307
102
      }
1308
    }
1309
    else
1310
    {
1311
28077
      bool pol = ck != notk;
1312
56154
      Node ca = pol ? c : c[0];
1313
28077
      atoms[c][ca] = pol;
1314
56154
      Trace("ext-rew-eqchain")
1315
28077
          << "  atoms(" << c << ", " << ca << ") = " << pol << std::endl;
1316
28077
      alist[c].push_back(ca);
1317
    }
1318
28179
    atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1319
  }
1320
  // sort the atoms in each atom list
1321
42248
  for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1322
42248
       it != alist.end();
1323
       ++it)
1324
  {
1325
28179
    std::sort(it->second.begin(), it->second.end());
1326
  }
1327
  // check subsumptions
1328
  // sort by #atoms
1329
14069
  std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1330
14069
  if (Trace.isOn("ext-rew-eqchain"))
1331
  {
1332
    for (const std::pair<Node, unsigned>& ac : atom_count)
1333
    {
1334
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << ac.first << " has "
1335
                               << ac.second << " atoms." << std::endl;
1336
    }
1337
    Trace("ext-rew-eqchain") << "  eqchain-simplify: compute subsumptions...\n";
1338
  }
1339
28138
  SimpSubsumeTrie sst;
1340
42248
  for (std::pair<const Node, bool>& cp : cstatus)
1341
  {
1342
28179
    if (!cp.second)
1343
    {
1344
      // already eliminated
1345
12
      continue;
1346
    }
1347
56334
    Node c = cp.first;
1348
28167
    std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1349
28167
    Assert(itc != atoms.end());
1350
56334
    Trace("ext-rew-eqchain") << "  - add term " << c << " with atom list "
1351
28167
                             << alist[c] << "...\n";
1352
56334
    std::vector<Node> subsumes;
1353
28167
    sst.addTerm(c, alist[c], subsumes);
1354
28167
    for (const Node& cc : subsumes)
1355
    {
1356
1
      if (!cstatus[cc])
1357
      {
1358
        // subsumes a child that was already eliminated
1359
        continue;
1360
      }
1361
2
      Trace("ext-rew-eqchain") << "  eqchain-simplify: " << c << " subsumes "
1362
1
                               << cc << std::endl;
1363
      // for each of the atoms in cc
1364
1
      std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1365
1
      Assert(itcc != atoms.end());
1366
1
      std::vector<Node> common_children;
1367
1
      std::vector<Node> diff_children;
1368
3
      for (const std::pair<const Node, bool>& ap : itcc->second)
1369
      {
1370
        // compare the polarity
1371
4
        Node a = ap.first;
1372
2
        bool polcc = ap.second;
1373
2
        Assert(itc->second.find(a) != itc->second.end());
1374
2
        bool polc = itc->second[a];
1375
4
        Trace("ext-rew-eqchain") << "    eqchain-simplify: atom " << a
1376
2
                                 << " has polarities : " << polc << " " << polcc
1377
2
                                 << "\n";
1378
4
        Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1379
2
        if (polc != polcc)
1380
        {
1381
2
          diff_children.push_back(lit);
1382
        }
1383
        else
1384
        {
1385
          common_children.push_back(lit);
1386
        }
1387
      }
1388
1
      std::vector<Node> rem_children;
1389
3
      for (const std::pair<const Node, bool>& ap : itc->second)
1390
      {
1391
4
        Node a = ap.first;
1392
2
        if (atoms[cc].find(a) == atoms[cc].end())
1393
        {
1394
          bool polc = ap.second;
1395
          rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1396
        }
1397
      }
1398
2
      Trace("ext-rew-eqchain")
1399
2
          << "    #common/diff/rem: " << common_children.size() << "/"
1400
1
          << diff_children.size() << "/" << rem_children.size() << "\n";
1401
1
      bool do_rewrite = false;
1402
3
      if (common_children.empty() && itc->second.size() == itcc->second.size()
1403
2
          && itcc->second.size() == 2)
1404
      {
1405
        // x | y = ~x | ~y ---> ~( x = y )
1406
1
        do_rewrite = true;
1407
1
        children.push_back(diff_children[0]);
1408
1
        children.push_back(diff_children[1]);
1409
1
        gpol = !gpol;
1410
1
        Trace("ext-rew-eqchain") << "    apply 2-child all-diff\n";
1411
      }
1412
      else if (common_children.empty() && diff_children.size() == 1)
1413
      {
1414
        do_rewrite = true;
1415
        // x = ( ~x | y ) ---> ~( ~x | ~y )
1416
        Node remn = rem_children.size() == 1 ? rem_children[0]
1417
                                             : nm->mkNode(ork, rem_children);
1418
        remn = TermUtil::mkNegate(notk, remn);
1419
        children.push_back(nm->mkNode(ork, diff_children[0], remn));
1420
        if (!isXor)
1421
        {
1422
          gpol = !gpol;
1423
        }
1424
        Trace("ext-rew-eqchain") << "    apply unit resolution\n";
1425
      }
1426
      else if (diff_children.size() == 1
1427
               && itc->second.size() == itcc->second.size())
1428
      {
1429
        // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1430
        do_rewrite = true;
1431
        Assert(!common_children.empty());
1432
        Node comn = common_children.size() == 1
1433
                        ? common_children[0]
1434
                        : nm->mkNode(ork, common_children);
1435
        children.push_back(comn);
1436
        if (isXor)
1437
        {
1438
          gpol = !gpol;
1439
        }
1440
        Trace("ext-rew-eqchain") << "    apply resolution\n";
1441
      }
1442
      else if (diff_children.empty())
1443
      {
1444
        do_rewrite = true;
1445
        if (rem_children.empty())
1446
        {
1447
          // x | y = x | y ---> true
1448
          // this can happen if we have ( ~x & ~y ) = ( x | y )
1449
          children.push_back(TermUtil::mkTypeMaxValue(tn));
1450
          if (isXor)
1451
          {
1452
            gpol = !gpol;
1453
          }
1454
          Trace("ext-rew-eqchain") << "    apply cancel\n";
1455
        }
1456
        else
1457
        {
1458
          // x | y = ( x | y | z ) ---> ( x | y | ~z )
1459
          Node remn = rem_children.size() == 1 ? rem_children[0]
1460
                                               : nm->mkNode(ork, rem_children);
1461
          remn = TermUtil::mkNegate(notk, remn);
1462
          Node comn = common_children.size() == 1
1463
                          ? common_children[0]
1464
                          : nm->mkNode(ork, common_children);
1465
          children.push_back(nm->mkNode(ork, comn, remn));
1466
          if (isXor)
1467
          {
1468
            gpol = !gpol;
1469
          }
1470
          Trace("ext-rew-eqchain") << "    apply subsume\n";
1471
        }
1472
      }
1473
1
      if (do_rewrite)
1474
      {
1475
        // eliminate the children, reverse polarity as needed
1476
3
        for (unsigned r = 0; r < 2; r++)
1477
        {
1478
4
          Node c_rem = r == 0 ? c : cc;
1479
2
          cstatus[c_rem] = false;
1480
2
          if (c_rem.getKind() == andk)
1481
          {
1482
            gpol = !gpol;
1483
          }
1484
        }
1485
1
        break;
1486
      }
1487
    }
1488
  }
1489
14069
  Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1490
1491
  // sorted right associative chain
1492
14069
  bool has_nvar = false;
1493
14069
  unsigned nvar_index = 0;
1494
42248
  for (std::pair<const Node, bool>& cp : cstatus)
1495
  {
1496
28179
    if (cp.second)
1497
    {
1498
28165
      if (!cp.first.isVar())
1499
      {
1500
19330
        has_nvar = true;
1501
19330
        nvar_index = children.size();
1502
      }
1503
28165
      children.push_back(cp.first);
1504
    }
1505
  }
1506
14069
  std::sort(children.begin(), children.end());
1507
1508
28138
  Node new_ret;
1509
14069
  if (!gpol)
1510
  {
1511
    // negate the constant child if it exists
1512
203
    unsigned nindex = has_nvar ? nvar_index : 0;
1513
203
    children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1514
  }
1515
14069
  new_ret = children.back();
1516
14069
  unsigned index = children.size() - 1;
1517
42277
  while (index > 0)
1518
  {
1519
14104
    index--;
1520
14104
    new_ret = nm->mkNode(eqk, children[index], new_ret);
1521
  }
1522
14069
  new_ret = Rewriter::rewrite(new_ret);
1523
14069
  if (new_ret != ret)
1524
  {
1525
210
    return new_ret;
1526
  }
1527
13859
  return Node::null();
1528
}
1529
1530
104126
Node ExtendedRewriter::partialSubstitute(Node n,
1531
                                         const std::map<Node, Node>& assign,
1532
                                         const std::map<Kind, bool>& rkinds)
1533
{
1534
208252
  std::unordered_map<TNode, Node> visited;
1535
104126
  std::unordered_map<TNode, Node>::iterator it;
1536
104126
  std::map<Node, Node>::const_iterator ita;
1537
208252
  std::vector<TNode> visit;
1538
208252
  TNode cur;
1539
104126
  visit.push_back(n);
1540
646065
  do
1541
  {
1542
750191
    cur = visit.back();
1543
750191
    visit.pop_back();
1544
750191
    it = visited.find(cur);
1545
1546
750191
    if (it == visited.end())
1547
    {
1548
353583
      ita = assign.find(cur);
1549
353583
      if (ita != assign.end())
1550
      {
1551
3200
        visited[cur] = ita->second;
1552
      }
1553
      else
1554
      {
1555
        // If rkinds is non-empty, then can only recurse on its kinds.
1556
        // We also always disallow substitution into witness. Notice that
1557
        // we disallow witness here, due to unsoundness when applying contextual
1558
        // substitutions over witness terms (see #4620).
1559
350383
        Kind k = cur.getKind();
1560
350383
        if (k != WITNESS && (rkinds.empty() || rkinds.find(k) != rkinds.end()))
1561
        {
1562
350383
          visited[cur] = Node::null();
1563
350383
          visit.push_back(cur);
1564
646065
          for (const Node& cn : cur)
1565
          {
1566
295682
            visit.push_back(cn);
1567
          }
1568
        }
1569
        else
1570
        {
1571
          visited[cur] = cur;
1572
        }
1573
      }
1574
    }
1575
396608
    else if (it->second.isNull())
1576
    {
1577
700766
      Node ret = cur;
1578
350383
      bool childChanged = false;
1579
700766
      std::vector<Node> children;
1580
350383
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
1581
      {
1582
4537
        children.push_back(cur.getOperator());
1583
      }
1584
646065
      for (const Node& cn : cur)
1585
      {
1586
295682
        it = visited.find(cn);
1587
295682
        Assert(it != visited.end());
1588
295682
        Assert(!it->second.isNull());
1589
295682
        childChanged = childChanged || cn != it->second;
1590
295682
        children.push_back(it->second);
1591
      }
1592
350383
      if (childChanged)
1593
      {
1594
4804
        ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1595
      }
1596
350383
      visited[cur] = ret;
1597
    }
1598
750191
  } while (!visit.empty());
1599
104126
  Assert(visited.find(n) != visited.end());
1600
104126
  Assert(!visited.find(n)->second.isNull());
1601
208252
  return visited[n];
1602
}
1603
1604
26291
Node ExtendedRewriter::partialSubstitute(Node n,
1605
                                         const std::vector<Node>& vars,
1606
                                         const std::vector<Node>& subs,
1607
                                         const std::map<Kind, bool>& rkinds)
1608
{
1609
26291
  Assert(vars.size() == subs.size());
1610
52582
  std::map<Node, Node> assign;
1611
52927
  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
1612
  {
1613
26636
    assign[vars[i]] = subs[i];
1614
  }
1615
52582
  return partialSubstitute(n, assign, rkinds);
1616
}
1617
1618
8877
Node ExtendedRewriter::solveEquality(Node n)
1619
{
1620
  // TODO (#1706) : implement
1621
8877
  Assert(n.getKind() == EQUAL);
1622
1623
8877
  return Node::null();
1624
}
1625
1626
22249
bool ExtendedRewriter::inferSubstitution(Node n,
1627
                                         std::vector<Node>& vars,
1628
                                         std::vector<Node>& subs,
1629
                                         bool usePred)
1630
{
1631
22249
  if (n.getKind() == AND)
1632
  {
1633
301
    bool ret = false;
1634
947
    for (const Node& nc : n)
1635
    {
1636
646
      bool cret = inferSubstitution(nc, vars, subs, usePred);
1637
646
      ret = ret || cret;
1638
    }
1639
301
    return ret;
1640
  }
1641
21948
  if (n.getKind() == EQUAL)
1642
  {
1643
    // see if it can be put into form x = y
1644
9423
    Node slv_eq = solveEquality(n);
1645
8877
    if (!slv_eq.isNull())
1646
    {
1647
      n = slv_eq;
1648
    }
1649
26631
    Node v[2];
1650
20652
    for (unsigned i = 0; i < 2; i++)
1651
    {
1652
17699
      if (n[i].isConst())
1653
      {
1654
5924
        vars.push_back(n[1 - i]);
1655
5924
        subs.push_back(n[i]);
1656
5924
        return true;
1657
      }
1658
11775
      if (n[i].isVar())
1659
      {
1660
9102
        v[i] = n[i];
1661
      }
1662
2673
      else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1663
      {
1664
22
        v[i] = n[i][0];
1665
      }
1666
    }
1667
4084
    for (unsigned i = 0; i < 2; i++)
1668
    {
1669
4669
      TNode r1 = v[i];
1670
4669
      Node r2 = v[1 - i];
1671
3538
      if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1672
      {
1673
2407
        r2 = n[1 - i];
1674
2407
        if (v[i] != n[i])
1675
        {
1676
10
          Assert(TermUtil::isNegate(n[i].getKind()));
1677
10
          r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1678
        }
1679
        // TODO (#1706) : union find
1680
2407
        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1681
        {
1682
2407
          vars.push_back(r1);
1683
2407
          subs.push_back(r2);
1684
2407
          return true;
1685
        }
1686
      }
1687
    }
1688
  }
1689
13617
  if (usePred)
1690
  {
1691
13188
    bool negated = n.getKind() == NOT;
1692
13188
    vars.push_back(negated ? n[0] : n);
1693
13188
    subs.push_back(negated ? d_false : d_true);
1694
13188
    return true;
1695
  }
1696
429
  return false;
1697
}
1698
1699
1287
Node ExtendedRewriter::extendedRewriteStrings(Node ret)
1700
{
1701
1287
  Node new_ret;
1702
2574
  Trace("q-ext-rewrite-debug")
1703
1287
      << "Extended rewrite strings : " << ret << std::endl;
1704
1705
1287
  if (ret.getKind() == EQUAL)
1706
  {
1707
177
    new_ret = strings::SequencesRewriter(nullptr).rewriteEqualityExt(ret);
1708
  }
1709
1710
1287
  return new_ret;
1711
}
1712
1713
52378
void ExtendedRewriter::debugExtendedRewrite(Node n,
1714
                                            Node ret,
1715
                                            const char* c) const
1716
{
1717
52378
  if (Trace.isOn("q-ext-rewrite"))
1718
  {
1719
    if (!ret.isNull())
1720
    {
1721
      Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1722
      Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1723
                             << " rewrites to " << ret << std::endl;
1724
    }
1725
  }
1726
52378
}
1727
1728
}  // namespace quantifiers
1729
}  // namespace theory
1730
28194
}  // namespace cvc5