GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/datatypes/datatypes_rewriter.cpp Lines: 480 512 93.8 %
Date: 2021-08-17 Branches: 1067 2307 46.3 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mudathir Mohamed, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of rewriter for the theory of (co)inductive datatypes.
14
 */
15
16
#include "theory/datatypes/datatypes_rewriter.h"
17
18
#include "expr/ascription_type.h"
19
#include "expr/dtype.h"
20
#include "expr/dtype_cons.h"
21
#include "expr/node_algorithm.h"
22
#include "expr/skolem_manager.h"
23
#include "expr/sygus_datatype.h"
24
#include "expr/uninterpreted_constant.h"
25
#include "options/datatypes_options.h"
26
#include "theory/datatypes/sygus_datatype_utils.h"
27
#include "theory/datatypes/theory_datatypes_utils.h"
28
#include "theory/datatypes/tuple_project_op.h"
29
#include "util/rational.h"
30
31
using namespace cvc5;
32
using namespace cvc5::kind;
33
34
namespace cvc5 {
35
namespace theory {
36
namespace datatypes {
37
38
518075
RewriteResponse DatatypesRewriter::postRewrite(TNode in)
39
{
40
518075
  Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
41
518075
  Kind kind = in.getKind();
42
518075
  NodeManager* nm = NodeManager::currentNM();
43
518075
  if (kind == kind::APPLY_CONSTRUCTOR)
44
  {
45
142684
    return rewriteConstructor(in);
46
  }
47
375391
  else if (kind == kind::APPLY_SELECTOR_TOTAL || kind == kind::APPLY_SELECTOR)
48
  {
49
81418
    return rewriteSelector(in);
50
  }
51
293973
  else if (kind == kind::APPLY_TESTER)
52
  {
53
35091
    return rewriteTester(in);
54
  }
55
258882
  else if (kind == APPLY_UPDATER)
56
  {
57
53
    return rewriteUpdater(in);
58
  }
59
258829
  else if (kind == kind::DT_SIZE)
60
  {
61
13638
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
62
    {
63
14876
      std::vector<Node> children;
64
19893
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
65
      {
66
12455
        if (in[0][i].getType().isDatatype())
67
        {
68
12368
          children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
69
        }
70
      }
71
14876
      TNode constructor = in[0].getOperator();
72
7438
      size_t constructorIndex = utils::indexOf(constructor);
73
7438
      const DType& dt = utils::datatypeOf(constructor);
74
7438
      const DTypeConstructor& c = dt[constructorIndex];
75
7438
      unsigned weight = c.getWeight();
76
7438
      children.push_back(nm->mkConst(Rational(weight)));
77
      Node res =
78
14876
          children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
79
14876
      Trace("datatypes-rewrite")
80
7438
          << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
81
7438
          << res << std::endl;
82
7438
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
83
    }
84
  }
85
245191
  else if (kind == kind::DT_HEIGHT_BOUND)
86
  {
87
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
88
    {
89
      std::vector<Node> children;
90
      Node res;
91
      Rational r = in[1].getConst<Rational>();
92
      Rational rmo = Rational(r - Rational(1));
93
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
94
      {
95
        if (in[0][i].getType().isDatatype())
96
        {
97
          if (r.isZero())
98
          {
99
            res = nm->mkConst(false);
100
            break;
101
          }
102
          children.push_back(
103
              nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
104
        }
105
      }
106
      if (res.isNull())
107
      {
108
        res = children.size() == 0
109
                  ? nm->mkConst(true)
110
                  : (children.size() == 1 ? children[0]
111
                                          : nm->mkNode(kind::AND, children));
112
      }
113
      Trace("datatypes-rewrite")
114
          << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
115
          << res << std::endl;
116
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
117
    }
118
  }
119
245191
  else if (kind == kind::DT_SIZE_BOUND)
120
  {
121
    if (in[0].isConst())
122
    {
123
      Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
124
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
125
    }
126
  }
127
245191
  else if (kind == DT_SYGUS_EVAL)
128
  {
129
    // sygus evaluation function
130
82947
    Node ev = in[0];
131
70787
    if (ev.getKind() == APPLY_CONSTRUCTOR)
132
    {
133
58627
      Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
134
58627
      Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
135
117254
      std::vector<Node> args;
136
524032
      for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
137
      {
138
465405
        args.push_back(in[j]);
139
      }
140
117254
      Node ret = utils::sygusToBuiltinEval(ev, args);
141
58627
      Trace("dt-sygus-util") << "...got " << ret << "\n";
142
58627
      Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
143
58627
      Assert(in.getType().isComparableTo(ret.getType()));
144
58627
      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
145
    }
146
  }
147
174404
  else if (kind == MATCH)
148
  {
149
12
    Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
150
24
    Node h = in[0];
151
24
    std::vector<Node> cases;
152
24
    std::vector<Node> rets;
153
24
    TypeNode t = h.getType();
154
12
    const DType& dt = t.getDType();
155
40
    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
156
    {
157
56
      Node c = in[k];
158
56
      Node cons;
159
28
      Kind ck = c.getKind();
160
28
      if (ck == MATCH_CASE)
161
      {
162
20
        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
163
20
        cons = c[0].getOperator();
164
      }
165
8
      else if (ck == MATCH_BIND_CASE)
166
      {
167
8
        if (c[1].getKind() == APPLY_CONSTRUCTOR)
168
        {
169
4
          cons = c[1].getOperator();
170
        }
171
      }
172
      else
173
      {
174
        AlwaysAssert(false);
175
      }
176
28
      size_t cindex = 0;
177
      // cons is null in the default case
178
28
      if (!cons.isNull())
179
      {
180
24
        cindex = utils::indexOf(cons);
181
      }
182
56
      Node body;
183
28
      if (ck == MATCH_CASE)
184
      {
185
20
        body = c[1];
186
      }
187
8
      else if (ck == MATCH_BIND_CASE)
188
      {
189
16
        std::vector<Node> vars;
190
16
        std::vector<Node> subs;
191
8
        if (cons.isNull())
192
        {
193
4
          Assert(c[1].getKind() == BOUND_VARIABLE);
194
4
          vars.push_back(c[1]);
195
4
          subs.push_back(h);
196
        }
197
        else
198
        {
199
12
          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
200
          {
201
8
            vars.push_back(c[0][i]);
202
            Node sc = nm->mkNode(
203
16
                APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(t, i), h);
204
8
            subs.push_back(sc);
205
          }
206
        }
207
8
        body =
208
16
            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
209
      }
210
28
      if (!cons.isNull())
211
      {
212
24
        cases.push_back(utils::mkTester(h, cindex, dt));
213
      }
214
      else
215
      {
216
        // variables have no constraints
217
4
        cases.push_back(nm->mkConst(true));
218
      }
219
28
      rets.push_back(body);
220
    }
221
12
    Assert(!cases.empty());
222
    // now make the ITE
223
12
    std::reverse(cases.begin(), cases.end());
224
12
    std::reverse(rets.begin(), rets.end());
225
24
    Node ret = rets[0];
226
12
    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
227
28
    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
228
    {
229
16
      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
230
    }
231
24
    Trace("dt-rewrite-match")
232
12
        << "Rewrite match: " << in << " ... " << ret << std::endl;
233
12
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
234
  }
235
174392
  else if (kind == TUPLE_PROJECT)
236
  {
237
    // returns a tuple that represents
238
    // (mkTuple ((_ tupSel i_1) t) ... ((_ tupSel i_n) t))
239
    // where each i_j is less than the length of t
240
241
6
    Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl;
242
12
    TupleProjectOp op = in.getOperator().getConst<TupleProjectOp>();
243
12
    std::vector<uint32_t> indices = op.getIndices();
244
12
    Node tuple = in[0];
245
12
    std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
246
12
    std::vector<TypeNode> types;
247
12
    std::vector<Node> elements;
248
12
    for (uint32_t index : indices)
249
    {
250
12
      TypeNode type = tupleTypes[index];
251
6
      types.push_back(type);
252
    }
253
12
    TypeNode projectType = nm->mkTupleType(types);
254
6
    const DType& dt = projectType.getDType();
255
6
    elements.push_back(dt[0].getConstructor());
256
6
    const DType& tupleDType = tuple.getType().getDType();
257
6
    const DTypeConstructor& constructor = tupleDType[0];
258
12
    for (uint32_t index : indices)
259
    {
260
12
      Node selector = constructor[index].getSelector();
261
12
      Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
262
6
      elements.push_back(element);
263
    }
264
12
    Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
265
266
12
    Trace("dt-rewrite-project")
267
6
        << "Rewrite project: " << in << " ... " << ret << std::endl;
268
6
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
269
  }
270
271
192746
  if (kind == kind::EQUAL)
272
  {
273
173380
    if (in[0] == in[1])
274
    {
275
4316
      return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
276
    }
277
303931
    std::vector<Node> rew;
278
169064
    if (utils::checkClash(in[0], in[1], rew))
279
    {
280
3668
      Trace("datatypes-rewrite")
281
1834
          << "Rewrite clashing equality " << in << " to false" << std::endl;
282
1834
      return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
283
      //}else if( rew.size()==1 && rew[0]!=in ){
284
      //  Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
285
      //  rew[0] << std::endl;
286
      //  return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
287
    }
288
167230
    else if (in[1] < in[0])
289
    {
290
64726
      Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
291
64726
      Trace("datatypes-rewrite")
292
32363
          << "Swap equality " << in << " to " << ins << std::endl;
293
32363
      return RewriteResponse(REWRITE_DONE, ins);
294
    }
295
269734
    Trace("datatypes-rewrite-debug")
296
269734
        << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
297
134867
        << in[1].getKind() << std::endl;
298
  }
299
300
154233
  return RewriteResponse(REWRITE_DONE, in);
301
}
302
303
300729
RewriteResponse DatatypesRewriter::preRewrite(TNode in)
304
{
305
300729
  Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
306
  // must prewrite to apply type ascriptions since rewriting does not preserve
307
  // types
308
300729
  if (in.getKind() == kind::APPLY_CONSTRUCTOR)
309
  {
310
138334
    TypeNode tn = in.getType();
311
312
    // check for parametric datatype constructors
313
    // to ensure a normal form, all parameteric datatype constructors must have
314
    // a type ascription
315
69216
    if (tn.isParametricDatatype())
316
    {
317
308
      if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
318
      {
319
196
        Trace("datatypes-rewrite-debug")
320
98
            << "Ascribing type to parametric datatype constructor " << in
321
98
            << std::endl;
322
196
        Node op = in.getOperator();
323
        // get the constructor object
324
98
        const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)];
325
        // create ascribed constructor type
326
        Node tc = NodeManager::currentNM()->mkConst(
327
196
            AscriptionType(dtc.getSpecializedConstructorType(tn)));
328
        Node op_new = NodeManager::currentNM()->mkNode(
329
196
            kind::APPLY_TYPE_ASCRIPTION, tc, op);
330
        // make new node
331
196
        std::vector<Node> children;
332
98
        children.push_back(op_new);
333
98
        children.insert(children.end(), in.begin(), in.end());
334
        Node inr =
335
196
            NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
336
98
        Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
337
98
        return RewriteResponse(REWRITE_DONE, inr);
338
      }
339
    }
340
  }
341
300631
  return RewriteResponse(REWRITE_DONE, in);
342
}
343
344
142684
RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
345
{
346
142684
  if (in.isConst())
347
  {
348
123778
    Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
349
61889
                                     << std::endl;
350
123778
    Node inn = normalizeConstant(in);
351
    // constant may be a subterm of another constant, so cannot assume that this
352
    // will succeed for codatatypes
353
    // Assert( !inn.isNull() );
354
61889
    if (!inn.isNull() && inn != in)
355
    {
356
16
      Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
357
8
                                 << inn << std::endl;
358
8
      return RewriteResponse(REWRITE_DONE, inn);
359
    }
360
61881
    return RewriteResponse(REWRITE_DONE, in);
361
  }
362
80795
  return RewriteResponse(REWRITE_DONE, in);
363
}
364
365
81418
RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
366
{
367
81418
  Kind k = in.getKind();
368
81418
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
369
  {
370
    // Have to be careful not to rewrite well-typed expressions
371
    // where the selector doesn't match the constructor,
372
    // e.g. "pred(zero)".
373
23523
    TypeNode tn = in.getType();
374
23523
    TypeNode argType = in[0].getType();
375
23523
    Node selector = in.getOperator();
376
23523
    TNode constructor = in[0].getOperator();
377
21641
    size_t constructorIndex = utils::indexOf(constructor);
378
21641
    const DType& dt = utils::datatypeOf(selector);
379
21641
    const DTypeConstructor& c = dt[constructorIndex];
380
43282
    Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
381
21641
                                     << in;
382
43282
    Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
383
21641
                                     << ", selector is " << selector
384
21641
                                     << std::endl;
385
    // The argument that the selector extracts, or -1 if the selector is
386
    // is wrongly applied.
387
21641
    int selectorIndex = -1;
388
21641
    if (k == kind::APPLY_SELECTOR_TOTAL)
389
    {
390
      // The argument index of internal selectors is obtained by
391
      // getSelectorIndexInternal.
392
14679
      selectorIndex = c.getSelectorIndexInternal(selector);
393
    }
394
    else
395
    {
396
      // The argument index of external selectors (applications of
397
      // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
398
      // The argument is only valid if it is the proper constructor.
399
6962
      selectorIndex = utils::indexOf(selector);
400
13924
      if (selectorIndex < 0
401
6962
          || selectorIndex >= static_cast<int>(c.getNumArgs()))
402
      {
403
864
        selectorIndex = -1;
404
      }
405
6098
      else if (c[selectorIndex].getSelector() != selector)
406
      {
407
1018
        selectorIndex = -1;
408
      }
409
    }
410
43282
    Trace("datatypes-rewrite-debug") << "Internal selector index is "
411
21641
                                     << selectorIndex << std::endl;
412
21641
    if (selectorIndex >= 0)
413
    {
414
19094
      Assert(selectorIndex < (int)c.getNumArgs());
415
19094
      if (dt.isCodatatype() && in[0][selectorIndex].isConst())
416
      {
417
        // must replace all debruijn indices with self
418
14
        Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
419
28
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
420
14
                                   << "Rewrite trivial codatatype selector "
421
14
                                   << in << " to " << sub << std::endl;
422
14
        if (sub != in)
423
        {
424
14
          return RewriteResponse(REWRITE_AGAIN_FULL, sub);
425
        }
426
      }
427
      else
428
      {
429
38160
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
430
19080
                                   << "Rewrite trivial selector " << in
431
19080
                                   << std::endl;
432
19080
        return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
433
      }
434
    }
435
2547
    else if (k == kind::APPLY_SELECTOR_TOTAL)
436
    {
437
      // evaluates to the first ground value of type tn.
438
1330
      Node gt = tn.mkGroundValue();
439
665
      Assert(!gt.isNull());
440
665
      if (tn.isDatatype() && !tn.isInstantiatedDatatype())
441
      {
442
        gt = NodeManager::currentNM()->mkNode(
443
            kind::APPLY_TYPE_ASCRIPTION,
444
            NodeManager::currentNM()->mkConst(AscriptionType(tn)),
445
            gt);
446
      }
447
1330
      Trace("datatypes-rewrite")
448
665
          << "DatatypesRewriter::postRewrite: "
449
665
          << "Rewrite trivial selector " << in
450
665
          << " to distinguished ground term " << gt << std::endl;
451
665
      return RewriteResponse(REWRITE_DONE, gt);
452
    }
453
  }
454
61659
  return RewriteResponse(REWRITE_DONE, in);
455
}
456
457
35091
RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
458
{
459
35091
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
460
  {
461
    bool result =
462
1815
        utils::indexOf(in.getOperator()) == utils::indexOf(in[0].getOperator());
463
3630
    Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
464
1815
                               << "Rewrite trivial tester " << in << " "
465
1815
                               << result << std::endl;
466
    return RewriteResponse(REWRITE_DONE,
467
1815
                           NodeManager::currentNM()->mkConst(result));
468
  }
469
33276
  const DType& dt = in[0].getType().getDType();
470
33276
  if (dt.getNumConstructors() == 1 && !dt.isSygus())
471
  {
472
    // only one constructor, so it must be
473
2744
    Trace("datatypes-rewrite")
474
1372
        << "DatatypesRewriter::postRewrite: "
475
2744
        << "only one ctor for " << dt.getName() << " and that is "
476
1372
        << dt[0].getName() << std::endl;
477
    return RewriteResponse(REWRITE_DONE,
478
1372
                           NodeManager::currentNM()->mkConst(true));
479
  }
480
  // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
481
31904
  return RewriteResponse(REWRITE_DONE, in);
482
}
483
484
53
RewriteResponse DatatypesRewriter::rewriteUpdater(TNode in)
485
{
486
53
  Assert (in.getKind()==APPLY_UPDATER);
487
53
  if (in[0].getKind() == APPLY_CONSTRUCTOR)
488
  {
489
14
    Node op = in.getOperator();
490
7
    size_t cindex = utils::indexOf(in[0].getOperator());
491
7
    size_t cuindex = utils::cindexOf(op);
492
7
    if (cindex==cuindex)
493
    {
494
7
      NodeManager * nm = NodeManager::currentNM();
495
7
      size_t updateIndex = utils::indexOf(op);
496
14
      std::vector<Node> children(in[0].begin(), in[0].end());
497
7
      children[updateIndex] = in[1];
498
7
      children.insert(children.begin(),in[0].getOperator());
499
7
      return RewriteResponse(REWRITE_DONE, nm->mkNode(APPLY_CONSTRUCTOR, children));
500
    }
501
    return RewriteResponse(REWRITE_DONE, in[0]);
502
  }
503
46
  return RewriteResponse(REWRITE_DONE, in);
504
}
505
506
977
Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
507
{
508
977
  Trace("dt-nconst") << "Normalize " << n << std::endl;
509
1954
  std::map<Node, Node> rf;
510
1954
  std::vector<Node> sk;
511
1954
  std::vector<Node> rf_pending;
512
1954
  std::vector<Node> terms;
513
1954
  std::map<Node, bool> cdts;
514
1954
  Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
515
977
  if (!s.isNull())
516
  {
517
596
    Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
518
826
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
519
    {
520
460
      Trace("dt-nconst") << "  " << it->first << " = " << it->second
521
230
                         << std::endl;
522
    }
523
    // now run DFA minimization on term structure
524
1192
    Trace("dt-nconst") << "  " << terms.size()
525
596
                       << " total subterms :" << std::endl;
526
596
    int eqc_count = 0;
527
1192
    std::map<Node, int> eqc_op_map;
528
1192
    std::map<Node, int> eqc;
529
1192
    std::map<int, std::vector<Node> > eqc_nodes;
530
    // partition based on top symbol
531
2831
    for (unsigned j = 0, size = terms.size(); j < size; j++)
532
    {
533
4470
      Node t = terms[j];
534
2235
      Trace("dt-nconst") << "    " << t << ", cdt=" << cdts[t] << std::endl;
535
      int e;
536
2235
      if (cdts[t])
537
      {
538
1711
        Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
539
3422
        Node op = t.getOperator();
540
1711
        std::map<Node, int>::iterator it = eqc_op_map.find(op);
541
1711
        if (it == eqc_op_map.end())
542
        {
543
947
          e = eqc_count;
544
947
          eqc_op_map[op] = eqc_count;
545
947
          eqc_count++;
546
        }
547
        else
548
        {
549
764
          e = it->second;
550
        }
551
      }
552
      else
553
      {
554
524
        e = eqc_count;
555
524
        eqc_count++;
556
      }
557
2235
      eqc[t] = e;
558
2235
      eqc_nodes[e].push_back(t);
559
    }
560
    // partition until fixed point
561
596
    int eqc_curr = 0;
562
596
    bool success = true;
563
2896
    while (success)
564
    {
565
1150
      success = false;
566
1150
      int eqc_end = eqc_count;
567
6410
      while (eqc_curr < eqc_end)
568
      {
569
2630
        if (eqc_nodes[eqc_curr].size() > 1)
570
        {
571
          // look at all nodes in this equivalence class
572
661
          unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
573
1322
          std::map<int, std::vector<Node> > prt;
574
1358
          for (unsigned j = 0; j < nchildren; j++)
575
          {
576
1267
            prt.clear();
577
            // partition based on children : for the first child that causes a
578
            // split, break
579
5064
            for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
580
                 k++)
581
            {
582
7594
              Node t = eqc_nodes[eqc_curr][k];
583
3797
              Assert(t.getNumChildren() == nchildren);
584
7594
              Node tc = t[j];
585
              // refer to loops
586
3797
              std::map<Node, Node>::iterator itr = rf.find(tc);
587
3797
              if (itr != rf.end())
588
              {
589
95
                tc = itr->second;
590
              }
591
3797
              Assert(eqc.find(tc) != eqc.end());
592
3797
              prt[eqc[tc]].push_back(t);
593
            }
594
1267
            if (prt.size() > 1)
595
            {
596
570
              success = true;
597
570
              break;
598
            }
599
          }
600
          // move into new eqc(s)
601
1910
          for (const std::pair<const int, std::vector<Node> >& p : prt)
602
          {
603
1249
            int e = eqc_count;
604
3214
            for (unsigned j = 0, size = p.second.size(); j < size; j++)
605
            {
606
3930
              Node t = p.second[j];
607
1965
              eqc[t] = e;
608
1965
              eqc_nodes[e].push_back(t);
609
            }
610
1249
            eqc_count++;
611
          }
612
        }
613
2630
        eqc_nodes.erase(eqc_curr);
614
2630
        eqc_curr++;
615
      }
616
    }
617
    // add in already occurring loop variables
618
826
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
619
    {
620
460
      Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
621
230
                               << " -> " << it->second << std::endl;
622
230
      Assert(eqc.find(it->second) != eqc.end());
623
230
      eqc[it->first] = eqc[it->second];
624
    }
625
    // we now have a partition of equivalent terms
626
596
    Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
627
3061
    for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
628
    {
629
4930
      Trace("dt-nconst") << "  " << it->first << " -> " << it->second
630
2465
                         << std::endl;
631
    }
632
    // traverse top-down based on equivalence class
633
1192
    std::map<int, int> eqc_stack;
634
596
    return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
635
  }
636
381
  Trace("dt-nconst") << "...invalid." << std::endl;
637
381
  return Node::null();
638
}
639
640
// normalize constant : apply to top-level codatatype constants
641
372030
Node DatatypesRewriter::normalizeConstant(Node n)
642
{
643
744060
  TypeNode tn = n.getType();
644
372030
  if (tn.isDatatype())
645
  {
646
334834
    if (tn.isCodatatype())
647
    {
648
212
      return normalizeCodatatypeConstant(n);
649
    }
650
    else
651
    {
652
669244
      std::vector<Node> children;
653
334622
      bool childrenChanged = false;
654
644715
      for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
655
      {
656
620186
        Node nc = normalizeConstant(n[i]);
657
310093
        children.push_back(nc);
658
310093
        childrenChanged = childrenChanged || nc != n[i];
659
      }
660
334622
      if (childrenChanged)
661
      {
662
        return NodeManager::currentNM()->mkNode(n.getKind(), children);
663
      }
664
    }
665
  }
666
371818
  return n;
667
}
668
669
4630
Node DatatypesRewriter::collectRef(Node n,
670
                                   std::vector<Node>& sk,
671
                                   std::map<Node, Node>& rf,
672
                                   std::vector<Node>& rf_pending,
673
                                   std::vector<Node>& terms,
674
                                   std::map<Node, bool>& cdts)
675
{
676
4630
  Assert(n.isConst());
677
9260
  TypeNode tn = n.getType();
678
9260
  Node ret = n;
679
4630
  bool isCdt = false;
680
4630
  if (tn.isDatatype())
681
  {
682
2942
    if (!tn.isCodatatype())
683
    {
684
      // nested datatype within codatatype : can be normalized independently
685
      // since all loops should be self-contained
686
48
      ret = normalizeConstant(n);
687
    }
688
    else
689
    {
690
2894
      isCdt = true;
691
2894
      if (n.getKind() == kind::APPLY_CONSTRUCTOR)
692
      {
693
2283
        sk.push_back(n);
694
2283
        rf_pending.push_back(Node::null());
695
4021
        std::vector<Node> children;
696
2283
        children.push_back(n.getOperator());
697
2283
        bool childChanged = false;
698
5391
        for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
699
        {
700
6761
          Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
701
3653
          if (nc.isNull())
702
          {
703
545
            return Node::null();
704
          }
705
3108
          childChanged = nc != n[i] || childChanged;
706
3108
          children.push_back(nc);
707
        }
708
1738
        sk.pop_back();
709
1738
        if (childChanged)
710
        {
711
495
          ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
712
                                                 children);
713
495
          if (!rf_pending.back().isNull())
714
          {
715
230
            rf[rf_pending.back()] = ret;
716
          }
717
        }
718
        else
719
        {
720
1243
          Assert(rf_pending.back().isNull());
721
        }
722
1738
        rf_pending.pop_back();
723
      }
724
      else
725
      {
726
        // a loop
727
611
        const Integer& i = n.getConst<UninterpretedConstant>().getIndex();
728
611
        uint32_t index = i.toUnsignedInt();
729
611
        if (index >= sk.size())
730
        {
731
381
          return Node::null();
732
        }
733
230
        Assert(sk.size() == rf_pending.size());
734
460
        Node r = rf_pending[rf_pending.size() - 1 - index];
735
230
        if (r.isNull())
736
        {
737
460
          r = NodeManager::currentNM()->mkBoundVar(
738
460
              sk[rf_pending.size() - 1 - index].getType());
739
230
          rf_pending[rf_pending.size() - 1 - index] = r;
740
        }
741
230
        return r;
742
      }
743
    }
744
  }
745
6948
  Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
746
3474
                           << std::endl;
747
3474
  if (std::find(terms.begin(), terms.end(), ret) == terms.end())
748
  {
749
2640
    terms.push_back(ret);
750
2640
    Assert(ret.getType() == tn);
751
2640
    cdts[ret] = isCdt;
752
  }
753
3474
  return ret;
754
}
755
// eqc_stack stores depth
756
2865
Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
757
    Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
758
{
759
5730
  Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
760
2865
                           << " depth=" << depth << std::endl;
761
2865
  if (eqc.find(n) != eqc.end())
762
  {
763
2815
    int e = eqc[n];
764
2815
    std::map<int, int>::iterator it = eqc_stack.find(e);
765
2815
    if (it != eqc_stack.end())
766
    {
767
230
      int debruijn = depth - it->second - 1;
768
      return NodeManager::currentNM()->mkConst(
769
230
          UninterpretedConstant(n.getType(), debruijn));
770
    }
771
4851
    std::vector<Node> children;
772
2585
    bool childChanged = false;
773
2585
    eqc_stack[e] = depth;
774
4854
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
775
    {
776
4538
      Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
777
2269
      children.push_back(nc);
778
2269
      childChanged = childChanged || nc != n[i];
779
    }
780
2585
    eqc_stack.erase(e);
781
2585
    if (childChanged)
782
    {
783
319
      Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
784
319
      children.insert(children.begin(), n.getOperator());
785
319
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
786
    }
787
  }
788
2316
  return n;
789
}
790
791
62
Node DatatypesRewriter::replaceDebruijn(Node n,
792
                                        Node orig,
793
                                        TypeNode orig_tn,
794
                                        unsigned depth)
795
{
796
62
  if (n.getKind() == kind::UNINTERPRETED_CONSTANT && n.getType() == orig_tn)
797
  {
798
    unsigned index =
799
1
        n.getConst<UninterpretedConstant>().getIndex().toUnsignedInt();
800
1
    if (index == depth)
801
    {
802
1
      return orig;
803
    }
804
  }
805
61
  else if (n.getNumChildren() > 0)
806
  {
807
48
    std::vector<Node> children;
808
24
    bool childChanged = false;
809
72
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
810
    {
811
96
      Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
812
48
      children.push_back(nc);
813
48
      childChanged = childChanged || nc != n[i];
814
    }
815
24
    if (childChanged)
816
    {
817
      if (n.hasOperator())
818
      {
819
        children.insert(children.begin(), n.getOperator());
820
      }
821
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
822
    }
823
  }
824
61
  return n;
825
}
826
827
84262
TrustNode DatatypesRewriter::expandDefinition(Node n)
828
{
829
84262
  NodeManager* nm = NodeManager::currentNM();
830
168524
  TypeNode tn = n.getType();
831
168524
  Node ret;
832
84262
  switch (n.getKind())
833
  {
834
4807
    case kind::APPLY_SELECTOR:
835
    {
836
9614
      Node selector = n.getOperator();
837
      // APPLY_SELECTOR always applies to an external selector, cindexOf is
838
      // legal here
839
4807
      size_t cindex = utils::cindexOf(selector);
840
4807
      const DType& dt = utils::datatypeOf(selector);
841
4807
      const DTypeConstructor& c = dt[cindex];
842
9614
      Node selector_use;
843
9614
      TypeNode ndt = n[0].getType();
844
4807
      if (options::dtSharedSelectors())
845
      {
846
4807
        size_t selectorIndex = utils::indexOf(selector);
847
9614
        Trace("dt-expand") << "...selector index = " << selectorIndex
848
4807
                           << std::endl;
849
4807
        Assert(selectorIndex < c.getNumArgs());
850
4807
        selector_use = c.getSelectorInternal(ndt, selectorIndex);
851
      }
852
      else
853
      {
854
        selector_use = selector;
855
      }
856
9614
      Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
857
4807
      if (options::dtRewriteErrorSel())
858
      {
859
392
        ret = sel;
860
      }
861
      else
862
      {
863
8830
        Node tester = c.getTester();
864
8830
        Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]);
865
4415
        SkolemManager* sm = nm->getSkolemManager();
866
8830
        TypeNode tnw = nm->mkFunctionType(ndt, n.getType());
867
        Node f =
868
8830
            sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector);
869
8830
        Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]);
870
4415
        ret = nm->mkNode(kind::ITE, tst, sel, sk);
871
8830
        Trace("dt-expand") << "Expand def : " << n << " to " << ret
872
4415
                           << std::endl;
873
4807
      }
874
    }
875
4807
    break;
876
22
    case APPLY_UPDATER:
877
    {
878
22
      Assert(tn.isDatatype());
879
22
      const DType& dt = tn.getDType();
880
44
      Node op = n.getOperator();
881
22
      size_t updateIndex = utils::indexOf(op);
882
22
      size_t cindex = utils::cindexOf(op);
883
22
      const DTypeConstructor& dc = dt[cindex];
884
44
      NodeBuilder b(APPLY_CONSTRUCTOR);
885
22
      b << dc.getConstructor();
886
22
      Trace("dt-expand") << "Expand updater " << n << std::endl;
887
22
      Trace("dt-expand") << "expr is " << n << std::endl;
888
22
      Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl;
889
22
      Trace("dt-expand") << "t is " << tn << std::endl;
890
64
      for (size_t i = 0, size = dc.getNumArgs(); i < size; ++i)
891
      {
892
42
        if (i == updateIndex)
893
        {
894
22
          b << n[1];
895
        }
896
        else
897
        {
898
100
          b << nm->mkNode(
899
80
              APPLY_SELECTOR_TOTAL, dc.getSelectorInternal(tn, i), n[0]);
900
        }
901
      }
902
22
      ret = b;
903
22
      if (dt.getNumConstructors() > 1)
904
      {
905
        // must be the right constructor to update
906
20
        Node tester = nm->mkNode(APPLY_TESTER, dc.getTester(), n[0]);
907
10
        ret = nm->mkNode(ITE, tester, ret, n[0]);
908
      }
909
44
      Trace("dt-expand") << "return " << ret << std::endl;
910
    }
911
22
    break;
912
79433
    default: break;
913
  }
914
84262
  if (!ret.isNull() && n != ret)
915
  {
916
4829
    return TrustNode::mkTrustRewrite(n, ret, nullptr);
917
  }
918
79433
  return TrustNode::null();
919
}
920
921
}  // namespace datatypes
922
}  // namespace theory
923
29337
}  // namespace cvc5