GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/datatypes/datatypes_rewriter.cpp Lines: 480 512 93.8 %
Date: 2021-05-24 Branches: 1070 2319 46.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mudathir Mohamed, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of rewriter for the theory of (co)inductive datatypes.
14
 */
15
16
#include "theory/datatypes/datatypes_rewriter.h"
17
18
#include "expr/dtype.h"
19
#include "expr/dtype_cons.h"
20
#include "expr/node_algorithm.h"
21
#include "expr/skolem_manager.h"
22
#include "expr/sygus_datatype.h"
23
#include "options/datatypes_options.h"
24
#include "theory/datatypes/sygus_datatype_utils.h"
25
#include "theory/datatypes/theory_datatypes_utils.h"
26
27
using namespace cvc5;
28
using namespace cvc5::kind;
29
30
namespace cvc5 {
31
namespace theory {
32
namespace datatypes {
33
34
479513
RewriteResponse DatatypesRewriter::postRewrite(TNode in)
35
{
36
479513
  Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
37
479513
  Kind kind = in.getKind();
38
479513
  NodeManager* nm = NodeManager::currentNM();
39
479513
  if (kind == kind::APPLY_CONSTRUCTOR)
40
  {
41
125081
    return rewriteConstructor(in);
42
  }
43
354432
  else if (kind == kind::APPLY_SELECTOR_TOTAL || kind == kind::APPLY_SELECTOR)
44
  {
45
68639
    return rewriteSelector(in);
46
  }
47
285793
  else if (kind == kind::APPLY_TESTER)
48
  {
49
31742
    return rewriteTester(in);
50
  }
51
254051
  else if (kind == APPLY_UPDATER)
52
  {
53
53
    return rewriteUpdater(in);
54
  }
55
253998
  else if (kind == kind::DT_SIZE)
56
  {
57
12943
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
58
    {
59
13610
      std::vector<Node> children;
60
18317
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
61
      {
62
11512
        if (in[0][i].getType().isDatatype())
63
        {
64
11423
          children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
65
        }
66
      }
67
13610
      TNode constructor = in[0].getOperator();
68
6805
      size_t constructorIndex = utils::indexOf(constructor);
69
6805
      const DType& dt = utils::datatypeOf(constructor);
70
6805
      const DTypeConstructor& c = dt[constructorIndex];
71
6805
      unsigned weight = c.getWeight();
72
6805
      children.push_back(nm->mkConst(Rational(weight)));
73
      Node res =
74
13610
          children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
75
13610
      Trace("datatypes-rewrite")
76
6805
          << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
77
6805
          << res << std::endl;
78
6805
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
79
    }
80
  }
81
241055
  else if (kind == kind::DT_HEIGHT_BOUND)
82
  {
83
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
84
    {
85
      std::vector<Node> children;
86
      Node res;
87
      Rational r = in[1].getConst<Rational>();
88
      Rational rmo = Rational(r - Rational(1));
89
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
90
      {
91
        if (in[0][i].getType().isDatatype())
92
        {
93
          if (r.isZero())
94
          {
95
            res = nm->mkConst(false);
96
            break;
97
          }
98
          children.push_back(
99
              nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
100
        }
101
      }
102
      if (res.isNull())
103
      {
104
        res = children.size() == 0
105
                  ? nm->mkConst(true)
106
                  : (children.size() == 1 ? children[0]
107
                                          : nm->mkNode(kind::AND, children));
108
      }
109
      Trace("datatypes-rewrite")
110
          << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
111
          << res << std::endl;
112
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
113
    }
114
  }
115
241055
  else if (kind == kind::DT_SIZE_BOUND)
116
  {
117
    if (in[0].isConst())
118
    {
119
      Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
120
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
121
    }
122
  }
123
241055
  else if (kind == DT_SYGUS_EVAL)
124
  {
125
    // sygus evaluation function
126
80745
    Node ev = in[0];
127
68122
    if (ev.getKind() == APPLY_CONSTRUCTOR)
128
    {
129
55499
      Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
130
55499
      Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
131
110998
      std::vector<Node> args;
132
511900
      for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
133
      {
134
456401
        args.push_back(in[j]);
135
      }
136
110998
      Node ret = utils::sygusToBuiltinEval(ev, args);
137
55499
      Trace("dt-sygus-util") << "...got " << ret << "\n";
138
55499
      Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
139
55499
      Assert(in.getType().isComparableTo(ret.getType()));
140
55499
      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
141
    }
142
  }
143
172933
  else if (kind == MATCH)
144
  {
145
12
    Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
146
24
    Node h = in[0];
147
24
    std::vector<Node> cases;
148
24
    std::vector<Node> rets;
149
24
    TypeNode t = h.getType();
150
12
    const DType& dt = t.getDType();
151
40
    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
152
    {
153
56
      Node c = in[k];
154
56
      Node cons;
155
28
      Kind ck = c.getKind();
156
28
      if (ck == MATCH_CASE)
157
      {
158
20
        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
159
20
        cons = c[0].getOperator();
160
      }
161
8
      else if (ck == MATCH_BIND_CASE)
162
      {
163
8
        if (c[1].getKind() == APPLY_CONSTRUCTOR)
164
        {
165
4
          cons = c[1].getOperator();
166
        }
167
      }
168
      else
169
      {
170
        AlwaysAssert(false);
171
      }
172
28
      size_t cindex = 0;
173
      // cons is null in the default case
174
28
      if (!cons.isNull())
175
      {
176
24
        cindex = utils::indexOf(cons);
177
      }
178
56
      Node body;
179
28
      if (ck == MATCH_CASE)
180
      {
181
20
        body = c[1];
182
      }
183
8
      else if (ck == MATCH_BIND_CASE)
184
      {
185
16
        std::vector<Node> vars;
186
16
        std::vector<Node> subs;
187
8
        if (cons.isNull())
188
        {
189
4
          Assert(c[1].getKind() == BOUND_VARIABLE);
190
4
          vars.push_back(c[1]);
191
4
          subs.push_back(h);
192
        }
193
        else
194
        {
195
12
          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
196
          {
197
8
            vars.push_back(c[0][i]);
198
            Node sc = nm->mkNode(
199
16
                APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(t, i), h);
200
8
            subs.push_back(sc);
201
          }
202
        }
203
8
        body =
204
16
            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
205
      }
206
28
      if (!cons.isNull())
207
      {
208
24
        cases.push_back(utils::mkTester(h, cindex, dt));
209
      }
210
      else
211
      {
212
        // variables have no constraints
213
4
        cases.push_back(nm->mkConst(true));
214
      }
215
28
      rets.push_back(body);
216
    }
217
12
    Assert(!cases.empty());
218
    // now make the ITE
219
12
    std::reverse(cases.begin(), cases.end());
220
12
    std::reverse(rets.begin(), rets.end());
221
24
    Node ret = rets[0];
222
12
    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
223
28
    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
224
    {
225
16
      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
226
    }
227
24
    Trace("dt-rewrite-match")
228
12
        << "Rewrite match: " << in << " ... " << ret << std::endl;
229
12
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
230
  }
231
172921
  else if (kind == TUPLE_PROJECT)
232
  {
233
    // returns a tuple that represents
234
    // (mkTuple ((_ tupSel i_1) t) ... ((_ tupSel i_n) t))
235
    // where each i_j is less than the length of t
236
237
6
    Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl;
238
12
    TupleProjectOp op = in.getOperator().getConst<TupleProjectOp>();
239
12
    std::vector<uint32_t> indices = op.getIndices();
240
12
    Node tuple = in[0];
241
12
    std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
242
12
    std::vector<TypeNode> types;
243
12
    std::vector<Node> elements;
244
12
    for (uint32_t index : indices)
245
    {
246
12
      TypeNode type = tupleTypes[index];
247
6
      types.push_back(type);
248
    }
249
12
    TypeNode projectType = nm->mkTupleType(types);
250
6
    const DType& dt = projectType.getDType();
251
6
    elements.push_back(dt[0].getConstructor());
252
6
    const DType& tupleDType = tuple.getType().getDType();
253
6
    const DTypeConstructor& constructor = tupleDType[0];
254
12
    for (uint32_t index : indices)
255
    {
256
12
      Node selector = constructor[index].getSelector();
257
12
      Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
258
6
      elements.push_back(element);
259
    }
260
12
    Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
261
262
12
    Trace("dt-rewrite-project")
263
6
        << "Rewrite project: " << in << " ... " << ret << std::endl;
264
6
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
265
  }
266
267
191676
  if (kind == kind::EQUAL)
268
  {
269
171949
    if (in[0] == in[1])
270
    {
271
4129
      return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
272
    }
273
300288
    std::vector<Node> rew;
274
167820
    if (utils::checkClash(in[0], in[1], rew))
275
    {
276
3332
      Trace("datatypes-rewrite")
277
1666
          << "Rewrite clashing equality " << in << " to false" << std::endl;
278
1666
      return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
279
      //}else if( rew.size()==1 && rew[0]!=in ){
280
      //  Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
281
      //  rew[0] << std::endl;
282
      //  return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
283
    }
284
166154
    else if (in[1] < in[0])
285
    {
286
67372
      Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
287
67372
      Trace("datatypes-rewrite")
288
33686
          << "Swap equality " << in << " to " << ins << std::endl;
289
33686
      return RewriteResponse(REWRITE_DONE, ins);
290
    }
291
264936
    Trace("datatypes-rewrite-debug")
292
264936
        << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
293
132468
        << in[1].getKind() << std::endl;
294
  }
295
296
152195
  return RewriteResponse(REWRITE_DONE, in);
297
}
298
299
278307
RewriteResponse DatatypesRewriter::preRewrite(TNode in)
300
{
301
278307
  Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
302
  // must prewrite to apply type ascriptions since rewriting does not preserve
303
  // types
304
278307
  if (in.getKind() == kind::APPLY_CONSTRUCTOR)
305
  {
306
120919
    TypeNode tn = in.getType();
307
308
    // check for parametric datatype constructors
309
    // to ensure a normal form, all parameteric datatype constructors must have
310
    // a type ascription
311
60512
    if (tn.isParametricDatatype())
312
    {
313
326
      if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
314
      {
315
210
        Trace("datatypes-rewrite-debug")
316
105
            << "Ascribing type to parametric datatype constructor " << in
317
105
            << std::endl;
318
210
        Node op = in.getOperator();
319
        // get the constructor object
320
105
        const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)];
321
        // create ascribed constructor type
322
        Node tc = NodeManager::currentNM()->mkConst(
323
210
            AscriptionType(dtc.getSpecializedConstructorType(tn)));
324
        Node op_new = NodeManager::currentNM()->mkNode(
325
210
            kind::APPLY_TYPE_ASCRIPTION, tc, op);
326
        // make new node
327
210
        std::vector<Node> children;
328
105
        children.push_back(op_new);
329
105
        children.insert(children.end(), in.begin(), in.end());
330
        Node inr =
331
210
            NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
332
105
        Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
333
105
        return RewriteResponse(REWRITE_DONE, inr);
334
      }
335
    }
336
  }
337
278202
  return RewriteResponse(REWRITE_DONE, in);
338
}
339
340
125081
RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
341
{
342
125081
  if (in.isConst())
343
  {
344
116212
    Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
345
58106
                                     << std::endl;
346
116212
    Node inn = normalizeConstant(in);
347
    // constant may be a subterm of another constant, so cannot assume that this
348
    // will succeed for codatatypes
349
    // Assert( !inn.isNull() );
350
58106
    if (!inn.isNull() && inn != in)
351
    {
352
16
      Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
353
8
                                 << inn << std::endl;
354
8
      return RewriteResponse(REWRITE_DONE, inn);
355
    }
356
58098
    return RewriteResponse(REWRITE_DONE, in);
357
  }
358
66975
  return RewriteResponse(REWRITE_DONE, in);
359
}
360
361
68639
RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
362
{
363
68639
  Kind k = in.getKind();
364
68639
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
365
  {
366
    // Have to be careful not to rewrite well-typed expressions
367
    // where the selector doesn't match the constructor,
368
    // e.g. "pred(zero)".
369
20790
    TypeNode tn = in.getType();
370
20790
    TypeNode argType = in[0].getType();
371
20790
    Node selector = in.getOperator();
372
20790
    TNode constructor = in[0].getOperator();
373
18840
    size_t constructorIndex = utils::indexOf(constructor);
374
18840
    const DType& dt = utils::datatypeOf(selector);
375
18840
    const DTypeConstructor& c = dt[constructorIndex];
376
37680
    Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
377
18840
                                     << in;
378
37680
    Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
379
18840
                                     << ", selector is " << selector
380
18840
                                     << std::endl;
381
    // The argument that the selector extracts, or -1 if the selector is
382
    // is wrongly applied.
383
18840
    int selectorIndex = -1;
384
18840
    if (k == kind::APPLY_SELECTOR_TOTAL)
385
    {
386
      // The argument index of internal selectors is obtained by
387
      // getSelectorIndexInternal.
388
13872
      selectorIndex = c.getSelectorIndexInternal(selector);
389
    }
390
    else
391
    {
392
      // The argument index of external selectors (applications of
393
      // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
394
      // The argument is only valid if it is the proper constructor.
395
4968
      selectorIndex = utils::indexOf(selector);
396
9936
      if (selectorIndex < 0
397
4968
          || selectorIndex >= static_cast<int>(c.getNumArgs()))
398
      {
399
940
        selectorIndex = -1;
400
      }
401
4028
      else if (c[selectorIndex].getSelector() != selector)
402
      {
403
1010
        selectorIndex = -1;
404
      }
405
    }
406
37680
    Trace("datatypes-rewrite-debug") << "Internal selector index is "
407
18840
                                     << selectorIndex << std::endl;
408
18840
    if (selectorIndex >= 0)
409
    {
410
16237
      Assert(selectorIndex < (int)c.getNumArgs());
411
16237
      if (dt.isCodatatype() && in[0][selectorIndex].isConst())
412
      {
413
        // must replace all debruijn indices with self
414
14
        Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
415
28
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
416
14
                                   << "Rewrite trivial codatatype selector "
417
14
                                   << in << " to " << sub << std::endl;
418
14
        if (sub != in)
419
        {
420
14
          return RewriteResponse(REWRITE_AGAIN_FULL, sub);
421
        }
422
      }
423
      else
424
      {
425
32446
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
426
16223
                                   << "Rewrite trivial selector " << in
427
16223
                                   << std::endl;
428
16223
        return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
429
      }
430
    }
431
2603
    else if (k == kind::APPLY_SELECTOR_TOTAL)
432
    {
433
      // evaluates to the first ground value of type tn.
434
1306
      Node gt = tn.mkGroundValue();
435
653
      Assert(!gt.isNull());
436
653
      if (tn.isDatatype() && !tn.isInstantiatedDatatype())
437
      {
438
        gt = NodeManager::currentNM()->mkNode(
439
            kind::APPLY_TYPE_ASCRIPTION,
440
            NodeManager::currentNM()->mkConst(AscriptionType(tn)),
441
            gt);
442
      }
443
1306
      Trace("datatypes-rewrite")
444
653
          << "DatatypesRewriter::postRewrite: "
445
653
          << "Rewrite trivial selector " << in
446
653
          << " to distinguished ground term " << gt << std::endl;
447
653
      return RewriteResponse(REWRITE_DONE, gt);
448
    }
449
  }
450
51749
  return RewriteResponse(REWRITE_DONE, in);
451
}
452
453
31742
RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
454
{
455
31742
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
456
  {
457
    bool result =
458
1920
        utils::indexOf(in.getOperator()) == utils::indexOf(in[0].getOperator());
459
3840
    Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
460
1920
                               << "Rewrite trivial tester " << in << " "
461
1920
                               << result << std::endl;
462
    return RewriteResponse(REWRITE_DONE,
463
1920
                           NodeManager::currentNM()->mkConst(result));
464
  }
465
29822
  const DType& dt = in[0].getType().getDType();
466
29822
  if (dt.getNumConstructors() == 1 && !dt.isSygus())
467
  {
468
    // only one constructor, so it must be
469
1484
    Trace("datatypes-rewrite")
470
742
        << "DatatypesRewriter::postRewrite: "
471
1484
        << "only one ctor for " << dt.getName() << " and that is "
472
742
        << dt[0].getName() << std::endl;
473
    return RewriteResponse(REWRITE_DONE,
474
742
                           NodeManager::currentNM()->mkConst(true));
475
  }
476
  // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
477
29080
  return RewriteResponse(REWRITE_DONE, in);
478
}
479
480
53
RewriteResponse DatatypesRewriter::rewriteUpdater(TNode in)
481
{
482
53
  Assert (in.getKind()==APPLY_UPDATER);
483
53
  if (in[0].getKind() == APPLY_CONSTRUCTOR)
484
  {
485
14
    Node op = in.getOperator();
486
7
    size_t cindex = utils::indexOf(in[0].getOperator());
487
7
    size_t cuindex = utils::cindexOf(op);
488
7
    if (cindex==cuindex)
489
    {
490
7
      NodeManager * nm = NodeManager::currentNM();
491
7
      size_t updateIndex = utils::indexOf(op);
492
14
      std::vector<Node> children(in[0].begin(), in[0].end());
493
7
      children[updateIndex] = in[1];
494
7
      children.insert(children.begin(),in[0].getOperator());
495
7
      return RewriteResponse(REWRITE_DONE, nm->mkNode(APPLY_CONSTRUCTOR, children));
496
    }
497
    return RewriteResponse(REWRITE_DONE, in[0]);
498
  }
499
46
  return RewriteResponse(REWRITE_DONE, in);
500
}
501
502
977
Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
503
{
504
977
  Trace("dt-nconst") << "Normalize " << n << std::endl;
505
1954
  std::map<Node, Node> rf;
506
1954
  std::vector<Node> sk;
507
1954
  std::vector<Node> rf_pending;
508
1954
  std::vector<Node> terms;
509
1954
  std::map<Node, bool> cdts;
510
1954
  Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
511
977
  if (!s.isNull())
512
  {
513
596
    Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
514
826
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
515
    {
516
460
      Trace("dt-nconst") << "  " << it->first << " = " << it->second
517
230
                         << std::endl;
518
    }
519
    // now run DFA minimization on term structure
520
1192
    Trace("dt-nconst") << "  " << terms.size()
521
596
                       << " total subterms :" << std::endl;
522
596
    int eqc_count = 0;
523
1192
    std::map<Node, int> eqc_op_map;
524
1192
    std::map<Node, int> eqc;
525
1192
    std::map<int, std::vector<Node> > eqc_nodes;
526
    // partition based on top symbol
527
2831
    for (unsigned j = 0, size = terms.size(); j < size; j++)
528
    {
529
4470
      Node t = terms[j];
530
2235
      Trace("dt-nconst") << "    " << t << ", cdt=" << cdts[t] << std::endl;
531
      int e;
532
2235
      if (cdts[t])
533
      {
534
1711
        Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
535
3422
        Node op = t.getOperator();
536
1711
        std::map<Node, int>::iterator it = eqc_op_map.find(op);
537
1711
        if (it == eqc_op_map.end())
538
        {
539
947
          e = eqc_count;
540
947
          eqc_op_map[op] = eqc_count;
541
947
          eqc_count++;
542
        }
543
        else
544
        {
545
764
          e = it->second;
546
        }
547
      }
548
      else
549
      {
550
524
        e = eqc_count;
551
524
        eqc_count++;
552
      }
553
2235
      eqc[t] = e;
554
2235
      eqc_nodes[e].push_back(t);
555
    }
556
    // partition until fixed point
557
596
    int eqc_curr = 0;
558
596
    bool success = true;
559
2896
    while (success)
560
    {
561
1150
      success = false;
562
1150
      int eqc_end = eqc_count;
563
6410
      while (eqc_curr < eqc_end)
564
      {
565
2630
        if (eqc_nodes[eqc_curr].size() > 1)
566
        {
567
          // look at all nodes in this equivalence class
568
661
          unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
569
1322
          std::map<int, std::vector<Node> > prt;
570
1358
          for (unsigned j = 0; j < nchildren; j++)
571
          {
572
1267
            prt.clear();
573
            // partition based on children : for the first child that causes a
574
            // split, break
575
5064
            for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
576
                 k++)
577
            {
578
7594
              Node t = eqc_nodes[eqc_curr][k];
579
3797
              Assert(t.getNumChildren() == nchildren);
580
7594
              Node tc = t[j];
581
              // refer to loops
582
3797
              std::map<Node, Node>::iterator itr = rf.find(tc);
583
3797
              if (itr != rf.end())
584
              {
585
95
                tc = itr->second;
586
              }
587
3797
              Assert(eqc.find(tc) != eqc.end());
588
3797
              prt[eqc[tc]].push_back(t);
589
            }
590
1267
            if (prt.size() > 1)
591
            {
592
570
              success = true;
593
570
              break;
594
            }
595
          }
596
          // move into new eqc(s)
597
1910
          for (const std::pair<const int, std::vector<Node> >& p : prt)
598
          {
599
1249
            int e = eqc_count;
600
3214
            for (unsigned j = 0, size = p.second.size(); j < size; j++)
601
            {
602
3930
              Node t = p.second[j];
603
1965
              eqc[t] = e;
604
1965
              eqc_nodes[e].push_back(t);
605
            }
606
1249
            eqc_count++;
607
          }
608
        }
609
2630
        eqc_nodes.erase(eqc_curr);
610
2630
        eqc_curr++;
611
      }
612
    }
613
    // add in already occurring loop variables
614
826
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
615
    {
616
460
      Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
617
230
                               << " -> " << it->second << std::endl;
618
230
      Assert(eqc.find(it->second) != eqc.end());
619
230
      eqc[it->first] = eqc[it->second];
620
    }
621
    // we now have a partition of equivalent terms
622
596
    Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
623
3061
    for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
624
    {
625
4930
      Trace("dt-nconst") << "  " << it->first << " -> " << it->second
626
2465
                         << std::endl;
627
    }
628
    // traverse top-down based on equivalence class
629
1192
    std::map<int, int> eqc_stack;
630
596
    return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
631
  }
632
381
  Trace("dt-nconst") << "...invalid." << std::endl;
633
381
  return Node::null();
634
}
635
636
// normalize constant : apply to top-level codatatype constants
637
366389
Node DatatypesRewriter::normalizeConstant(Node n)
638
{
639
732778
  TypeNode tn = n.getType();
640
366389
  if (tn.isDatatype())
641
  {
642
332495
    if (tn.isCodatatype())
643
    {
644
212
      return normalizeCodatatypeConstant(n);
645
    }
646
    else
647
    {
648
664566
      std::vector<Node> children;
649
332283
      bool childrenChanged = false;
650
640518
      for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
651
      {
652
616470
        Node nc = normalizeConstant(n[i]);
653
308235
        children.push_back(nc);
654
308235
        childrenChanged = childrenChanged || nc != n[i];
655
      }
656
332283
      if (childrenChanged)
657
      {
658
        return NodeManager::currentNM()->mkNode(n.getKind(), children);
659
      }
660
    }
661
  }
662
366177
  return n;
663
}
664
665
4630
Node DatatypesRewriter::collectRef(Node n,
666
                                   std::vector<Node>& sk,
667
                                   std::map<Node, Node>& rf,
668
                                   std::vector<Node>& rf_pending,
669
                                   std::vector<Node>& terms,
670
                                   std::map<Node, bool>& cdts)
671
{
672
4630
  Assert(n.isConst());
673
9260
  TypeNode tn = n.getType();
674
9260
  Node ret = n;
675
4630
  bool isCdt = false;
676
4630
  if (tn.isDatatype())
677
  {
678
2942
    if (!tn.isCodatatype())
679
    {
680
      // nested datatype within codatatype : can be normalized independently
681
      // since all loops should be self-contained
682
48
      ret = normalizeConstant(n);
683
    }
684
    else
685
    {
686
2894
      isCdt = true;
687
2894
      if (n.getKind() == kind::APPLY_CONSTRUCTOR)
688
      {
689
2283
        sk.push_back(n);
690
2283
        rf_pending.push_back(Node::null());
691
4021
        std::vector<Node> children;
692
2283
        children.push_back(n.getOperator());
693
2283
        bool childChanged = false;
694
5391
        for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
695
        {
696
6761
          Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
697
3653
          if (nc.isNull())
698
          {
699
545
            return Node::null();
700
          }
701
3108
          childChanged = nc != n[i] || childChanged;
702
3108
          children.push_back(nc);
703
        }
704
1738
        sk.pop_back();
705
1738
        if (childChanged)
706
        {
707
495
          ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
708
                                                 children);
709
495
          if (!rf_pending.back().isNull())
710
          {
711
230
            rf[rf_pending.back()] = ret;
712
          }
713
        }
714
        else
715
        {
716
1243
          Assert(rf_pending.back().isNull());
717
        }
718
1738
        rf_pending.pop_back();
719
      }
720
      else
721
      {
722
        // a loop
723
611
        const Integer& i = n.getConst<UninterpretedConstant>().getIndex();
724
611
        uint32_t index = i.toUnsignedInt();
725
611
        if (index >= sk.size())
726
        {
727
381
          return Node::null();
728
        }
729
230
        Assert(sk.size() == rf_pending.size());
730
460
        Node r = rf_pending[rf_pending.size() - 1 - index];
731
230
        if (r.isNull())
732
        {
733
460
          r = NodeManager::currentNM()->mkBoundVar(
734
460
              sk[rf_pending.size() - 1 - index].getType());
735
230
          rf_pending[rf_pending.size() - 1 - index] = r;
736
        }
737
230
        return r;
738
      }
739
    }
740
  }
741
6948
  Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
742
3474
                           << std::endl;
743
3474
  if (std::find(terms.begin(), terms.end(), ret) == terms.end())
744
  {
745
2640
    terms.push_back(ret);
746
2640
    Assert(ret.getType() == tn);
747
2640
    cdts[ret] = isCdt;
748
  }
749
3474
  return ret;
750
}
751
// eqc_stack stores depth
752
2865
Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
753
    Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
754
{
755
5730
  Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
756
2865
                           << " depth=" << depth << std::endl;
757
2865
  if (eqc.find(n) != eqc.end())
758
  {
759
2815
    int e = eqc[n];
760
2815
    std::map<int, int>::iterator it = eqc_stack.find(e);
761
2815
    if (it != eqc_stack.end())
762
    {
763
230
      int debruijn = depth - it->second - 1;
764
      return NodeManager::currentNM()->mkConst(
765
230
          UninterpretedConstant(n.getType(), debruijn));
766
    }
767
4851
    std::vector<Node> children;
768
2585
    bool childChanged = false;
769
2585
    eqc_stack[e] = depth;
770
4854
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
771
    {
772
4538
      Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
773
2269
      children.push_back(nc);
774
2269
      childChanged = childChanged || nc != n[i];
775
    }
776
2585
    eqc_stack.erase(e);
777
2585
    if (childChanged)
778
    {
779
319
      Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
780
319
      children.insert(children.begin(), n.getOperator());
781
319
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
782
    }
783
  }
784
2316
  return n;
785
}
786
787
62
Node DatatypesRewriter::replaceDebruijn(Node n,
788
                                        Node orig,
789
                                        TypeNode orig_tn,
790
                                        unsigned depth)
791
{
792
62
  if (n.getKind() == kind::UNINTERPRETED_CONSTANT && n.getType() == orig_tn)
793
  {
794
    unsigned index =
795
1
        n.getConst<UninterpretedConstant>().getIndex().toUnsignedInt();
796
1
    if (index == depth)
797
    {
798
1
      return orig;
799
    }
800
  }
801
61
  else if (n.getNumChildren() > 0)
802
  {
803
48
    std::vector<Node> children;
804
24
    bool childChanged = false;
805
72
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
806
    {
807
96
      Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
808
48
      children.push_back(nc);
809
48
      childChanged = childChanged || nc != n[i];
810
    }
811
24
    if (childChanged)
812
    {
813
      if (n.hasOperator())
814
      {
815
        children.insert(children.begin(), n.getOperator());
816
      }
817
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
818
    }
819
  }
820
61
  return n;
821
}
822
823
76555
TrustNode DatatypesRewriter::expandDefinition(Node n)
824
{
825
76555
  NodeManager* nm = NodeManager::currentNM();
826
153110
  TypeNode tn = n.getType();
827
153110
  Node ret;
828
76555
  switch (n.getKind())
829
  {
830
3866
    case kind::APPLY_SELECTOR:
831
    {
832
7732
      Node selector = n.getOperator();
833
      // APPLY_SELECTOR always applies to an external selector, cindexOf is
834
      // legal here
835
3866
      size_t cindex = utils::cindexOf(selector);
836
3866
      const DType& dt = utils::datatypeOf(selector);
837
3866
      const DTypeConstructor& c = dt[cindex];
838
7732
      Node selector_use;
839
7732
      TypeNode ndt = n[0].getType();
840
684638
      if (options::dtSharedSelectors())
841
      {
842
3866
        size_t selectorIndex = utils::indexOf(selector);
843
7732
        Trace("dt-expand") << "...selector index = " << selectorIndex
844
3866
                           << std::endl;
845
3866
        Assert(selectorIndex < c.getNumArgs());
846
3866
        selector_use = c.getSelectorInternal(ndt, selectorIndex);
847
      }
848
      else
849
      {
850
        selector_use = selector;
851
      }
852
7732
      Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
853
3866
      if (options::dtRewriteErrorSel())
854
      {
855
323
        ret = sel;
856
      }
857
      else
858
      {
859
7086
        Node tester = c.getTester();
860
7086
        Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]);
861
3543
        SkolemManager* sm = nm->getSkolemManager();
862
7086
        TypeNode tnw = nm->mkFunctionType(ndt, n.getType());
863
        Node f =
864
7086
            sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector);
865
7086
        Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]);
866
3543
        ret = nm->mkNode(kind::ITE, tst, sel, sk);
867
7086
        Trace("dt-expand") << "Expand def : " << n << " to " << ret
868
3543
                           << std::endl;
869
3866
      }
870
    }
871
3866
    break;
872
22
    case APPLY_UPDATER:
873
    {
874
22
      Assert(tn.isDatatype());
875
22
      const DType& dt = tn.getDType();
876
44
      Node op = n.getOperator();
877
22
      size_t updateIndex = utils::indexOf(op);
878
22
      size_t cindex = utils::cindexOf(op);
879
22
      const DTypeConstructor& dc = dt[cindex];
880
44
      NodeBuilder b(APPLY_CONSTRUCTOR);
881
22
      b << dc.getConstructor();
882
22
      Trace("dt-expand") << "Expand updater " << n << std::endl;
883
22
      Trace("dt-expand") << "expr is " << n << std::endl;
884
22
      Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl;
885
22
      Trace("dt-expand") << "t is " << tn << std::endl;
886
64
      for (size_t i = 0, size = dc.getNumArgs(); i < size; ++i)
887
      {
888
42
        if (i == updateIndex)
889
        {
890
22
          b << n[1];
891
        }
892
        else
893
        {
894
100
          b << nm->mkNode(
895
80
              APPLY_SELECTOR_TOTAL, dc.getSelectorInternal(tn, i), n[0]);
896
        }
897
      }
898
22
      ret = b;
899
22
      if (dt.getNumConstructors() > 1)
900
      {
901
        // must be the right constructor to update
902
20
        Node tester = nm->mkNode(APPLY_TESTER, dc.getTester(), n[0]);
903
10
        ret = nm->mkNode(ITE, tester, ret, n[0]);
904
      }
905
44
      Trace("dt-expand") << "return " << ret << std::endl;
906
    }
907
22
    break;
908
72667
    default: break;
909
  }
910
76555
  if (!ret.isNull() && n != ret)
911
  {
912
3888
    return TrustNode::mkTrustRewrite(n, ret, nullptr);
913
  }
914
72667
  return TrustNode::null();
915
}
916
917
}  // namespace datatypes
918
}  // namespace theory
919
708963
}  // namespace cvc5