GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/datatypes/datatypes_rewriter.cpp Lines: 404 434 93.1 %
Date: 2021-03-23 Branches: 889 1928 46.1 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file datatypes_rewriter.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Mudathir Mohamed, Mathias Preiner
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of rewriter for the theory of (co)inductive datatypes.
13
 **
14
 ** Implementation of rewriter for the theory of (co)inductive datatypes.
15
 **/
16
17
#include "theory/datatypes/datatypes_rewriter.h"
18
19
#include "expr/dtype.h"
20
#include "expr/dtype_cons.h"
21
#include "expr/node_algorithm.h"
22
#include "expr/sygus_datatype.h"
23
#include "options/datatypes_options.h"
24
#include "theory/datatypes/sygus_datatype_utils.h"
25
#include "theory/datatypes/theory_datatypes_utils.h"
26
27
using namespace CVC4;
28
using namespace CVC4::kind;
29
30
namespace CVC4 {
31
namespace theory {
32
namespace datatypes {
33
34
508041
RewriteResponse DatatypesRewriter::postRewrite(TNode in)
35
{
36
508041
  Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
37
508041
  Kind kind = in.getKind();
38
508041
  NodeManager* nm = NodeManager::currentNM();
39
508041
  if (kind == kind::APPLY_CONSTRUCTOR)
40
  {
41
134225
    return rewriteConstructor(in);
42
  }
43
373816
  else if (kind == kind::APPLY_SELECTOR_TOTAL || kind == kind::APPLY_SELECTOR)
44
  {
45
75764
    return rewriteSelector(in);
46
  }
47
298052
  else if (kind == kind::APPLY_TESTER)
48
  {
49
39558
    return rewriteTester(in);
50
  }
51
258494
  else if (kind == kind::DT_SIZE)
52
  {
53
21819
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
54
    {
55
22726
      std::vector<Node> children;
56
30246
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
57
      {
58
18883
        if (in[0][i].getType().isDatatype())
59
        {
60
18718
          children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
61
        }
62
      }
63
22726
      TNode constructor = in[0].getOperator();
64
11363
      size_t constructorIndex = utils::indexOf(constructor);
65
11363
      const DType& dt = utils::datatypeOf(constructor);
66
11363
      const DTypeConstructor& c = dt[constructorIndex];
67
11363
      unsigned weight = c.getWeight();
68
11363
      children.push_back(nm->mkConst(Rational(weight)));
69
      Node res =
70
22726
          children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
71
22726
      Trace("datatypes-rewrite")
72
11363
          << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
73
11363
          << res << std::endl;
74
11363
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
75
    }
76
  }
77
236675
  else if (kind == kind::DT_HEIGHT_BOUND)
78
  {
79
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
80
    {
81
      std::vector<Node> children;
82
      Node res;
83
      Rational r = in[1].getConst<Rational>();
84
      Rational rmo = Rational(r - Rational(1));
85
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
86
      {
87
        if (in[0][i].getType().isDatatype())
88
        {
89
          if (r.isZero())
90
          {
91
            res = nm->mkConst(false);
92
            break;
93
          }
94
          children.push_back(
95
              nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
96
        }
97
      }
98
      if (res.isNull())
99
      {
100
        res = children.size() == 0
101
                  ? nm->mkConst(true)
102
                  : (children.size() == 1 ? children[0]
103
                                          : nm->mkNode(kind::AND, children));
104
      }
105
      Trace("datatypes-rewrite")
106
          << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
107
          << res << std::endl;
108
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
109
    }
110
  }
111
236675
  else if (kind == kind::DT_SIZE_BOUND)
112
  {
113
    if (in[0].isConst())
114
    {
115
      Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
116
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
117
    }
118
  }
119
236675
  else if (kind == DT_SYGUS_EVAL)
120
  {
121
    // sygus evaluation function
122
121721
    Node ev = in[0];
123
99465
    if (ev.getKind() == APPLY_CONSTRUCTOR)
124
    {
125
77209
      Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
126
77209
      Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
127
154418
      std::vector<Node> args;
128
746680
      for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
129
      {
130
669471
        args.push_back(in[j]);
131
      }
132
154418
      Node ret = utils::sygusToBuiltinEval(ev, args);
133
77209
      Trace("dt-sygus-util") << "...got " << ret << "\n";
134
77209
      Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
135
77209
      Assert(in.getType().isComparableTo(ret.getType()));
136
77209
      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
137
    }
138
  }
139
137210
  else if (kind == MATCH)
140
  {
141
12
    Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
142
24
    Node h = in[0];
143
24
    std::vector<Node> cases;
144
24
    std::vector<Node> rets;
145
24
    TypeNode t = h.getType();
146
12
    const DType& dt = t.getDType();
147
40
    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
148
    {
149
56
      Node c = in[k];
150
56
      Node cons;
151
28
      Kind ck = c.getKind();
152
28
      if (ck == MATCH_CASE)
153
      {
154
20
        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
155
20
        cons = c[0].getOperator();
156
      }
157
8
      else if (ck == MATCH_BIND_CASE)
158
      {
159
8
        if (c[1].getKind() == APPLY_CONSTRUCTOR)
160
        {
161
4
          cons = c[1].getOperator();
162
        }
163
      }
164
      else
165
      {
166
        AlwaysAssert(false);
167
      }
168
28
      size_t cindex = 0;
169
      // cons is null in the default case
170
28
      if (!cons.isNull())
171
      {
172
24
        cindex = utils::indexOf(cons);
173
      }
174
56
      Node body;
175
28
      if (ck == MATCH_CASE)
176
      {
177
20
        body = c[1];
178
      }
179
8
      else if (ck == MATCH_BIND_CASE)
180
      {
181
16
        std::vector<Node> vars;
182
16
        std::vector<Node> subs;
183
8
        if (cons.isNull())
184
        {
185
4
          Assert(c[1].getKind() == BOUND_VARIABLE);
186
4
          vars.push_back(c[1]);
187
4
          subs.push_back(h);
188
        }
189
        else
190
        {
191
12
          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
192
          {
193
8
            vars.push_back(c[0][i]);
194
            Node sc = nm->mkNode(
195
16
                APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(t, i), h);
196
8
            subs.push_back(sc);
197
          }
198
        }
199
8
        body =
200
16
            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
201
      }
202
28
      if (!cons.isNull())
203
      {
204
24
        cases.push_back(utils::mkTester(h, cindex, dt));
205
      }
206
      else
207
      {
208
        // variables have no constraints
209
4
        cases.push_back(nm->mkConst(true));
210
      }
211
28
      rets.push_back(body);
212
    }
213
12
    Assert(!cases.empty());
214
    // now make the ITE
215
12
    std::reverse(cases.begin(), cases.end());
216
12
    std::reverse(rets.begin(), rets.end());
217
24
    Node ret = rets[0];
218
12
    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
219
28
    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
220
    {
221
16
      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
222
    }
223
24
    Trace("dt-rewrite-match")
224
12
        << "Rewrite match: " << in << " ... " << ret << std::endl;
225
12
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
226
  }
227
137198
  else if (kind == TUPLE_PROJECT)
228
  {
229
    // returns a tuple that represents
230
    // (mkTuple ((_ tupSel i_1) t) ... ((_ tupSel i_n) t))
231
    // where each i_j is less than the length of t
232
233
4
    Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl;
234
8
    TupleProjectOp op = in.getOperator().getConst<TupleProjectOp>();
235
8
    std::vector<uint32_t> indices = op.getIndices();
236
8
    Node tuple = in[0];
237
8
    std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
238
8
    std::vector<TypeNode> types;
239
8
    std::vector<Node> elements;
240
8
    for (uint32_t index : indices)
241
    {
242
8
      TypeNode type = tupleTypes[index];
243
4
      types.push_back(type);
244
    }
245
8
    TypeNode projectType = nm->mkTupleType(types);
246
4
    const DType& dt = projectType.getDType();
247
4
    elements.push_back(dt[0].getConstructor());
248
4
    const DType& tupleDType = tuple.getType().getDType();
249
4
    const DTypeConstructor& constructor = tupleDType[0];
250
8
    for (uint32_t index : indices)
251
    {
252
8
      Node selector = constructor[index].getSelector();
253
8
      Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
254
4
      elements.push_back(element);
255
    }
256
8
    Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
257
258
8
    Trace("dt-rewrite-project")
259
4
        << "Rewrite project: " << in << " ... " << ret << std::endl;
260
4
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
261
  }
262
263
169906
  if (kind == kind::EQUAL)
264
  {
265
135594
    if (in[0] == in[1])
266
    {
267
2420
      return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
268
    }
269
236033
    std::vector<Node> rew;
270
133174
    if (utils::checkClash(in[0], in[1], rew))
271
    {
272
3084
      Trace("datatypes-rewrite")
273
1542
          << "Rewrite clashing equality " << in << " to false" << std::endl;
274
1542
      return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
275
      //}else if( rew.size()==1 && rew[0]!=in ){
276
      //  Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
277
      //  rew[0] << std::endl;
278
      //  return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
279
    }
280
131632
    else if (in[1] < in[0])
281
    {
282
57546
      Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
283
57546
      Trace("datatypes-rewrite")
284
28773
          << "Swap equality " << in << " to " << ins << std::endl;
285
28773
      return RewriteResponse(REWRITE_DONE, ins);
286
    }
287
205718
    Trace("datatypes-rewrite-debug")
288
205718
        << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
289
102859
        << in[1].getKind() << std::endl;
290
  }
291
292
137171
  return RewriteResponse(REWRITE_DONE, in);
293
}
294
295
304837
RewriteResponse DatatypesRewriter::preRewrite(TNode in)
296
{
297
304837
  Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
298
  // must prewrite to apply type ascriptions since rewriting does not preserve
299
  // types
300
304837
  if (in.getKind() == kind::APPLY_CONSTRUCTOR)
301
  {
302
128765
    TypeNode tn = in.getType();
303
304
    // check for parametric datatype constructors
305
    // to ensure a normal form, all parameteric datatype constructors must have
306
    // a type ascription
307
64434
    if (tn.isParametricDatatype())
308
    {
309
317
      if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
310
      {
311
206
        Trace("datatypes-rewrite-debug")
312
103
            << "Ascribing type to parametric datatype constructor " << in
313
103
            << std::endl;
314
206
        Node op = in.getOperator();
315
        // get the constructor object
316
103
        const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)];
317
        // create ascribed constructor type
318
        Node tc = NodeManager::currentNM()->mkConst(
319
206
            AscriptionType(dtc.getSpecializedConstructorType(tn)));
320
        Node op_new = NodeManager::currentNM()->mkNode(
321
206
            kind::APPLY_TYPE_ASCRIPTION, tc, op);
322
        // make new node
323
206
        std::vector<Node> children;
324
103
        children.push_back(op_new);
325
103
        children.insert(children.end(), in.begin(), in.end());
326
        Node inr =
327
206
            NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
328
103
        Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
329
103
        return RewriteResponse(REWRITE_DONE, inr);
330
      }
331
    }
332
  }
333
304734
  return RewriteResponse(REWRITE_DONE, in);
334
}
335
336
134225
RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
337
{
338
134225
  if (in.isConst())
339
  {
340
145388
    Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
341
72694
                                     << std::endl;
342
145388
    Node inn = normalizeConstant(in);
343
    // constant may be a subterm of another constant, so cannot assume that this
344
    // will succeed for codatatypes
345
    // Assert( !inn.isNull() );
346
72694
    if (!inn.isNull() && inn != in)
347
    {
348
16
      Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
349
8
                                 << inn << std::endl;
350
8
      return RewriteResponse(REWRITE_DONE, inn);
351
    }
352
72686
    return RewriteResponse(REWRITE_DONE, in);
353
  }
354
61531
  return RewriteResponse(REWRITE_DONE, in);
355
}
356
357
75764
RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
358
{
359
75764
  Kind k = in.getKind();
360
75764
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
361
  {
362
    // Have to be careful not to rewrite well-typed expressions
363
    // where the selector doesn't match the constructor,
364
    // e.g. "pred(zero)".
365
25870
    TypeNode tn = in.getType();
366
25870
    TypeNode argType = in[0].getType();
367
25870
    Node selector = in.getOperator();
368
25870
    TNode constructor = in[0].getOperator();
369
23530
    size_t constructorIndex = utils::indexOf(constructor);
370
23530
    const DType& dt = utils::datatypeOf(selector);
371
23530
    const DTypeConstructor& c = dt[constructorIndex];
372
47060
    Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
373
23530
                                     << in;
374
47060
    Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
375
23530
                                     << ", selector is " << selector
376
23530
                                     << std::endl;
377
    // The argument that the selector extracts, or -1 if the selector is
378
    // is wrongly applied.
379
23530
    int selectorIndex = -1;
380
23530
    if (k == kind::APPLY_SELECTOR_TOTAL)
381
    {
382
      // The argument index of internal selectors is obtained by
383
      // getSelectorIndexInternal.
384
18310
      selectorIndex = c.getSelectorIndexInternal(selector);
385
    }
386
    else
387
    {
388
      // The argument index of external selectors (applications of
389
      // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
390
      // The argument is only valid if it is the proper constructor.
391
5220
      selectorIndex = utils::indexOf(selector);
392
10440
      if (selectorIndex < 0
393
5220
          || selectorIndex >= static_cast<int>(c.getNumArgs()))
394
      {
395
806
        selectorIndex = -1;
396
      }
397
4414
      else if (c[selectorIndex].getSelector() != selector)
398
      {
399
1534
        selectorIndex = -1;
400
      }
401
    }
402
47060
    Trace("datatypes-rewrite-debug") << "Internal selector index is "
403
23530
                                     << selectorIndex << std::endl;
404
23530
    if (selectorIndex >= 0)
405
    {
406
20462
      Assert(selectorIndex < (int)c.getNumArgs());
407
20462
      if (dt.isCodatatype() && in[0][selectorIndex].isConst())
408
      {
409
        // must replace all debruijn indices with self
410
14
        Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
411
28
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
412
14
                                   << "Rewrite trivial codatatype selector "
413
14
                                   << in << " to " << sub << std::endl;
414
14
        if (sub != in)
415
        {
416
14
          return RewriteResponse(REWRITE_AGAIN_FULL, sub);
417
        }
418
      }
419
      else
420
      {
421
40896
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
422
20448
                                   << "Rewrite trivial selector " << in
423
20448
                                   << std::endl;
424
20448
        return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
425
      }
426
    }
427
3068
    else if (k == kind::APPLY_SELECTOR_TOTAL)
428
    {
429
      // evaluates to the first ground value of type tn.
430
1456
      Node gt = tn.mkGroundValue();
431
728
      Assert(!gt.isNull());
432
728
      if (tn.isDatatype() && !tn.isInstantiatedDatatype())
433
      {
434
        gt = NodeManager::currentNM()->mkNode(
435
            kind::APPLY_TYPE_ASCRIPTION,
436
            NodeManager::currentNM()->mkConst(AscriptionType(tn)),
437
            gt);
438
      }
439
1456
      Trace("datatypes-rewrite")
440
728
          << "DatatypesRewriter::postRewrite: "
441
728
          << "Rewrite trivial selector " << in
442
728
          << " to distinguished ground term " << gt << std::endl;
443
728
      return RewriteResponse(REWRITE_DONE, gt);
444
    }
445
  }
446
54574
  return RewriteResponse(REWRITE_DONE, in);
447
}
448
449
39558
RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
450
{
451
39558
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
452
  {
453
    bool result =
454
1980
        utils::indexOf(in.getOperator()) == utils::indexOf(in[0].getOperator());
455
3960
    Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
456
1980
                               << "Rewrite trivial tester " << in << " "
457
1980
                               << result << std::endl;
458
    return RewriteResponse(REWRITE_DONE,
459
1980
                           NodeManager::currentNM()->mkConst(result));
460
  }
461
37578
  const DType& dt = in[0].getType().getDType();
462
37578
  if (dt.getNumConstructors() == 1 && !dt.isSygus())
463
  {
464
    // only one constructor, so it must be
465
1212
    Trace("datatypes-rewrite")
466
606
        << "DatatypesRewriter::postRewrite: "
467
1212
        << "only one ctor for " << dt.getName() << " and that is "
468
606
        << dt[0].getName() << std::endl;
469
    return RewriteResponse(REWRITE_DONE,
470
606
                           NodeManager::currentNM()->mkConst(true));
471
  }
472
  // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
473
36972
  return RewriteResponse(REWRITE_DONE, in);
474
}
475
476
965
Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
477
{
478
965
  Trace("dt-nconst") << "Normalize " << n << std::endl;
479
1930
  std::map<Node, Node> rf;
480
1930
  std::vector<Node> sk;
481
1930
  std::vector<Node> rf_pending;
482
1930
  std::vector<Node> terms;
483
1930
  std::map<Node, bool> cdts;
484
1930
  Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
485
965
  if (!s.isNull())
486
  {
487
584
    Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
488
814
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
489
    {
490
460
      Trace("dt-nconst") << "  " << it->first << " = " << it->second
491
230
                         << std::endl;
492
    }
493
    // now run DFA minimization on term structure
494
1168
    Trace("dt-nconst") << "  " << terms.size()
495
584
                       << " total subterms :" << std::endl;
496
584
    int eqc_count = 0;
497
1168
    std::map<Node, int> eqc_op_map;
498
1168
    std::map<Node, int> eqc;
499
1168
    std::map<int, std::vector<Node> > eqc_nodes;
500
    // partition based on top symbol
501
2807
    for (unsigned j = 0, size = terms.size(); j < size; j++)
502
    {
503
4446
      Node t = terms[j];
504
2223
      Trace("dt-nconst") << "    " << t << ", cdt=" << cdts[t] << std::endl;
505
      int e;
506
2223
      if (cdts[t])
507
      {
508
1699
        Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
509
3398
        Node op = t.getOperator();
510
1699
        std::map<Node, int>::iterator it = eqc_op_map.find(op);
511
1699
        if (it == eqc_op_map.end())
512
        {
513
935
          e = eqc_count;
514
935
          eqc_op_map[op] = eqc_count;
515
935
          eqc_count++;
516
        }
517
        else
518
        {
519
764
          e = it->second;
520
        }
521
      }
522
      else
523
      {
524
524
        e = eqc_count;
525
524
        eqc_count++;
526
      }
527
2223
      eqc[t] = e;
528
2223
      eqc_nodes[e].push_back(t);
529
    }
530
    // partition until fixed point
531
584
    int eqc_curr = 0;
532
584
    bool success = true;
533
2860
    while (success)
534
    {
535
1138
      success = false;
536
1138
      int eqc_end = eqc_count;
537
6374
      while (eqc_curr < eqc_end)
538
      {
539
2618
        if (eqc_nodes[eqc_curr].size() > 1)
540
        {
541
          // look at all nodes in this equivalence class
542
661
          unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
543
1322
          std::map<int, std::vector<Node> > prt;
544
1358
          for (unsigned j = 0; j < nchildren; j++)
545
          {
546
1267
            prt.clear();
547
            // partition based on children : for the first child that causes a
548
            // split, break
549
5064
            for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
550
                 k++)
551
            {
552
7594
              Node t = eqc_nodes[eqc_curr][k];
553
3797
              Assert(t.getNumChildren() == nchildren);
554
7594
              Node tc = t[j];
555
              // refer to loops
556
3797
              std::map<Node, Node>::iterator itr = rf.find(tc);
557
3797
              if (itr != rf.end())
558
              {
559
95
                tc = itr->second;
560
              }
561
3797
              Assert(eqc.find(tc) != eqc.end());
562
3797
              prt[eqc[tc]].push_back(t);
563
            }
564
1267
            if (prt.size() > 1)
565
            {
566
570
              success = true;
567
570
              break;
568
            }
569
          }
570
          // move into new eqc(s)
571
1910
          for (const std::pair<const int, std::vector<Node> >& p : prt)
572
          {
573
1249
            int e = eqc_count;
574
3214
            for (unsigned j = 0, size = p.second.size(); j < size; j++)
575
            {
576
3930
              Node t = p.second[j];
577
1965
              eqc[t] = e;
578
1965
              eqc_nodes[e].push_back(t);
579
            }
580
1249
            eqc_count++;
581
          }
582
        }
583
2618
        eqc_nodes.erase(eqc_curr);
584
2618
        eqc_curr++;
585
      }
586
    }
587
    // add in already occurring loop variables
588
814
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
589
    {
590
460
      Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
591
230
                               << " -> " << it->second << std::endl;
592
230
      Assert(eqc.find(it->second) != eqc.end());
593
230
      eqc[it->first] = eqc[it->second];
594
    }
595
    // we now have a partition of equivalent terms
596
584
    Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
597
3037
    for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
598
    {
599
4906
      Trace("dt-nconst") << "  " << it->first << " -> " << it->second
600
2453
                         << std::endl;
601
    }
602
    // traverse top-down based on equivalence class
603
1168
    std::map<int, int> eqc_stack;
604
584
    return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
605
  }
606
381
  Trace("dt-nconst") << "...invalid." << std::endl;
607
381
  return Node::null();
608
}
609
610
// normalize constant : apply to top-level codatatype constants
611
425510
Node DatatypesRewriter::normalizeConstant(Node n)
612
{
613
851020
  TypeNode tn = n.getType();
614
425510
  if (tn.isDatatype())
615
  {
616
387863
    if (tn.isCodatatype())
617
    {
618
200
      return normalizeCodatatypeConstant(n);
619
    }
620
    else
621
    {
622
775326
      std::vector<Node> children;
623
387663
      bool childrenChanged = false;
624
740431
      for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
625
      {
626
705536
        Node nc = normalizeConstant(n[i]);
627
352768
        children.push_back(nc);
628
352768
        childrenChanged = childrenChanged || nc != n[i];
629
      }
630
387663
      if (childrenChanged)
631
      {
632
        return NodeManager::currentNM()->mkNode(n.getKind(), children);
633
      }
634
    }
635
  }
636
425310
  return n;
637
}
638
639
4618
Node DatatypesRewriter::collectRef(Node n,
640
                                   std::vector<Node>& sk,
641
                                   std::map<Node, Node>& rf,
642
                                   std::vector<Node>& rf_pending,
643
                                   std::vector<Node>& terms,
644
                                   std::map<Node, bool>& cdts)
645
{
646
4618
  Assert(n.isConst());
647
9236
  TypeNode tn = n.getType();
648
9236
  Node ret = n;
649
4618
  bool isCdt = false;
650
4618
  if (tn.isDatatype())
651
  {
652
2930
    if (!tn.isCodatatype())
653
    {
654
      // nested datatype within codatatype : can be normalized independently
655
      // since all loops should be self-contained
656
48
      ret = normalizeConstant(n);
657
    }
658
    else
659
    {
660
2882
      isCdt = true;
661
2882
      if (n.getKind() == kind::APPLY_CONSTRUCTOR)
662
      {
663
2271
        sk.push_back(n);
664
2271
        rf_pending.push_back(Node::null());
665
3997
        std::vector<Node> children;
666
2271
        children.push_back(n.getOperator());
667
2271
        bool childChanged = false;
668
5379
        for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
669
        {
670
6761
          Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
671
3653
          if (nc.isNull())
672
          {
673
545
            return Node::null();
674
          }
675
3108
          childChanged = nc != n[i] || childChanged;
676
3108
          children.push_back(nc);
677
        }
678
1726
        sk.pop_back();
679
1726
        if (childChanged)
680
        {
681
495
          ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
682
                                                 children);
683
495
          if (!rf_pending.back().isNull())
684
          {
685
230
            rf[rf_pending.back()] = ret;
686
          }
687
        }
688
        else
689
        {
690
1231
          Assert(rf_pending.back().isNull());
691
        }
692
1726
        rf_pending.pop_back();
693
      }
694
      else
695
      {
696
        // a loop
697
611
        const Integer& i = n.getConst<UninterpretedConstant>().getIndex();
698
611
        uint32_t index = i.toUnsignedInt();
699
611
        if (index >= sk.size())
700
        {
701
381
          return Node::null();
702
        }
703
230
        Assert(sk.size() == rf_pending.size());
704
460
        Node r = rf_pending[rf_pending.size() - 1 - index];
705
230
        if (r.isNull())
706
        {
707
460
          r = NodeManager::currentNM()->mkBoundVar(
708
460
              sk[rf_pending.size() - 1 - index].getType());
709
230
          rf_pending[rf_pending.size() - 1 - index] = r;
710
        }
711
230
        return r;
712
      }
713
    }
714
  }
715
6924
  Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
716
3462
                           << std::endl;
717
3462
  if (std::find(terms.begin(), terms.end(), ret) == terms.end())
718
  {
719
2628
    terms.push_back(ret);
720
2628
    Assert(ret.getType() == tn);
721
2628
    cdts[ret] = isCdt;
722
  }
723
3462
  return ret;
724
}
725
// eqc_stack stores depth
726
2853
Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
727
    Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
728
{
729
5706
  Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
730
2853
                           << " depth=" << depth << std::endl;
731
2853
  if (eqc.find(n) != eqc.end())
732
  {
733
2803
    int e = eqc[n];
734
2803
    std::map<int, int>::iterator it = eqc_stack.find(e);
735
2803
    if (it != eqc_stack.end())
736
    {
737
230
      int debruijn = depth - it->second - 1;
738
      return NodeManager::currentNM()->mkConst(
739
230
          UninterpretedConstant(n.getType(), debruijn));
740
    }
741
4827
    std::vector<Node> children;
742
2573
    bool childChanged = false;
743
2573
    eqc_stack[e] = depth;
744
4842
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
745
    {
746
4538
      Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
747
2269
      children.push_back(nc);
748
2269
      childChanged = childChanged || nc != n[i];
749
    }
750
2573
    eqc_stack.erase(e);
751
2573
    if (childChanged)
752
    {
753
319
      Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
754
319
      children.insert(children.begin(), n.getOperator());
755
319
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
756
    }
757
  }
758
2304
  return n;
759
}
760
761
62
Node DatatypesRewriter::replaceDebruijn(Node n,
762
                                        Node orig,
763
                                        TypeNode orig_tn,
764
                                        unsigned depth)
765
{
766
62
  if (n.getKind() == kind::UNINTERPRETED_CONSTANT && n.getType() == orig_tn)
767
  {
768
    unsigned index =
769
1
        n.getConst<UninterpretedConstant>().getIndex().toUnsignedInt();
770
1
    if (index == depth)
771
    {
772
1
      return orig;
773
    }
774
  }
775
61
  else if (n.getNumChildren() > 0)
776
  {
777
48
    std::vector<Node> children;
778
24
    bool childChanged = false;
779
72
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
780
    {
781
96
      Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
782
48
      children.push_back(nc);
783
48
      childChanged = childChanged || nc != n[i];
784
    }
785
24
    if (childChanged)
786
    {
787
      if (n.hasOperator())
788
      {
789
        children.insert(children.begin(), n.getOperator());
790
      }
791
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
792
    }
793
  }
794
61
  return n;
795
}
796
797
} /* CVC4::theory::datatypes namespace */
798
} /* CVC4::theory namespace */
799
26685
} /* CVC4 namespace */