GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/datatypes/datatypes_rewriter.cpp Lines: 549 583 94.2 %
Date: 2021-11-07 Branches: 1201 2649 45.3 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mudathir Mohamed, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of rewriter for the theory of (co)inductive datatypes.
14
 */
15
16
#include "theory/datatypes/datatypes_rewriter.h"
17
18
#include "expr/ascription_type.h"
19
#include "expr/codatatype_bound_variable.h"
20
#include "expr/dtype.h"
21
#include "expr/dtype_cons.h"
22
#include "expr/node_algorithm.h"
23
#include "expr/skolem_manager.h"
24
#include "expr/sygus_datatype.h"
25
#include "options/datatypes_options.h"
26
#include "theory/datatypes/sygus_datatype_utils.h"
27
#include "theory/datatypes/theory_datatypes_utils.h"
28
#include "theory/datatypes/tuple_project_op.h"
29
#include "util/rational.h"
30
31
using namespace cvc5;
32
using namespace cvc5::kind;
33
34
namespace cvc5 {
35
namespace theory {
36
namespace datatypes {
37
38
15273
DatatypesRewriter::DatatypesRewriter(Evaluator* sygusEval)
39
15273
    : d_sygusEval(sygusEval)
40
{
41
15273
}
42
43
610405
RewriteResponse DatatypesRewriter::postRewrite(TNode in)
44
{
45
610405
  Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
46
610405
  Kind kind = in.getKind();
47
610405
  NodeManager* nm = NodeManager::currentNM();
48
610405
  if (kind == kind::APPLY_CONSTRUCTOR)
49
  {
50
174541
    return rewriteConstructor(in);
51
  }
52
435864
  else if (kind == kind::APPLY_SELECTOR_TOTAL || kind == kind::APPLY_SELECTOR)
53
  {
54
93705
    return rewriteSelector(in);
55
  }
56
342159
  else if (kind == kind::APPLY_TESTER)
57
  {
58
42994
    return rewriteTester(in);
59
  }
60
299165
  else if (kind == APPLY_UPDATER)
61
  {
62
75
    return rewriteUpdater(in);
63
  }
64
299090
  else if (kind == kind::DT_SIZE)
65
  {
66
26983
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
67
    {
68
30826
      std::vector<Node> children;
69
42579
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
70
      {
71
27166
        if (in[0][i].getType().isDatatype())
72
        {
73
26999
          children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
74
        }
75
      }
76
30826
      TNode constructor = in[0].getOperator();
77
15413
      size_t constructorIndex = utils::indexOf(constructor);
78
15413
      const DType& dt = utils::datatypeOf(constructor);
79
15413
      const DTypeConstructor& c = dt[constructorIndex];
80
15413
      unsigned weight = c.getWeight();
81
15413
      children.push_back(nm->mkConst(Rational(weight)));
82
      Node res =
83
30826
          children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
84
30826
      Trace("datatypes-rewrite")
85
15413
          << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
86
15413
          << res << std::endl;
87
15413
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
88
    }
89
  }
90
272107
  else if (kind == kind::DT_HEIGHT_BOUND)
91
  {
92
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
93
    {
94
      std::vector<Node> children;
95
      Node res;
96
      Rational r = in[1].getConst<Rational>();
97
      Rational rmo = Rational(r - Rational(1));
98
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
99
      {
100
        if (in[0][i].getType().isDatatype())
101
        {
102
          if (r.isZero())
103
          {
104
            res = nm->mkConst(false);
105
            break;
106
          }
107
          children.push_back(
108
              nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
109
        }
110
      }
111
      if (res.isNull())
112
      {
113
        res = children.size() == 0
114
                  ? nm->mkConst(true)
115
                  : (children.size() == 1 ? children[0]
116
                                          : nm->mkNode(kind::AND, children));
117
      }
118
      Trace("datatypes-rewrite")
119
          << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
120
          << res << std::endl;
121
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
122
    }
123
  }
124
272107
  else if (kind == kind::DT_SIZE_BOUND)
125
  {
126
    if (in[0].isConst())
127
    {
128
      Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
129
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
130
    }
131
  }
132
272107
  else if (kind == DT_SYGUS_EVAL)
133
  {
134
    // sygus evaluation function
135
132660
    Node ev = in[0];
136
110914
    if (ev.getKind() == APPLY_CONSTRUCTOR)
137
    {
138
89168
      Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
139
89168
      Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
140
178336
      std::vector<Node> args;
141
390401
      for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
142
      {
143
301233
        args.push_back(in[j]);
144
      }
145
178336
      Node ret = sygusToBuiltinEval(ev, args);
146
89168
      Trace("dt-sygus-util") << "...got " << ret << "\n";
147
89168
      Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
148
89168
      Assert(in.getType().isComparableTo(ret.getType()));
149
89168
      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
150
    }
151
  }
152
161193
  else if (kind == MATCH)
153
  {
154
12
    Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
155
24
    Node h = in[0];
156
24
    std::vector<Node> cases;
157
24
    std::vector<Node> rets;
158
24
    TypeNode t = h.getType();
159
12
    const DType& dt = t.getDType();
160
40
    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
161
    {
162
56
      Node c = in[k];
163
56
      Node cons;
164
28
      Kind ck = c.getKind();
165
28
      if (ck == MATCH_CASE)
166
      {
167
20
        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
168
20
        cons = c[0].getOperator();
169
      }
170
8
      else if (ck == MATCH_BIND_CASE)
171
      {
172
8
        if (c[1].getKind() == APPLY_CONSTRUCTOR)
173
        {
174
4
          cons = c[1].getOperator();
175
        }
176
      }
177
      else
178
      {
179
        AlwaysAssert(false);
180
      }
181
28
      size_t cindex = 0;
182
      // cons is null in the default case
183
28
      if (!cons.isNull())
184
      {
185
24
        cindex = utils::indexOf(cons);
186
      }
187
56
      Node body;
188
28
      if (ck == MATCH_CASE)
189
      {
190
20
        body = c[1];
191
      }
192
8
      else if (ck == MATCH_BIND_CASE)
193
      {
194
16
        std::vector<Node> vars;
195
16
        std::vector<Node> subs;
196
8
        if (cons.isNull())
197
        {
198
4
          Assert(c[1].getKind() == BOUND_VARIABLE);
199
4
          vars.push_back(c[1]);
200
4
          subs.push_back(h);
201
        }
202
        else
203
        {
204
12
          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
205
          {
206
8
            vars.push_back(c[0][i]);
207
            Node sc =
208
16
                nm->mkNode(APPLY_SELECTOR, dt[cindex][i].getSelector(), h);
209
8
            subs.push_back(sc);
210
          }
211
        }
212
8
        body =
213
16
            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
214
      }
215
28
      if (!cons.isNull())
216
      {
217
24
        cases.push_back(utils::mkTester(h, cindex, dt));
218
      }
219
      else
220
      {
221
        // variables have no constraints
222
4
        cases.push_back(nm->mkConst(true));
223
      }
224
28
      rets.push_back(body);
225
    }
226
12
    Assert(!cases.empty());
227
    // now make the ITE
228
12
    std::reverse(cases.begin(), cases.end());
229
12
    std::reverse(rets.begin(), rets.end());
230
24
    Node ret = rets[0];
231
12
    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
232
28
    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
233
    {
234
16
      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
235
    }
236
24
    Trace("dt-rewrite-match")
237
12
        << "Rewrite match: " << in << " ... " << ret << std::endl;
238
12
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
239
  }
240
161181
  else if (kind == TUPLE_PROJECT)
241
  {
242
    // returns a tuple that represents
243
    // (mkTuple ((_ tupSel i_1) t) ... ((_ tupSel i_n) t))
244
    // where each i_j is less than the length of t
245
246
6
    Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl;
247
12
    TupleProjectOp op = in.getOperator().getConst<TupleProjectOp>();
248
12
    std::vector<uint32_t> indices = op.getIndices();
249
12
    Node tuple = in[0];
250
12
    std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
251
12
    std::vector<TypeNode> types;
252
12
    std::vector<Node> elements;
253
12
    for (uint32_t index : indices)
254
    {
255
12
      TypeNode type = tupleTypes[index];
256
6
      types.push_back(type);
257
    }
258
12
    TypeNode projectType = nm->mkTupleType(types);
259
6
    const DType& dt = projectType.getDType();
260
6
    elements.push_back(dt[0].getConstructor());
261
6
    const DType& tupleDType = tuple.getType().getDType();
262
6
    const DTypeConstructor& constructor = tupleDType[0];
263
12
    for (uint32_t index : indices)
264
    {
265
12
      Node selector = constructor[index].getSelector();
266
12
      Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
267
6
      elements.push_back(element);
268
    }
269
12
    Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
270
271
12
    Trace("dt-rewrite-project")
272
6
        << "Rewrite project: " << in << " ... " << ret << std::endl;
273
6
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
274
  }
275
276
194491
  if (kind == kind::EQUAL)
277
  {
278
159443
    if (in[0] == in[1])
279
    {
280
2998
      return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
281
    }
282
277340
    std::vector<Node> rew;
283
156445
    if (utils::checkClash(in[0], in[1], rew))
284
    {
285
3546
      Trace("datatypes-rewrite")
286
1773
          << "Rewrite clashing equality " << in << " to false" << std::endl;
287
1773
      return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
288
      //}else if( rew.size()==1 && rew[0]!=in ){
289
      //  Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
290
      //  rew[0] << std::endl;
291
      //  return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
292
    }
293
154672
    else if (in[1] < in[0])
294
    {
295
67554
      Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
296
67554
      Trace("datatypes-rewrite")
297
33777
          << "Swap equality " << in << " to " << ins << std::endl;
298
33777
      return RewriteResponse(REWRITE_DONE, ins);
299
    }
300
241790
    Trace("datatypes-rewrite-debug")
301
241790
        << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
302
120895
        << in[1].getKind() << std::endl;
303
  }
304
305
155943
  return RewriteResponse(REWRITE_DONE, in);
306
}
307
308
365373
RewriteResponse DatatypesRewriter::preRewrite(TNode in)
309
{
310
365373
  Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
311
  // must prewrite to apply type ascriptions since rewriting does not preserve
312
  // types
313
365373
  if (in.getKind() == kind::APPLY_CONSTRUCTOR)
314
  {
315
164347
    TypeNode tn = in.getType();
316
317
    // check for parametric datatype constructors
318
    // to ensure a normal form, all parameteric datatype constructors must have
319
    // a type ascription
320
82191
    if (tn.isParametricDatatype())
321
    {
322
274
      if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
323
      {
324
70
        Trace("datatypes-rewrite-debug")
325
35
            << "Ascribing type to parametric datatype constructor " << in
326
35
            << std::endl;
327
70
        Node op = in.getOperator();
328
        // get the constructor object
329
35
        const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)];
330
        // create ascribed constructor type
331
        Node tc = NodeManager::currentNM()->mkConst(
332
70
            AscriptionType(dtc.getSpecializedConstructorType(tn)));
333
        Node op_new = NodeManager::currentNM()->mkNode(
334
70
            kind::APPLY_TYPE_ASCRIPTION, tc, op);
335
        // make new node
336
70
        std::vector<Node> children;
337
35
        children.push_back(op_new);
338
35
        children.insert(children.end(), in.begin(), in.end());
339
        Node inr =
340
70
            NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
341
35
        Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
342
35
        return RewriteResponse(REWRITE_DONE, inr);
343
      }
344
    }
345
  }
346
365338
  return RewriteResponse(REWRITE_DONE, in);
347
}
348
349
174541
RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
350
{
351
174541
  if (in.isConst())
352
  {
353
165752
    Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
354
82876
                                     << std::endl;
355
165752
    Node inn = normalizeConstant(in);
356
    // constant may be a subterm of another constant, so cannot assume that this
357
    // will succeed for codatatypes
358
    // Assert( !inn.isNull() );
359
82876
    if (!inn.isNull() && inn != in)
360
    {
361
12
      Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
362
6
                                 << inn << std::endl;
363
6
      return RewriteResponse(REWRITE_DONE, inn);
364
    }
365
82870
    return RewriteResponse(REWRITE_DONE, in);
366
  }
367
91665
  return RewriteResponse(REWRITE_DONE, in);
368
}
369
370
93705
RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
371
{
372
93705
  Kind k = in.getKind();
373
93705
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
374
  {
375
    // Have to be careful not to rewrite well-typed expressions
376
    // where the selector doesn't match the constructor,
377
    // e.g. "pred(zero)".
378
32235
    TypeNode tn = in.getType();
379
32235
    TypeNode argType = in[0].getType();
380
32235
    Node selector = in.getOperator();
381
32235
    TNode constructor = in[0].getOperator();
382
29839
    size_t constructorIndex = utils::indexOf(constructor);
383
29839
    const DType& dt = utils::datatypeOf(selector);
384
29839
    const DTypeConstructor& c = dt[constructorIndex];
385
59678
    Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
386
29839
                                     << in;
387
59678
    Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
388
29839
                                     << ", selector is " << selector
389
29839
                                     << std::endl;
390
    // The argument that the selector extracts, or -1 if the selector is
391
    // is wrongly applied.
392
29839
    int selectorIndex = -1;
393
29839
    if (k == kind::APPLY_SELECTOR_TOTAL)
394
    {
395
      // The argument index of internal selectors is obtained by
396
      // getSelectorIndexInternal.
397
24647
      selectorIndex = c.getSelectorIndexInternal(selector);
398
    }
399
    else
400
    {
401
      // The argument index of external selectors (applications of
402
      // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
403
      // The argument is only valid if it is the proper constructor.
404
5192
      selectorIndex = utils::indexOf(selector);
405
10384
      if (selectorIndex < 0
406
5192
          || selectorIndex >= static_cast<int>(c.getNumArgs()))
407
      {
408
894
        selectorIndex = -1;
409
      }
410
4298
      else if (c[selectorIndex].getSelector() != selector)
411
      {
412
1502
        selectorIndex = -1;
413
      }
414
    }
415
59678
    Trace("datatypes-rewrite-debug") << "Internal selector index is "
416
29839
                                     << selectorIndex << std::endl;
417
29839
    if (selectorIndex >= 0)
418
    {
419
25736
      Assert(selectorIndex < (int)c.getNumArgs());
420
25736
      if (dt.isCodatatype() && in[0][selectorIndex].isConst())
421
      {
422
        // must replace all debruijn indices with self
423
14
        Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
424
28
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
425
14
                                   << "Rewrite trivial codatatype selector "
426
14
                                   << in << " to " << sub << std::endl;
427
14
        if (sub != in)
428
        {
429
14
          return RewriteResponse(REWRITE_AGAIN_FULL, sub);
430
        }
431
      }
432
      else
433
      {
434
51444
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
435
25722
                                   << "Rewrite trivial selector " << in
436
25722
                                   << std::endl;
437
25722
        return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
438
      }
439
    }
440
4103
    else if (k == kind::APPLY_SELECTOR_TOTAL)
441
    {
442
      // evaluates to the first ground value of type tn.
443
3414
      Node gt = tn.mkGroundValue();
444
1707
      Assert(!gt.isNull());
445
1707
      if (tn.isDatatype() && !tn.isInstantiatedDatatype())
446
      {
447
        gt = NodeManager::currentNM()->mkNode(
448
            kind::APPLY_TYPE_ASCRIPTION,
449
            NodeManager::currentNM()->mkConst(AscriptionType(tn)),
450
            gt);
451
      }
452
3414
      Trace("datatypes-rewrite")
453
1707
          << "DatatypesRewriter::postRewrite: "
454
1707
          << "Rewrite trivial selector " << in
455
1707
          << " to distinguished ground term " << gt << std::endl;
456
1707
      return RewriteResponse(REWRITE_DONE, gt);
457
    }
458
  }
459
66262
  return RewriteResponse(REWRITE_DONE, in);
460
}
461
462
42994
RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
463
{
464
42994
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
465
  {
466
    bool result =
467
2096
        utils::indexOf(in.getOperator()) == utils::indexOf(in[0].getOperator());
468
4192
    Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
469
2096
                               << "Rewrite trivial tester " << in << " "
470
2096
                               << result << std::endl;
471
    return RewriteResponse(REWRITE_DONE,
472
2096
                           NodeManager::currentNM()->mkConst(result));
473
  }
474
40898
  const DType& dt = in[0].getType().getDType();
475
40898
  if (dt.getNumConstructors() == 1 && !dt.isSygus())
476
  {
477
    // only one constructor, so it must be
478
2484
    Trace("datatypes-rewrite")
479
1242
        << "DatatypesRewriter::postRewrite: "
480
2484
        << "only one ctor for " << dt.getName() << " and that is "
481
1242
        << dt[0].getName() << std::endl;
482
    return RewriteResponse(REWRITE_DONE,
483
1242
                           NodeManager::currentNM()->mkConst(true));
484
  }
485
  // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
486
39656
  return RewriteResponse(REWRITE_DONE, in);
487
}
488
489
75
RewriteResponse DatatypesRewriter::rewriteUpdater(TNode in)
490
{
491
75
  Assert (in.getKind()==APPLY_UPDATER);
492
75
  if (in[0].getKind() == APPLY_CONSTRUCTOR)
493
  {
494
18
    Node op = in.getOperator();
495
9
    size_t cindex = utils::indexOf(in[0].getOperator());
496
9
    size_t cuindex = utils::cindexOf(op);
497
9
    if (cindex==cuindex)
498
    {
499
9
      NodeManager * nm = NodeManager::currentNM();
500
9
      size_t updateIndex = utils::indexOf(op);
501
18
      std::vector<Node> children(in[0].begin(), in[0].end());
502
9
      children[updateIndex] = in[1];
503
9
      children.insert(children.begin(),in[0].getOperator());
504
9
      return RewriteResponse(REWRITE_DONE, nm->mkNode(APPLY_CONSTRUCTOR, children));
505
    }
506
    return RewriteResponse(REWRITE_DONE, in[0]);
507
  }
508
66
  return RewriteResponse(REWRITE_DONE, in);
509
}
510
511
1537
Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
512
{
513
1537
  Trace("dt-nconst") << "Normalize " << n << std::endl;
514
3074
  std::map<Node, Node> rf;
515
3074
  std::vector<Node> sk;
516
3074
  std::vector<Node> rf_pending;
517
3074
  std::vector<Node> terms;
518
3074
  std::map<Node, bool> cdts;
519
3074
  Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
520
1537
  if (!s.isNull())
521
  {
522
632
    Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
523
904
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
524
    {
525
544
      Trace("dt-nconst") << "  " << it->first << " = " << it->second
526
272
                         << std::endl;
527
    }
528
    // now run DFA minimization on term structure
529
1264
    Trace("dt-nconst") << "  " << terms.size()
530
632
                       << " total subterms :" << std::endl;
531
632
    int eqc_count = 0;
532
1264
    std::map<Node, int> eqc_op_map;
533
1264
    std::map<Node, int> eqc;
534
1264
    std::map<int, std::vector<Node> > eqc_nodes;
535
    // partition based on top symbol
536
3295
    for (unsigned j = 0, size = terms.size(); j < size; j++)
537
    {
538
5326
      Node t = terms[j];
539
2663
      Trace("dt-nconst") << "    " << t << ", cdt=" << cdts[t] << std::endl;
540
      int e;
541
2663
      if (cdts[t])
542
      {
543
2087
        Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
544
4174
        Node op = t.getOperator();
545
2087
        std::map<Node, int>::iterator it = eqc_op_map.find(op);
546
2087
        if (it == eqc_op_map.end())
547
        {
548
993
          e = eqc_count;
549
993
          eqc_op_map[op] = eqc_count;
550
993
          eqc_count++;
551
        }
552
        else
553
        {
554
1094
          e = it->second;
555
        }
556
      }
557
      else
558
      {
559
576
        e = eqc_count;
560
576
        eqc_count++;
561
      }
562
2663
      eqc[t] = e;
563
2663
      eqc_nodes[e].push_back(t);
564
    }
565
    // partition until fixed point
566
632
    int eqc_curr = 0;
567
632
    bool success = true;
568
3412
    while (success)
569
    {
570
1390
      success = false;
571
1390
      int eqc_end = eqc_count;
572
7670
      while (eqc_curr < eqc_end)
573
      {
574
3140
        if (eqc_nodes[eqc_curr].size() > 1)
575
        {
576
          // look at all nodes in this equivalence class
577
893
          unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
578
1786
          std::map<int, std::vector<Node> > prt;
579
1832
          for (unsigned j = 0; j < nchildren; j++)
580
          {
581
1715
            prt.clear();
582
            // partition based on children : for the first child that causes a
583
            // split, break
584
7600
            for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
585
                 k++)
586
            {
587
11770
              Node t = eqc_nodes[eqc_curr][k];
588
5885
              Assert(t.getNumChildren() == nchildren);
589
11770
              Node tc = t[j];
590
              // refer to loops
591
5885
              std::map<Node, Node>::iterator itr = rf.find(tc);
592
5885
              if (itr != rf.end())
593
              {
594
125
                tc = itr->second;
595
              }
596
5885
              Assert(eqc.find(tc) != eqc.end());
597
5885
              prt[eqc[tc]].push_back(t);
598
            }
599
1715
            if (prt.size() > 1)
600
            {
601
776
              success = true;
602
776
              break;
603
            }
604
          }
605
          // move into new eqc(s)
606
2580
          for (const std::pair<const int, std::vector<Node> >& p : prt)
607
          {
608
1687
            int e = eqc_count;
609
4712
            for (unsigned j = 0, size = p.second.size(); j < size; j++)
610
            {
611
6050
              Node t = p.second[j];
612
3025
              eqc[t] = e;
613
3025
              eqc_nodes[e].push_back(t);
614
            }
615
1687
            eqc_count++;
616
          }
617
        }
618
3140
        eqc_nodes.erase(eqc_curr);
619
3140
        eqc_curr++;
620
      }
621
    }
622
    // add in already occurring loop variables
623
904
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
624
    {
625
544
      Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
626
272
                               << " -> " << it->second << std::endl;
627
272
      Assert(eqc.find(it->second) != eqc.end());
628
272
      eqc[it->first] = eqc[it->second];
629
    }
630
    // we now have a partition of equivalent terms
631
632
    Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
632
3567
    for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
633
    {
634
5870
      Trace("dt-nconst") << "  " << it->first << " -> " << it->second
635
2935
                         << std::endl;
636
    }
637
    // traverse top-down based on equivalence class
638
1264
    std::map<int, int> eqc_stack;
639
632
    return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
640
  }
641
905
  Trace("dt-nconst") << "...invalid." << std::endl;
642
905
  return Node::null();
643
}
644
645
// normalize constant : apply to top-level codatatype constants
646
476209
Node DatatypesRewriter::normalizeConstant(Node n)
647
{
648
952418
  TypeNode tn = n.getType();
649
476209
  if (tn.isDatatype())
650
  {
651
445792
    if (tn.isCodatatype())
652
    {
653
202
      return normalizeCodatatypeConstant(n);
654
    }
655
    else
656
    {
657
891180
      std::vector<Node> children;
658
445590
      bool childrenChanged = false;
659
838881
      for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
660
      {
661
786582
        Node nc = normalizeConstant(n[i]);
662
393291
        children.push_back(nc);
663
393291
        childrenChanged = childrenChanged || nc != n[i];
664
      }
665
445590
      if (childrenChanged)
666
      {
667
        return NodeManager::currentNM()->mkNode(n.getKind(), children);
668
      }
669
    }
670
  }
671
476007
  return n;
672
}
673
674
7758
Node DatatypesRewriter::collectRef(Node n,
675
                                   std::vector<Node>& sk,
676
                                   std::map<Node, Node>& rf,
677
                                   std::vector<Node>& rf_pending,
678
                                   std::vector<Node>& terms,
679
                                   std::map<Node, bool>& cdts)
680
{
681
7758
  Assert(n.isConst());
682
15516
  TypeNode tn = n.getType();
683
15516
  Node ret = n;
684
7758
  bool isCdt = false;
685
7758
  if (tn.isDatatype())
686
  {
687
4788
    if (!tn.isCodatatype())
688
    {
689
      // nested datatype within codatatype : can be normalized independently
690
      // since all loops should be self-contained
691
42
      ret = normalizeConstant(n);
692
    }
693
    else
694
    {
695
4746
      isCdt = true;
696
4746
      if (n.getKind() == kind::APPLY_CONSTRUCTOR)
697
      {
698
3569
        sk.push_back(n);
699
3569
        rf_pending.push_back(Node::null());
700
5683
        std::vector<Node> children;
701
3569
        children.push_back(n.getOperator());
702
3569
        bool childChanged = false;
703
8335
        for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
704
        {
705
10987
          Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
706
6221
          if (nc.isNull())
707
          {
708
1455
            return Node::null();
709
          }
710
4766
          childChanged = nc != n[i] || childChanged;
711
4766
          children.push_back(nc);
712
        }
713
2114
        sk.pop_back();
714
2114
        if (childChanged)
715
        {
716
679
          ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
717
                                                 children);
718
679
          if (!rf_pending.back().isNull())
719
          {
720
272
            rf[rf_pending.back()] = ret;
721
          }
722
        }
723
        else
724
        {
725
1435
          Assert(rf_pending.back().isNull());
726
        }
727
2114
        rf_pending.pop_back();
728
      }
729
      else
730
      {
731
        // a loop
732
1177
        const Integer& i = n.getConst<CodatatypeBoundVariable>().getIndex();
733
1177
        uint32_t index = i.toUnsignedInt();
734
1177
        if (index >= sk.size())
735
        {
736
905
          return Node::null();
737
        }
738
272
        Assert(sk.size() == rf_pending.size());
739
544
        Node r = rf_pending[rf_pending.size() - 1 - index];
740
272
        if (r.isNull())
741
        {
742
544
          r = NodeManager::currentNM()->mkBoundVar(
743
544
              sk[rf_pending.size() - 1 - index].getType());
744
272
          rf_pending[rf_pending.size() - 1 - index] = r;
745
        }
746
272
        return r;
747
      }
748
    }
749
  }
750
10252
  Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
751
5126
                           << std::endl;
752
5126
  if (std::find(terms.begin(), terms.end(), ret) == terms.end())
753
  {
754
3592
    terms.push_back(ret);
755
3592
    Assert(ret.getType() == tn);
756
3592
    cdts[ret] = isCdt;
757
  }
758
5126
  return ret;
759
}
760
// eqc_stack stores depth
761
3389
Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
762
    Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
763
{
764
6778
  Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
765
3389
                           << " depth=" << depth << std::endl;
766
3389
  if (eqc.find(n) != eqc.end())
767
  {
768
3351
    int e = eqc[n];
769
3351
    std::map<int, int>::iterator it = eqc_stack.find(e);
770
3351
    if (it != eqc_stack.end())
771
    {
772
272
      int debruijn = depth - it->second - 1;
773
      return NodeManager::currentNM()->mkConst(
774
272
          CodatatypeBoundVariable(n.getType(), debruijn));
775
    }
776
5779
    std::vector<Node> children;
777
3079
    bool childChanged = false;
778
3079
    eqc_stack[e] = depth;
779
5836
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
780
    {
781
5514
      Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
782
2757
      children.push_back(nc);
783
2757
      childChanged = childChanged || nc != n[i];
784
    }
785
3079
    eqc_stack.erase(e);
786
3079
    if (childChanged)
787
    {
788
379
      Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
789
379
      children.insert(children.begin(), n.getOperator());
790
379
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
791
    }
792
  }
793
2738
  return n;
794
}
795
796
66
Node DatatypesRewriter::replaceDebruijn(Node n,
797
                                        Node orig,
798
                                        TypeNode orig_tn,
799
                                        unsigned depth)
800
{
801
66
  if (n.getKind() == kind::CODATATYPE_BOUND_VARIABLE && n.getType() == orig_tn)
802
  {
803
    unsigned index =
804
        n.getConst<CodatatypeBoundVariable>().getIndex().toUnsignedInt();
805
    if (index == depth)
806
    {
807
      return orig;
808
    }
809
  }
810
66
  else if (n.getNumChildren() > 0)
811
  {
812
52
    std::vector<Node> children;
813
26
    bool childChanged = false;
814
78
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
815
    {
816
104
      Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
817
52
      children.push_back(nc);
818
52
      childChanged = childChanged || nc != n[i];
819
    }
820
26
    if (childChanged)
821
    {
822
      if (n.hasOperator())
823
      {
824
        children.insert(children.begin(), n.getOperator());
825
      }
826
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
827
    }
828
  }
829
66
  return n;
830
}
831
832
5122
Node DatatypesRewriter::expandApplySelector(Node n)
833
{
834
5122
  Assert(n.getKind() == APPLY_SELECTOR);
835
10244
  Node selector = n.getOperator();
836
  // APPLY_SELECTOR always applies to an external selector, cindexOf is
837
  // legal here
838
5122
  size_t cindex = utils::cindexOf(selector);
839
5122
  const DType& dt = utils::datatypeOf(selector);
840
5122
  const DTypeConstructor& c = dt[cindex];
841
10244
  Node selector_use;
842
10244
  TypeNode ndt = n[0].getType();
843
5122
  if (options::dtSharedSelectors())
844
  {
845
209
    size_t selectorIndex = utils::indexOf(selector);
846
209
    Trace("dt-expand") << "...selector index = " << selectorIndex << std::endl;
847
209
    Assert(selectorIndex < c.getNumArgs());
848
209
    selector_use = c.getSelectorInternal(ndt, selectorIndex);
849
  }
850
  else
851
  {
852
4913
    selector_use = selector;
853
  }
854
5122
  NodeManager* nm = NodeManager::currentNM();
855
10244
  Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
856
5122
  if (options::dtRewriteErrorSel())
857
  {
858
36
    return sel;
859
  }
860
10172
  Node tester = c.getTester();
861
10172
  Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]);
862
5086
  SkolemManager* sm = nm->getSkolemManager();
863
10172
  TypeNode tnw = nm->mkFunctionType(ndt, n.getType());
864
10172
  Node f = sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector);
865
10172
  Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]);
866
10172
  Node ret = nm->mkNode(kind::ITE, tst, sel, sk);
867
5086
  Trace("dt-expand") << "Expand def : " << n << " to " << ret << std::endl;
868
5086
  return ret;
869
}
870
871
78845
TrustNode DatatypesRewriter::expandDefinition(Node n)
872
{
873
78845
  NodeManager* nm = NodeManager::currentNM();
874
157690
  TypeNode tn = n.getType();
875
157690
  Node ret;
876
78845
  switch (n.getKind())
877
  {
878
4809
    case kind::APPLY_SELECTOR:
879
    {
880
4809
      ret = expandApplySelector(n);
881
    }
882
4809
    break;
883
26
    case APPLY_UPDATER:
884
    {
885
26
      Assert(tn.isDatatype());
886
26
      const DType& dt = tn.getDType();
887
52
      Node op = n.getOperator();
888
26
      size_t updateIndex = utils::indexOf(op);
889
26
      size_t cindex = utils::cindexOf(op);
890
26
      const DTypeConstructor& dc = dt[cindex];
891
52
      NodeBuilder b(APPLY_CONSTRUCTOR);
892
26
      b << dc.getConstructor();
893
26
      Trace("dt-expand") << "Expand updater " << n << std::endl;
894
26
      Trace("dt-expand") << "expr is " << n << std::endl;
895
26
      Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl;
896
26
      Trace("dt-expand") << "t is " << tn << std::endl;
897
76
      for (size_t i = 0, size = dc.getNumArgs(); i < size; ++i)
898
      {
899
50
        if (i == updateIndex)
900
        {
901
26
          b << n[1];
902
        }
903
        else
904
        {
905
120
          b << nm->mkNode(
906
96
              APPLY_SELECTOR_TOTAL, dc.getSelectorInternal(tn, i), n[0]);
907
        }
908
      }
909
26
      ret = b;
910
26
      if (dt.getNumConstructors() > 1)
911
      {
912
        // must be the right constructor to update
913
20
        Node tester = nm->mkNode(APPLY_TESTER, dc.getTester(), n[0]);
914
10
        ret = nm->mkNode(ITE, tester, ret, n[0]);
915
      }
916
52
      Trace("dt-expand") << "return " << ret << std::endl;
917
    }
918
26
    break;
919
74010
    default: break;
920
  }
921
78845
  if (!ret.isNull() && n != ret)
922
  {
923
4835
    return TrustNode::mkTrustRewrite(n, ret, nullptr);
924
  }
925
74010
  return TrustNode::null();
926
}
927
928
89168
Node DatatypesRewriter::sygusToBuiltinEval(Node n,
929
                                           const std::vector<Node>& args)
930
{
931
89168
  Assert(d_sygusEval != nullptr);
932
89168
  NodeManager* nm = NodeManager::currentNM();
933
  // constant arguments?
934
89168
  bool constArgs = true;
935
370510
  for (const Node& a : args)
936
  {
937
287881
    if (!a.isConst())
938
    {
939
6539
      constArgs = false;
940
6539
      break;
941
    }
942
  }
943
178336
  std::vector<Node> eargs;
944
89168
  bool svarsInit = false;
945
178336
  std::vector<Node> svars;
946
178336
  std::unordered_map<TNode, Node> visited;
947
89168
  std::unordered_map<TNode, Node>::iterator it;
948
178336
  std::vector<TNode> visit;
949
178336
  TNode cur;
950
  unsigned index;
951
89168
  visit.push_back(n);
952
242797
  do
953
  {
954
331965
    cur = visit.back();
955
331965
    visit.pop_back();
956
331965
    it = visited.find(cur);
957
331965
    if (it == visited.end())
958
    {
959
492398
      TypeNode tn = cur.getType();
960
246199
      if (!tn.isDatatype() || !tn.getDType().isSygus())
961
      {
962
114
        visited[cur] = cur;
963
      }
964
246085
      else if (cur.isConst())
965
      {
966
        // convert to builtin term
967
234800
        Node bt = utils::sygusToBuiltin(cur);
968
        // run the evaluator if possible
969
117400
        if (!svarsInit)
970
        {
971
88319
          svarsInit = true;
972
176638
          TypeNode type = cur.getType();
973
176638
          Node varList = type.getDType().getSygusVarList();
974
387560
          for (const Node& v : varList)
975
          {
976
299241
            svars.push_back(v);
977
          }
978
        }
979
117400
        Assert(args.size() == svars.size());
980
        // try evaluation if we have constant arguments
981
        Node ret =
982
234800
            constArgs ? d_sygusEval->eval(bt, svars, args) : Node::null();
983
117400
        if (ret.isNull())
984
        {
985
          // if evaluation was not available, use a substitution
986
9914
          ret = bt.substitute(
987
              svars.begin(), svars.end(), args.begin(), args.end());
988
        }
989
117400
        visited[cur] = ret;
990
      }
991
      else
992
      {
993
128685
        if (cur.getKind() == APPLY_CONSTRUCTOR)
994
        {
995
80424
          visited[cur] = Node::null();
996
80424
          visit.push_back(cur);
997
242797
          for (const Node& cn : cur)
998
          {
999
162373
            visit.push_back(cn);
1000
          }
1001
        }
1002
        else
1003
        {
1004
          // it is the evaluation of this term on the arguments
1005
48261
          if (eargs.empty())
1006
          {
1007
42777
            eargs.push_back(cur);
1008
42777
            eargs.insert(eargs.end(), args.begin(), args.end());
1009
          }
1010
          else
1011
          {
1012
5484
            eargs[0] = cur;
1013
          }
1014
48261
          visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs);
1015
        }
1016
      }
1017
    }
1018
85766
    else if (it->second.isNull())
1019
    {
1020
160848
      Node ret = cur;
1021
80424
      Assert(cur.getKind() == APPLY_CONSTRUCTOR);
1022
80424
      const DType& dt = cur.getType().getDType();
1023
      // non sygus-datatype terms are also themselves
1024
80424
      if (dt.isSygus())
1025
      {
1026
160848
        std::vector<Node> children;
1027
242797
        for (const Node& cn : cur)
1028
        {
1029
162373
          it = visited.find(cn);
1030
162373
          Assert(it != visited.end());
1031
162373
          Assert(!it->second.isNull());
1032
162373
          children.push_back(it->second);
1033
        }
1034
80424
        index = utils::indexOf(cur.getOperator());
1035
        // apply to children, which constructs the builtin term
1036
80424
        ret = utils::mkSygusTerm(dt, index, children);
1037
        // now apply it to arguments in args
1038
80424
        ret = utils::applySygusArgs(dt, dt[index].getSygusOp(), ret, args);
1039
      }
1040
80424
      visited[cur] = ret;
1041
    }
1042
331965
  } while (!visit.empty());
1043
89168
  Assert(visited.find(n) != visited.end());
1044
89168
  Assert(!visited.find(n)->second.isNull());
1045
178336
  return visited[n];
1046
}
1047
1048
}  // namespace datatypes
1049
}  // namespace theory
1050
31137
}  // namespace cvc5