GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/datatypes/datatypes_rewriter.cpp Lines: 483 514 94.0 %
Date: 2021-09-17 Branches: 1076 2333 46.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mudathir Mohamed, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of rewriter for the theory of (co)inductive datatypes.
14
 */
15
16
#include "theory/datatypes/datatypes_rewriter.h"
17
18
#include "expr/ascription_type.h"
19
#include "expr/dtype.h"
20
#include "expr/dtype_cons.h"
21
#include "expr/node_algorithm.h"
22
#include "expr/skolem_manager.h"
23
#include "expr/sygus_datatype.h"
24
#include "expr/uninterpreted_constant.h"
25
#include "options/datatypes_options.h"
26
#include "theory/datatypes/sygus_datatype_utils.h"
27
#include "theory/datatypes/theory_datatypes_utils.h"
28
#include "theory/datatypes/tuple_project_op.h"
29
#include "util/rational.h"
30
31
using namespace cvc5;
32
using namespace cvc5::kind;
33
34
namespace cvc5 {
35
namespace theory {
36
namespace datatypes {
37
38
462867
RewriteResponse DatatypesRewriter::postRewrite(TNode in)
39
{
40
462867
  Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
41
462867
  Kind kind = in.getKind();
42
462867
  NodeManager* nm = NodeManager::currentNM();
43
462867
  if (kind == kind::APPLY_CONSTRUCTOR)
44
  {
45
124069
    return rewriteConstructor(in);
46
  }
47
338798
  else if (kind == kind::APPLY_SELECTOR_TOTAL || kind == kind::APPLY_SELECTOR)
48
  {
49
72787
    return rewriteSelector(in);
50
  }
51
266011
  else if (kind == kind::APPLY_TESTER)
52
  {
53
34379
    return rewriteTester(in);
54
  }
55
231632
  else if (kind == APPLY_UPDATER)
56
  {
57
53
    return rewriteUpdater(in);
58
  }
59
231579
  else if (kind == kind::DT_SIZE)
60
  {
61
13831
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
62
    {
63
13130
      std::vector<Node> children;
64
17469
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
65
      {
66
10904
        if (in[0][i].getType().isDatatype())
67
        {
68
10815
          children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
69
        }
70
      }
71
13130
      TNode constructor = in[0].getOperator();
72
6565
      size_t constructorIndex = utils::indexOf(constructor);
73
6565
      const DType& dt = utils::datatypeOf(constructor);
74
6565
      const DTypeConstructor& c = dt[constructorIndex];
75
6565
      unsigned weight = c.getWeight();
76
6565
      children.push_back(nm->mkConst(Rational(weight)));
77
      Node res =
78
13130
          children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
79
13130
      Trace("datatypes-rewrite")
80
6565
          << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
81
6565
          << res << std::endl;
82
6565
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
83
    }
84
  }
85
217748
  else if (kind == kind::DT_HEIGHT_BOUND)
86
  {
87
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
88
    {
89
      std::vector<Node> children;
90
      Node res;
91
      Rational r = in[1].getConst<Rational>();
92
      Rational rmo = Rational(r - Rational(1));
93
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
94
      {
95
        if (in[0][i].getType().isDatatype())
96
        {
97
          if (r.isZero())
98
          {
99
            res = nm->mkConst(false);
100
            break;
101
          }
102
          children.push_back(
103
              nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
104
        }
105
      }
106
      if (res.isNull())
107
      {
108
        res = children.size() == 0
109
                  ? nm->mkConst(true)
110
                  : (children.size() == 1 ? children[0]
111
                                          : nm->mkNode(kind::AND, children));
112
      }
113
      Trace("datatypes-rewrite")
114
          << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
115
          << res << std::endl;
116
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
117
    }
118
  }
119
217748
  else if (kind == kind::DT_SIZE_BOUND)
120
  {
121
    if (in[0].isConst())
122
    {
123
      Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
124
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
125
    }
126
  }
127
217748
  else if (kind == DT_SYGUS_EVAL)
128
  {
129
    // sygus evaluation function
130
61125
    Node ev = in[0];
131
51191
    if (ev.getKind() == APPLY_CONSTRUCTOR)
132
    {
133
41257
      Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
134
41257
      Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
135
82514
      std::vector<Node> args;
136
191271
      for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
137
      {
138
150014
        args.push_back(in[j]);
139
      }
140
82514
      Node ret = utils::sygusToBuiltinEval(ev, args);
141
41257
      Trace("dt-sygus-util") << "...got " << ret << "\n";
142
41257
      Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
143
41257
      Assert(in.getType().isComparableTo(ret.getType()));
144
41257
      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
145
    }
146
  }
147
166557
  else if (kind == MATCH)
148
  {
149
12
    Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
150
24
    Node h = in[0];
151
24
    std::vector<Node> cases;
152
24
    std::vector<Node> rets;
153
24
    TypeNode t = h.getType();
154
12
    const DType& dt = t.getDType();
155
40
    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
156
    {
157
56
      Node c = in[k];
158
56
      Node cons;
159
28
      Kind ck = c.getKind();
160
28
      if (ck == MATCH_CASE)
161
      {
162
20
        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
163
20
        cons = c[0].getOperator();
164
      }
165
8
      else if (ck == MATCH_BIND_CASE)
166
      {
167
8
        if (c[1].getKind() == APPLY_CONSTRUCTOR)
168
        {
169
4
          cons = c[1].getOperator();
170
        }
171
      }
172
      else
173
      {
174
        AlwaysAssert(false);
175
      }
176
28
      size_t cindex = 0;
177
      // cons is null in the default case
178
28
      if (!cons.isNull())
179
      {
180
24
        cindex = utils::indexOf(cons);
181
      }
182
56
      Node body;
183
28
      if (ck == MATCH_CASE)
184
      {
185
20
        body = c[1];
186
      }
187
8
      else if (ck == MATCH_BIND_CASE)
188
      {
189
16
        std::vector<Node> vars;
190
16
        std::vector<Node> subs;
191
8
        if (cons.isNull())
192
        {
193
4
          Assert(c[1].getKind() == BOUND_VARIABLE);
194
4
          vars.push_back(c[1]);
195
4
          subs.push_back(h);
196
        }
197
        else
198
        {
199
12
          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
200
          {
201
8
            vars.push_back(c[0][i]);
202
            Node sc =
203
16
                nm->mkNode(APPLY_SELECTOR, dt[cindex][i].getSelector(), h);
204
8
            subs.push_back(sc);
205
          }
206
        }
207
8
        body =
208
16
            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
209
      }
210
28
      if (!cons.isNull())
211
      {
212
24
        cases.push_back(utils::mkTester(h, cindex, dt));
213
      }
214
      else
215
      {
216
        // variables have no constraints
217
4
        cases.push_back(nm->mkConst(true));
218
      }
219
28
      rets.push_back(body);
220
    }
221
12
    Assert(!cases.empty());
222
    // now make the ITE
223
12
    std::reverse(cases.begin(), cases.end());
224
12
    std::reverse(rets.begin(), rets.end());
225
24
    Node ret = rets[0];
226
12
    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
227
28
    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
228
    {
229
16
      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
230
    }
231
24
    Trace("dt-rewrite-match")
232
12
        << "Rewrite match: " << in << " ... " << ret << std::endl;
233
12
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
234
  }
235
166545
  else if (kind == TUPLE_PROJECT)
236
  {
237
    // returns a tuple that represents
238
    // (mkTuple ((_ tupSel i_1) t) ... ((_ tupSel i_n) t))
239
    // where each i_j is less than the length of t
240
241
6
    Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl;
242
12
    TupleProjectOp op = in.getOperator().getConst<TupleProjectOp>();
243
12
    std::vector<uint32_t> indices = op.getIndices();
244
12
    Node tuple = in[0];
245
12
    std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
246
12
    std::vector<TypeNode> types;
247
12
    std::vector<Node> elements;
248
12
    for (uint32_t index : indices)
249
    {
250
12
      TypeNode type = tupleTypes[index];
251
6
      types.push_back(type);
252
    }
253
12
    TypeNode projectType = nm->mkTupleType(types);
254
6
    const DType& dt = projectType.getDType();
255
6
    elements.push_back(dt[0].getConstructor());
256
6
    const DType& tupleDType = tuple.getType().getDType();
257
6
    const DTypeConstructor& constructor = tupleDType[0];
258
12
    for (uint32_t index : indices)
259
    {
260
12
      Node selector = constructor[index].getSelector();
261
12
      Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
262
6
      elements.push_back(element);
263
    }
264
12
    Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
265
266
12
    Trace("dt-rewrite-project")
267
6
        << "Rewrite project: " << in << " ... " << ret << std::endl;
268
6
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
269
  }
270
271
183739
  if (kind == kind::EQUAL)
272
  {
273
165523
    if (in[0] == in[1])
274
    {
275
2109
      return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
276
    }
277
293706
    std::vector<Node> rew;
278
163414
    if (utils::checkClash(in[0], in[1], rew))
279
    {
280
3500
      Trace("datatypes-rewrite")
281
1750
          << "Rewrite clashing equality " << in << " to false" << std::endl;
282
1750
      return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
283
      //}else if( rew.size()==1 && rew[0]!=in ){
284
      //  Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
285
      //  rew[0] << std::endl;
286
      //  return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
287
    }
288
161664
    else if (in[1] < in[0])
289
    {
290
62744
      Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
291
62744
      Trace("datatypes-rewrite")
292
31372
          << "Swap equality " << in << " to " << ins << std::endl;
293
31372
      return RewriteResponse(REWRITE_DONE, ins);
294
    }
295
260584
    Trace("datatypes-rewrite-debug")
296
260584
        << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
297
130292
        << in[1].getKind() << std::endl;
298
  }
299
300
148508
  return RewriteResponse(REWRITE_DONE, in);
301
}
302
303
261718
RewriteResponse DatatypesRewriter::preRewrite(TNode in)
304
{
305
261718
  Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
306
  // must prewrite to apply type ascriptions since rewriting does not preserve
307
  // types
308
261718
  if (in.getKind() == kind::APPLY_CONSTRUCTOR)
309
  {
310
119639
    TypeNode tn = in.getType();
311
312
    // check for parametric datatype constructors
313
    // to ensure a normal form, all parameteric datatype constructors must have
314
    // a type ascription
315
59836
    if (tn.isParametricDatatype())
316
    {
317
266
      if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
318
      {
319
66
        Trace("datatypes-rewrite-debug")
320
33
            << "Ascribing type to parametric datatype constructor " << in
321
33
            << std::endl;
322
66
        Node op = in.getOperator();
323
        // get the constructor object
324
33
        const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)];
325
        // create ascribed constructor type
326
        Node tc = NodeManager::currentNM()->mkConst(
327
66
            AscriptionType(dtc.getSpecializedConstructorType(tn)));
328
        Node op_new = NodeManager::currentNM()->mkNode(
329
66
            kind::APPLY_TYPE_ASCRIPTION, tc, op);
330
        // make new node
331
66
        std::vector<Node> children;
332
33
        children.push_back(op_new);
333
33
        children.insert(children.end(), in.begin(), in.end());
334
        Node inr =
335
66
            NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
336
33
        Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
337
33
        return RewriteResponse(REWRITE_DONE, inr);
338
      }
339
    }
340
  }
341
261685
  return RewriteResponse(REWRITE_DONE, in);
342
}
343
344
124069
RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
345
{
346
124069
  if (in.isConst())
347
  {
348
112142
    Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
349
56071
                                     << std::endl;
350
112142
    Node inn = normalizeConstant(in);
351
    // constant may be a subterm of another constant, so cannot assume that this
352
    // will succeed for codatatypes
353
    // Assert( !inn.isNull() );
354
56071
    if (!inn.isNull() && inn != in)
355
    {
356
16
      Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
357
8
                                 << inn << std::endl;
358
8
      return RewriteResponse(REWRITE_DONE, inn);
359
    }
360
56063
    return RewriteResponse(REWRITE_DONE, in);
361
  }
362
67998
  return RewriteResponse(REWRITE_DONE, in);
363
}
364
365
72787
RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
366
{
367
72787
  Kind k = in.getKind();
368
72787
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
369
  {
370
    // Have to be careful not to rewrite well-typed expressions
371
    // where the selector doesn't match the constructor,
372
    // e.g. "pred(zero)".
373
20354
    TypeNode tn = in.getType();
374
20354
    TypeNode argType = in[0].getType();
375
20354
    Node selector = in.getOperator();
376
20354
    TNode constructor = in[0].getOperator();
377
18474
    size_t constructorIndex = utils::indexOf(constructor);
378
18474
    const DType& dt = utils::datatypeOf(selector);
379
18474
    const DTypeConstructor& c = dt[constructorIndex];
380
36948
    Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
381
18474
                                     << in;
382
36948
    Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
383
18474
                                     << ", selector is " << selector
384
18474
                                     << std::endl;
385
    // The argument that the selector extracts, or -1 if the selector is
386
    // is wrongly applied.
387
18474
    int selectorIndex = -1;
388
18474
    if (k == kind::APPLY_SELECTOR_TOTAL)
389
    {
390
      // The argument index of internal selectors is obtained by
391
      // getSelectorIndexInternal.
392
14120
      selectorIndex = c.getSelectorIndexInternal(selector);
393
    }
394
    else
395
    {
396
      // The argument index of external selectors (applications of
397
      // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
398
      // The argument is only valid if it is the proper constructor.
399
4354
      selectorIndex = utils::indexOf(selector);
400
8708
      if (selectorIndex < 0
401
4354
          || selectorIndex >= static_cast<int>(c.getNumArgs()))
402
      {
403
880
        selectorIndex = -1;
404
      }
405
3474
      else if (c[selectorIndex].getSelector() != selector)
406
      {
407
1000
        selectorIndex = -1;
408
      }
409
    }
410
36948
    Trace("datatypes-rewrite-debug") << "Internal selector index is "
411
18474
                                     << selectorIndex << std::endl;
412
18474
    if (selectorIndex >= 0)
413
    {
414
15668
      Assert(selectorIndex < (int)c.getNumArgs());
415
15668
      if (dt.isCodatatype() && in[0][selectorIndex].isConst())
416
      {
417
        // must replace all debruijn indices with self
418
14
        Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
419
28
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
420
14
                                   << "Rewrite trivial codatatype selector "
421
14
                                   << in << " to " << sub << std::endl;
422
14
        if (sub != in)
423
        {
424
14
          return RewriteResponse(REWRITE_AGAIN_FULL, sub);
425
        }
426
      }
427
      else
428
      {
429
31308
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
430
15654
                                   << "Rewrite trivial selector " << in
431
15654
                                   << std::endl;
432
15654
        return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
433
      }
434
    }
435
2806
    else if (k == kind::APPLY_SELECTOR_TOTAL)
436
    {
437
      // evaluates to the first ground value of type tn.
438
1852
      Node gt = tn.mkGroundValue();
439
926
      Assert(!gt.isNull());
440
926
      if (tn.isDatatype() && !tn.isInstantiatedDatatype())
441
      {
442
        gt = NodeManager::currentNM()->mkNode(
443
            kind::APPLY_TYPE_ASCRIPTION,
444
            NodeManager::currentNM()->mkConst(AscriptionType(tn)),
445
            gt);
446
      }
447
1852
      Trace("datatypes-rewrite")
448
926
          << "DatatypesRewriter::postRewrite: "
449
926
          << "Rewrite trivial selector " << in
450
926
          << " to distinguished ground term " << gt << std::endl;
451
926
      return RewriteResponse(REWRITE_DONE, gt);
452
    }
453
  }
454
56193
  return RewriteResponse(REWRITE_DONE, in);
455
}
456
457
34379
RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
458
{
459
34379
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
460
  {
461
    bool result =
462
1767
        utils::indexOf(in.getOperator()) == utils::indexOf(in[0].getOperator());
463
3534
    Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
464
1767
                               << "Rewrite trivial tester " << in << " "
465
1767
                               << result << std::endl;
466
    return RewriteResponse(REWRITE_DONE,
467
1767
                           NodeManager::currentNM()->mkConst(result));
468
  }
469
32612
  const DType& dt = in[0].getType().getDType();
470
32612
  if (dt.getNumConstructors() == 1 && !dt.isSygus())
471
  {
472
    // only one constructor, so it must be
473
2348
    Trace("datatypes-rewrite")
474
1174
        << "DatatypesRewriter::postRewrite: "
475
2348
        << "only one ctor for " << dt.getName() << " and that is "
476
1174
        << dt[0].getName() << std::endl;
477
    return RewriteResponse(REWRITE_DONE,
478
1174
                           NodeManager::currentNM()->mkConst(true));
479
  }
480
  // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
481
31438
  return RewriteResponse(REWRITE_DONE, in);
482
}
483
484
53
RewriteResponse DatatypesRewriter::rewriteUpdater(TNode in)
485
{
486
53
  Assert (in.getKind()==APPLY_UPDATER);
487
53
  if (in[0].getKind() == APPLY_CONSTRUCTOR)
488
  {
489
14
    Node op = in.getOperator();
490
7
    size_t cindex = utils::indexOf(in[0].getOperator());
491
7
    size_t cuindex = utils::cindexOf(op);
492
7
    if (cindex==cuindex)
493
    {
494
7
      NodeManager * nm = NodeManager::currentNM();
495
7
      size_t updateIndex = utils::indexOf(op);
496
14
      std::vector<Node> children(in[0].begin(), in[0].end());
497
7
      children[updateIndex] = in[1];
498
7
      children.insert(children.begin(),in[0].getOperator());
499
7
      return RewriteResponse(REWRITE_DONE, nm->mkNode(APPLY_CONSTRUCTOR, children));
500
    }
501
    return RewriteResponse(REWRITE_DONE, in[0]);
502
  }
503
46
  return RewriteResponse(REWRITE_DONE, in);
504
}
505
506
813
Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
507
{
508
813
  Trace("dt-nconst") << "Normalize " << n << std::endl;
509
1626
  std::map<Node, Node> rf;
510
1626
  std::vector<Node> sk;
511
1626
  std::vector<Node> rf_pending;
512
1626
  std::vector<Node> terms;
513
1626
  std::map<Node, bool> cdts;
514
1626
  Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
515
813
  if (!s.isNull())
516
  {
517
526
    Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
518
730
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
519
    {
520
408
      Trace("dt-nconst") << "  " << it->first << " = " << it->second
521
204
                         << std::endl;
522
    }
523
    // now run DFA minimization on term structure
524
1052
    Trace("dt-nconst") << "  " << terms.size()
525
526
                       << " total subterms :" << std::endl;
526
526
    int eqc_count = 0;
527
1052
    std::map<Node, int> eqc_op_map;
528
1052
    std::map<Node, int> eqc;
529
1052
    std::map<int, std::vector<Node> > eqc_nodes;
530
    // partition based on top symbol
531
2449
    for (unsigned j = 0, size = terms.size(); j < size; j++)
532
    {
533
3846
      Node t = terms[j];
534
1923
      Trace("dt-nconst") << "    " << t << ", cdt=" << cdts[t] << std::endl;
535
      int e;
536
1923
      if (cdts[t])
537
      {
538
1467
        Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
539
2934
        Node op = t.getOperator();
540
1467
        std::map<Node, int>::iterator it = eqc_op_map.find(op);
541
1467
        if (it == eqc_op_map.end())
542
        {
543
835
          e = eqc_count;
544
835
          eqc_op_map[op] = eqc_count;
545
835
          eqc_count++;
546
        }
547
        else
548
        {
549
632
          e = it->second;
550
        }
551
      }
552
      else
553
      {
554
456
        e = eqc_count;
555
456
        eqc_count++;
556
      }
557
1923
      eqc[t] = e;
558
1923
      eqc_nodes[e].push_back(t);
559
    }
560
    // partition until fixed point
561
526
    int eqc_curr = 0;
562
526
    bool success = true;
563
2506
    while (success)
564
    {
565
990
      success = false;
566
990
      int eqc_end = eqc_count;
567
5530
      while (eqc_curr < eqc_end)
568
      {
569
2270
        if (eqc_nodes[eqc_curr].size() > 1)
570
        {
571
          // look at all nodes in this equivalence class
572
553
          unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
573
1106
          std::map<int, std::vector<Node> > prt;
574
1124
          for (unsigned j = 0; j < nchildren; j++)
575
          {
576
1051
            prt.clear();
577
            // partition based on children : for the first child that causes a
578
            // split, break
579
4168
            for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
580
                 k++)
581
            {
582
6234
              Node t = eqc_nodes[eqc_curr][k];
583
3117
              Assert(t.getNumChildren() == nchildren);
584
6234
              Node tc = t[j];
585
              // refer to loops
586
3117
              std::map<Node, Node>::iterator itr = rf.find(tc);
587
3117
              if (itr != rf.end())
588
              {
589
77
                tc = itr->second;
590
              }
591
3117
              Assert(eqc.find(tc) != eqc.end());
592
3117
              prt[eqc[tc]].push_back(t);
593
            }
594
1051
            if (prt.size() > 1)
595
            {
596
480
              success = true;
597
480
              break;
598
            }
599
          }
600
          // move into new eqc(s)
601
1604
          for (const std::pair<const int, std::vector<Node> >& p : prt)
602
          {
603
1051
            int e = eqc_count;
604
2676
            for (unsigned j = 0, size = p.second.size(); j < size; j++)
605
            {
606
3250
              Node t = p.second[j];
607
1625
              eqc[t] = e;
608
1625
              eqc_nodes[e].push_back(t);
609
            }
610
1051
            eqc_count++;
611
          }
612
        }
613
2270
        eqc_nodes.erase(eqc_curr);
614
2270
        eqc_curr++;
615
      }
616
    }
617
    // add in already occurring loop variables
618
730
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
619
    {
620
408
      Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
621
204
                               << " -> " << it->second << std::endl;
622
204
      Assert(eqc.find(it->second) != eqc.end());
623
204
      eqc[it->first] = eqc[it->second];
624
    }
625
    // we now have a partition of equivalent terms
626
526
    Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
627
2653
    for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
628
    {
629
4254
      Trace("dt-nconst") << "  " << it->first << " -> " << it->second
630
2127
                         << std::endl;
631
    }
632
    // traverse top-down based on equivalence class
633
1052
    std::map<int, int> eqc_stack;
634
526
    return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
635
  }
636
287
  Trace("dt-nconst") << "...invalid." << std::endl;
637
287
  return Node::null();
638
}
639
640
// normalize constant : apply to top-level codatatype constants
641
347200
Node DatatypesRewriter::normalizeConstant(Node n)
642
{
643
694400
  TypeNode tn = n.getType();
644
347200
  if (tn.isDatatype())
645
  {
646
319019
    if (tn.isCodatatype())
647
    {
648
212
      return normalizeCodatatypeConstant(n);
649
    }
650
    else
651
    {
652
637614
      std::vector<Node> children;
653
318807
      bool childrenChanged = false;
654
609888
      for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
655
      {
656
582162
        Node nc = normalizeConstant(n[i]);
657
291081
        children.push_back(nc);
658
291081
        childrenChanged = childrenChanged || nc != n[i];
659
      }
660
318807
      if (childrenChanged)
661
      {
662
        return NodeManager::currentNM()->mkNode(n.getKind(), children);
663
      }
664
    }
665
  }
666
346988
  return n;
667
}
668
669
3794
Node DatatypesRewriter::collectRef(Node n,
670
                                   std::vector<Node>& sk,
671
                                   std::map<Node, Node>& rf,
672
                                   std::vector<Node>& rf_pending,
673
                                   std::vector<Node>& terms,
674
                                   std::map<Node, bool>& cdts)
675
{
676
3794
  Assert(n.isConst());
677
7588
  TypeNode tn = n.getType();
678
7588
  Node ret = n;
679
3794
  bool isCdt = false;
680
3794
  if (tn.isDatatype())
681
  {
682
2442
    if (!tn.isCodatatype())
683
    {
684
      // nested datatype within codatatype : can be normalized independently
685
      // since all loops should be self-contained
686
48
      ret = normalizeConstant(n);
687
    }
688
    else
689
    {
690
2394
      isCdt = true;
691
2394
      if (n.getKind() == kind::APPLY_CONSTRUCTOR)
692
      {
693
1903
        sk.push_back(n);
694
1903
        rf_pending.push_back(Node::null());
695
3397
        std::vector<Node> children;
696
1903
        children.push_back(n.getOperator());
697
1903
        bool childChanged = false;
698
4475
        for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
699
        {
700
5553
          Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
701
2981
          if (nc.isNull())
702
          {
703
409
            return Node::null();
704
          }
705
2572
          childChanged = nc != n[i] || childChanged;
706
2572
          children.push_back(nc);
707
        }
708
1494
        sk.pop_back();
709
1494
        if (childChanged)
710
        {
711
427
          ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
712
                                                 children);
713
427
          if (!rf_pending.back().isNull())
714
          {
715
204
            rf[rf_pending.back()] = ret;
716
          }
717
        }
718
        else
719
        {
720
1067
          Assert(rf_pending.back().isNull());
721
        }
722
1494
        rf_pending.pop_back();
723
      }
724
      else
725
      {
726
        // a loop
727
491
        const Integer& i = n.getConst<UninterpretedConstant>().getIndex();
728
491
        uint32_t index = i.toUnsignedInt();
729
491
        if (index >= sk.size())
730
        {
731
287
          return Node::null();
732
        }
733
204
        Assert(sk.size() == rf_pending.size());
734
408
        Node r = rf_pending[rf_pending.size() - 1 - index];
735
204
        if (r.isNull())
736
        {
737
408
          r = NodeManager::currentNM()->mkBoundVar(
738
408
              sk[rf_pending.size() - 1 - index].getType());
739
204
          rf_pending[rf_pending.size() - 1 - index] = r;
740
        }
741
204
        return r;
742
      }
743
    }
744
  }
745
5788
  Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
746
2894
                           << std::endl;
747
2894
  if (std::find(terms.begin(), terms.end(), ret) == terms.end())
748
  {
749
2234
    terms.push_back(ret);
750
2234
    Assert(ret.getType() == tn);
751
2234
    cdts[ret] = isCdt;
752
  }
753
2894
  return ret;
754
}
755
// eqc_stack stores depth
756
2479
Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
757
    Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
758
{
759
4958
  Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
760
2479
                           << " depth=" << depth << std::endl;
761
2479
  if (eqc.find(n) != eqc.end())
762
  {
763
2429
    int e = eqc[n];
764
2429
    std::map<int, int>::iterator it = eqc_stack.find(e);
765
2429
    if (it != eqc_stack.end())
766
    {
767
204
      int debruijn = depth - it->second - 1;
768
      return NodeManager::currentNM()->mkConst(
769
204
          UninterpretedConstant(n.getType(), debruijn));
770
    }
771
4157
    std::vector<Node> children;
772
2225
    bool childChanged = false;
773
2225
    eqc_stack[e] = depth;
774
4178
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
775
    {
776
3906
      Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
777
1953
      children.push_back(nc);
778
1953
      childChanged = childChanged || nc != n[i];
779
    }
780
2225
    eqc_stack.erase(e);
781
2225
    if (childChanged)
782
    {
783
293
      Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
784
293
      children.insert(children.begin(), n.getOperator());
785
293
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
786
    }
787
  }
788
1982
  return n;
789
}
790
791
62
Node DatatypesRewriter::replaceDebruijn(Node n,
792
                                        Node orig,
793
                                        TypeNode orig_tn,
794
                                        unsigned depth)
795
{
796
62
  if (n.getKind() == kind::UNINTERPRETED_CONSTANT && n.getType() == orig_tn)
797
  {
798
    unsigned index =
799
1
        n.getConst<UninterpretedConstant>().getIndex().toUnsignedInt();
800
1
    if (index == depth)
801
    {
802
1
      return orig;
803
    }
804
  }
805
61
  else if (n.getNumChildren() > 0)
806
  {
807
48
    std::vector<Node> children;
808
24
    bool childChanged = false;
809
72
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
810
    {
811
96
      Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
812
48
      children.push_back(nc);
813
48
      childChanged = childChanged || nc != n[i];
814
    }
815
24
    if (childChanged)
816
    {
817
      if (n.hasOperator())
818
      {
819
        children.insert(children.begin(), n.getOperator());
820
      }
821
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
822
    }
823
  }
824
61
  return n;
825
}
826
827
4303
Node DatatypesRewriter::expandApplySelector(Node n)
828
{
829
4303
  Assert(n.getKind() == APPLY_SELECTOR);
830
8606
  Node selector = n.getOperator();
831
  // APPLY_SELECTOR always applies to an external selector, cindexOf is
832
  // legal here
833
4303
  size_t cindex = utils::cindexOf(selector);
834
4303
  const DType& dt = utils::datatypeOf(selector);
835
4303
  const DTypeConstructor& c = dt[cindex];
836
8606
  Node selector_use;
837
8606
  TypeNode ndt = n[0].getType();
838
4303
  if (options::dtSharedSelectors())
839
  {
840
178
    size_t selectorIndex = utils::indexOf(selector);
841
178
    Trace("dt-expand") << "...selector index = " << selectorIndex << std::endl;
842
178
    Assert(selectorIndex < c.getNumArgs());
843
178
    selector_use = c.getSelectorInternal(ndt, selectorIndex);
844
  }
845
  else
846
  {
847
4125
    selector_use = selector;
848
  }
849
4303
  NodeManager* nm = NodeManager::currentNM();
850
8606
  Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
851
4303
  if (options::dtRewriteErrorSel())
852
  {
853
30
    return sel;
854
  }
855
8546
  Node tester = c.getTester();
856
8546
  Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]);
857
4273
  SkolemManager* sm = nm->getSkolemManager();
858
8546
  TypeNode tnw = nm->mkFunctionType(ndt, n.getType());
859
8546
  Node f = sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector);
860
8546
  Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]);
861
8546
  Node ret = nm->mkNode(kind::ITE, tst, sel, sk);
862
4273
  Trace("dt-expand") << "Expand def : " << n << " to " << ret << std::endl;
863
4273
  return ret;
864
}
865
866
82203
TrustNode DatatypesRewriter::expandDefinition(Node n)
867
{
868
82203
  NodeManager* nm = NodeManager::currentNM();
869
164406
  TypeNode tn = n.getType();
870
164406
  Node ret;
871
82203
  switch (n.getKind())
872
  {
873
4045
    case kind::APPLY_SELECTOR:
874
    {
875
4045
      ret = expandApplySelector(n);
876
    }
877
4045
    break;
878
22
    case APPLY_UPDATER:
879
    {
880
22
      Assert(tn.isDatatype());
881
22
      const DType& dt = tn.getDType();
882
44
      Node op = n.getOperator();
883
22
      size_t updateIndex = utils::indexOf(op);
884
22
      size_t cindex = utils::cindexOf(op);
885
22
      const DTypeConstructor& dc = dt[cindex];
886
44
      NodeBuilder b(APPLY_CONSTRUCTOR);
887
22
      b << dc.getConstructor();
888
22
      Trace("dt-expand") << "Expand updater " << n << std::endl;
889
22
      Trace("dt-expand") << "expr is " << n << std::endl;
890
22
      Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl;
891
22
      Trace("dt-expand") << "t is " << tn << std::endl;
892
64
      for (size_t i = 0, size = dc.getNumArgs(); i < size; ++i)
893
      {
894
42
        if (i == updateIndex)
895
        {
896
22
          b << n[1];
897
        }
898
        else
899
        {
900
100
          b << nm->mkNode(
901
80
              APPLY_SELECTOR_TOTAL, dc.getSelectorInternal(tn, i), n[0]);
902
        }
903
      }
904
22
      ret = b;
905
22
      if (dt.getNumConstructors() > 1)
906
      {
907
        // must be the right constructor to update
908
20
        Node tester = nm->mkNode(APPLY_TESTER, dc.getTester(), n[0]);
909
10
        ret = nm->mkNode(ITE, tester, ret, n[0]);
910
      }
911
44
      Trace("dt-expand") << "return " << ret << std::endl;
912
    }
913
22
    break;
914
78136
    default: break;
915
  }
916
82203
  if (!ret.isNull() && n != ret)
917
  {
918
4067
    return TrustNode::mkTrustRewrite(n, ret, nullptr);
919
  }
920
78136
  return TrustNode::null();
921
}
922
923
}  // namespace datatypes
924
}  // namespace theory
925
29577
}  // namespace cvc5