GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/datatypes/datatypes_rewriter.cpp Lines: 552 583 94.7 %
Date: 2021-11-05 Branches: 1207 2649 45.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mudathir Mohamed, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of rewriter for the theory of (co)inductive datatypes.
14
 */
15
16
#include "theory/datatypes/datatypes_rewriter.h"
17
18
#include "expr/ascription_type.h"
19
#include "expr/codatatype_bound_variable.h"
20
#include "expr/dtype.h"
21
#include "expr/dtype_cons.h"
22
#include "expr/node_algorithm.h"
23
#include "expr/skolem_manager.h"
24
#include "expr/sygus_datatype.h"
25
#include "options/datatypes_options.h"
26
#include "theory/datatypes/sygus_datatype_utils.h"
27
#include "theory/datatypes/theory_datatypes_utils.h"
28
#include "theory/datatypes/tuple_project_op.h"
29
#include "util/rational.h"
30
31
using namespace cvc5;
32
using namespace cvc5::kind;
33
34
namespace cvc5 {
35
namespace theory {
36
namespace datatypes {
37
38
15271
DatatypesRewriter::DatatypesRewriter(Evaluator* sygusEval)
39
15271
    : d_sygusEval(sygusEval)
40
{
41
15271
}
42
43
622194
RewriteResponse DatatypesRewriter::postRewrite(TNode in)
44
{
45
622194
  Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
46
622194
  Kind kind = in.getKind();
47
622194
  NodeManager* nm = NodeManager::currentNM();
48
622194
  if (kind == kind::APPLY_CONSTRUCTOR)
49
  {
50
177405
    return rewriteConstructor(in);
51
  }
52
444789
  else if (kind == kind::APPLY_SELECTOR_TOTAL || kind == kind::APPLY_SELECTOR)
53
  {
54
93933
    return rewriteSelector(in);
55
  }
56
350856
  else if (kind == kind::APPLY_TESTER)
57
  {
58
43440
    return rewriteTester(in);
59
  }
60
307416
  else if (kind == APPLY_UPDATER)
61
  {
62
63
    return rewriteUpdater(in);
63
  }
64
307353
  else if (kind == kind::DT_SIZE)
65
  {
66
26888
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
67
    {
68
30700
      std::vector<Node> children;
69
42420
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
70
      {
71
27070
        if (in[0][i].getType().isDatatype())
72
        {
73
26903
          children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
74
        }
75
      }
76
30700
      TNode constructor = in[0].getOperator();
77
15350
      size_t constructorIndex = utils::indexOf(constructor);
78
15350
      const DType& dt = utils::datatypeOf(constructor);
79
15350
      const DTypeConstructor& c = dt[constructorIndex];
80
15350
      unsigned weight = c.getWeight();
81
15350
      children.push_back(nm->mkConst(Rational(weight)));
82
      Node res =
83
30700
          children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
84
30700
      Trace("datatypes-rewrite")
85
15350
          << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
86
15350
          << res << std::endl;
87
15350
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
88
    }
89
  }
90
280465
  else if (kind == kind::DT_HEIGHT_BOUND)
91
  {
92
    if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
93
    {
94
      std::vector<Node> children;
95
      Node res;
96
      Rational r = in[1].getConst<Rational>();
97
      Rational rmo = Rational(r - Rational(1));
98
      for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
99
      {
100
        if (in[0][i].getType().isDatatype())
101
        {
102
          if (r.isZero())
103
          {
104
            res = nm->mkConst(false);
105
            break;
106
          }
107
          children.push_back(
108
              nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
109
        }
110
      }
111
      if (res.isNull())
112
      {
113
        res = children.size() == 0
114
                  ? nm->mkConst(true)
115
                  : (children.size() == 1 ? children[0]
116
                                          : nm->mkNode(kind::AND, children));
117
      }
118
      Trace("datatypes-rewrite")
119
          << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
120
          << res << std::endl;
121
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
122
    }
123
  }
124
280465
  else if (kind == kind::DT_SIZE_BOUND)
125
  {
126
    if (in[0].isConst())
127
    {
128
      Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
129
      return RewriteResponse(REWRITE_AGAIN_FULL, res);
130
    }
131
  }
132
280465
  else if (kind == DT_SYGUS_EVAL)
133
  {
134
    // sygus evaluation function
135
133396
    Node ev = in[0];
136
111490
    if (ev.getKind() == APPLY_CONSTRUCTOR)
137
    {
138
89584
      Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
139
89584
      Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
140
179168
      std::vector<Node> args;
141
391413
      for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
142
      {
143
301829
        args.push_back(in[j]);
144
      }
145
179168
      Node ret = sygusToBuiltinEval(ev, args);
146
89584
      Trace("dt-sygus-util") << "...got " << ret << "\n";
147
89584
      Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
148
89584
      Assert(in.getType().isComparableTo(ret.getType()));
149
89584
      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
150
    }
151
  }
152
168975
  else if (kind == MATCH)
153
  {
154
12
    Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
155
24
    Node h = in[0];
156
24
    std::vector<Node> cases;
157
24
    std::vector<Node> rets;
158
24
    TypeNode t = h.getType();
159
12
    const DType& dt = t.getDType();
160
40
    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
161
    {
162
56
      Node c = in[k];
163
56
      Node cons;
164
28
      Kind ck = c.getKind();
165
28
      if (ck == MATCH_CASE)
166
      {
167
20
        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
168
20
        cons = c[0].getOperator();
169
      }
170
8
      else if (ck == MATCH_BIND_CASE)
171
      {
172
8
        if (c[1].getKind() == APPLY_CONSTRUCTOR)
173
        {
174
4
          cons = c[1].getOperator();
175
        }
176
      }
177
      else
178
      {
179
        AlwaysAssert(false);
180
      }
181
28
      size_t cindex = 0;
182
      // cons is null in the default case
183
28
      if (!cons.isNull())
184
      {
185
24
        cindex = utils::indexOf(cons);
186
      }
187
56
      Node body;
188
28
      if (ck == MATCH_CASE)
189
      {
190
20
        body = c[1];
191
      }
192
8
      else if (ck == MATCH_BIND_CASE)
193
      {
194
16
        std::vector<Node> vars;
195
16
        std::vector<Node> subs;
196
8
        if (cons.isNull())
197
        {
198
4
          Assert(c[1].getKind() == BOUND_VARIABLE);
199
4
          vars.push_back(c[1]);
200
4
          subs.push_back(h);
201
        }
202
        else
203
        {
204
12
          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
205
          {
206
8
            vars.push_back(c[0][i]);
207
            Node sc =
208
16
                nm->mkNode(APPLY_SELECTOR, dt[cindex][i].getSelector(), h);
209
8
            subs.push_back(sc);
210
          }
211
        }
212
8
        body =
213
16
            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
214
      }
215
28
      if (!cons.isNull())
216
      {
217
24
        cases.push_back(utils::mkTester(h, cindex, dt));
218
      }
219
      else
220
      {
221
        // variables have no constraints
222
4
        cases.push_back(nm->mkConst(true));
223
      }
224
28
      rets.push_back(body);
225
    }
226
12
    Assert(!cases.empty());
227
    // now make the ITE
228
12
    std::reverse(cases.begin(), cases.end());
229
12
    std::reverse(rets.begin(), rets.end());
230
24
    Node ret = rets[0];
231
12
    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
232
28
    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
233
    {
234
16
      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
235
    }
236
24
    Trace("dt-rewrite-match")
237
12
        << "Rewrite match: " << in << " ... " << ret << std::endl;
238
12
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
239
  }
240
168963
  else if (kind == TUPLE_PROJECT)
241
  {
242
    // returns a tuple that represents
243
    // (mkTuple ((_ tupSel i_1) t) ... ((_ tupSel i_n) t))
244
    // where each i_j is less than the length of t
245
246
6
    Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl;
247
12
    TupleProjectOp op = in.getOperator().getConst<TupleProjectOp>();
248
12
    std::vector<uint32_t> indices = op.getIndices();
249
12
    Node tuple = in[0];
250
12
    std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
251
12
    std::vector<TypeNode> types;
252
12
    std::vector<Node> elements;
253
12
    for (uint32_t index : indices)
254
    {
255
12
      TypeNode type = tupleTypes[index];
256
6
      types.push_back(type);
257
    }
258
12
    TypeNode projectType = nm->mkTupleType(types);
259
6
    const DType& dt = projectType.getDType();
260
6
    elements.push_back(dt[0].getConstructor());
261
6
    const DType& tupleDType = tuple.getType().getDType();
262
6
    const DTypeConstructor& constructor = tupleDType[0];
263
12
    for (uint32_t index : indices)
264
    {
265
12
      Node selector = constructor[index].getSelector();
266
12
      Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
267
6
      elements.push_back(element);
268
    }
269
12
    Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
270
271
12
    Trace("dt-rewrite-project")
272
6
        << "Rewrite project: " << in << " ... " << ret << std::endl;
273
6
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
274
  }
275
276
202401
  if (kind == kind::EQUAL)
277
  {
278
167221
    if (in[0] == in[1])
279
    {
280
2876
      return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
281
    }
282
293808
    std::vector<Node> rew;
283
164345
    if (utils::checkClash(in[0], in[1], rew))
284
    {
285
3542
      Trace("datatypes-rewrite")
286
1771
          << "Rewrite clashing equality " << in << " to false" << std::endl;
287
1771
      return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
288
      //}else if( rew.size()==1 && rew[0]!=in ){
289
      //  Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
290
      //  rew[0] << std::endl;
291
      //  return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
292
    }
293
162574
    else if (in[1] < in[0])
294
    {
295
66222
      Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
296
66222
      Trace("datatypes-rewrite")
297
33111
          << "Swap equality " << in << " to " << ins << std::endl;
298
33111
      return RewriteResponse(REWRITE_DONE, ins);
299
    }
300
258926
    Trace("datatypes-rewrite-debug")
301
258926
        << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
302
129463
        << in[1].getKind() << std::endl;
303
  }
304
305
164643
  return RewriteResponse(REWRITE_DONE, in);
306
}
307
308
371514
RewriteResponse DatatypesRewriter::preRewrite(TNode in)
309
{
310
371514
  Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
311
  // must prewrite to apply type ascriptions since rewriting does not preserve
312
  // types
313
371514
  if (in.getKind() == kind::APPLY_CONSTRUCTOR)
314
  {
315
167483
    TypeNode tn = in.getType();
316
317
    // check for parametric datatype constructors
318
    // to ensure a normal form, all parameteric datatype constructors must have
319
    // a type ascription
320
83759
    if (tn.isParametricDatatype())
321
    {
322
274
      if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
323
      {
324
70
        Trace("datatypes-rewrite-debug")
325
35
            << "Ascribing type to parametric datatype constructor " << in
326
35
            << std::endl;
327
70
        Node op = in.getOperator();
328
        // get the constructor object
329
35
        const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)];
330
        // create ascribed constructor type
331
        Node tc = NodeManager::currentNM()->mkConst(
332
70
            AscriptionType(dtc.getSpecializedConstructorType(tn)));
333
        Node op_new = NodeManager::currentNM()->mkNode(
334
70
            kind::APPLY_TYPE_ASCRIPTION, tc, op);
335
        // make new node
336
70
        std::vector<Node> children;
337
35
        children.push_back(op_new);
338
35
        children.insert(children.end(), in.begin(), in.end());
339
        Node inr =
340
70
            NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
341
35
        Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
342
35
        return RewriteResponse(REWRITE_DONE, inr);
343
      }
344
    }
345
  }
346
371479
  return RewriteResponse(REWRITE_DONE, in);
347
}
348
349
177405
RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
350
{
351
177405
  if (in.isConst())
352
  {
353
164880
    Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
354
82440
                                     << std::endl;
355
164880
    Node inn = normalizeConstant(in);
356
    // constant may be a subterm of another constant, so cannot assume that this
357
    // will succeed for codatatypes
358
    // Assert( !inn.isNull() );
359
82440
    if (!inn.isNull() && inn != in)
360
    {
361
16
      Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
362
8
                                 << inn << std::endl;
363
8
      return RewriteResponse(REWRITE_DONE, inn);
364
    }
365
82432
    return RewriteResponse(REWRITE_DONE, in);
366
  }
367
94965
  return RewriteResponse(REWRITE_DONE, in);
368
}
369
370
93933
RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
371
{
372
93933
  Kind k = in.getKind();
373
93933
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
374
  {
375
    // Have to be careful not to rewrite well-typed expressions
376
    // where the selector doesn't match the constructor,
377
    // e.g. "pred(zero)".
378
31906
    TypeNode tn = in.getType();
379
31906
    TypeNode argType = in[0].getType();
380
31906
    Node selector = in.getOperator();
381
31906
    TNode constructor = in[0].getOperator();
382
29510
    size_t constructorIndex = utils::indexOf(constructor);
383
29510
    const DType& dt = utils::datatypeOf(selector);
384
29510
    const DTypeConstructor& c = dt[constructorIndex];
385
59020
    Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
386
29510
                                     << in;
387
59020
    Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
388
29510
                                     << ", selector is " << selector
389
29510
                                     << std::endl;
390
    // The argument that the selector extracts, or -1 if the selector is
391
    // is wrongly applied.
392
29510
    int selectorIndex = -1;
393
29510
    if (k == kind::APPLY_SELECTOR_TOTAL)
394
    {
395
      // The argument index of internal selectors is obtained by
396
      // getSelectorIndexInternal.
397
24318
      selectorIndex = c.getSelectorIndexInternal(selector);
398
    }
399
    else
400
    {
401
      // The argument index of external selectors (applications of
402
      // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
403
      // The argument is only valid if it is the proper constructor.
404
5192
      selectorIndex = utils::indexOf(selector);
405
10384
      if (selectorIndex < 0
406
5192
          || selectorIndex >= static_cast<int>(c.getNumArgs()))
407
      {
408
894
        selectorIndex = -1;
409
      }
410
4298
      else if (c[selectorIndex].getSelector() != selector)
411
      {
412
1502
        selectorIndex = -1;
413
      }
414
    }
415
59020
    Trace("datatypes-rewrite-debug") << "Internal selector index is "
416
29510
                                     << selectorIndex << std::endl;
417
29510
    if (selectorIndex >= 0)
418
    {
419
25473
      Assert(selectorIndex < (int)c.getNumArgs());
420
25473
      if (dt.isCodatatype() && in[0][selectorIndex].isConst())
421
      {
422
        // must replace all debruijn indices with self
423
14
        Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
424
28
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
425
14
                                   << "Rewrite trivial codatatype selector "
426
14
                                   << in << " to " << sub << std::endl;
427
14
        if (sub != in)
428
        {
429
14
          return RewriteResponse(REWRITE_AGAIN_FULL, sub);
430
        }
431
      }
432
      else
433
      {
434
50918
        Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
435
25459
                                   << "Rewrite trivial selector " << in
436
25459
                                   << std::endl;
437
25459
        return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
438
      }
439
    }
440
4037
    else if (k == kind::APPLY_SELECTOR_TOTAL)
441
    {
442
      // evaluates to the first ground value of type tn.
443
3282
      Node gt = tn.mkGroundValue();
444
1641
      Assert(!gt.isNull());
445
1641
      if (tn.isDatatype() && !tn.isInstantiatedDatatype())
446
      {
447
        gt = NodeManager::currentNM()->mkNode(
448
            kind::APPLY_TYPE_ASCRIPTION,
449
            NodeManager::currentNM()->mkConst(AscriptionType(tn)),
450
            gt);
451
      }
452
3282
      Trace("datatypes-rewrite")
453
1641
          << "DatatypesRewriter::postRewrite: "
454
1641
          << "Rewrite trivial selector " << in
455
1641
          << " to distinguished ground term " << gt << std::endl;
456
1641
      return RewriteResponse(REWRITE_DONE, gt);
457
    }
458
  }
459
66819
  return RewriteResponse(REWRITE_DONE, in);
460
}
461
462
43440
RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
463
{
464
43440
  if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
465
  {
466
    bool result =
467
2094
        utils::indexOf(in.getOperator()) == utils::indexOf(in[0].getOperator());
468
4188
    Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
469
2094
                               << "Rewrite trivial tester " << in << " "
470
2094
                               << result << std::endl;
471
    return RewriteResponse(REWRITE_DONE,
472
2094
                           NodeManager::currentNM()->mkConst(result));
473
  }
474
41346
  const DType& dt = in[0].getType().getDType();
475
41346
  if (dt.getNumConstructors() == 1 && !dt.isSygus())
476
  {
477
    // only one constructor, so it must be
478
2472
    Trace("datatypes-rewrite")
479
1236
        << "DatatypesRewriter::postRewrite: "
480
2472
        << "only one ctor for " << dt.getName() << " and that is "
481
1236
        << dt[0].getName() << std::endl;
482
    return RewriteResponse(REWRITE_DONE,
483
1236
                           NodeManager::currentNM()->mkConst(true));
484
  }
485
  // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
486
40110
  return RewriteResponse(REWRITE_DONE, in);
487
}
488
489
63
RewriteResponse DatatypesRewriter::rewriteUpdater(TNode in)
490
{
491
63
  Assert (in.getKind()==APPLY_UPDATER);
492
63
  if (in[0].getKind() == APPLY_CONSTRUCTOR)
493
  {
494
18
    Node op = in.getOperator();
495
9
    size_t cindex = utils::indexOf(in[0].getOperator());
496
9
    size_t cuindex = utils::cindexOf(op);
497
9
    if (cindex==cuindex)
498
    {
499
9
      NodeManager * nm = NodeManager::currentNM();
500
9
      size_t updateIndex = utils::indexOf(op);
501
18
      std::vector<Node> children(in[0].begin(), in[0].end());
502
9
      children[updateIndex] = in[1];
503
9
      children.insert(children.begin(),in[0].getOperator());
504
9
      return RewriteResponse(REWRITE_DONE, nm->mkNode(APPLY_CONSTRUCTOR, children));
505
    }
506
    return RewriteResponse(REWRITE_DONE, in[0]);
507
  }
508
54
  return RewriteResponse(REWRITE_DONE, in);
509
}
510
511
810
Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
512
{
513
810
  Trace("dt-nconst") << "Normalize " << n << std::endl;
514
1620
  std::map<Node, Node> rf;
515
1620
  std::vector<Node> sk;
516
1620
  std::vector<Node> rf_pending;
517
1620
  std::vector<Node> terms;
518
1620
  std::map<Node, bool> cdts;
519
1620
  Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
520
810
  if (!s.isNull())
521
  {
522
522
    Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
523
740
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
524
    {
525
436
      Trace("dt-nconst") << "  " << it->first << " = " << it->second
526
218
                         << std::endl;
527
    }
528
    // now run DFA minimization on term structure
529
1044
    Trace("dt-nconst") << "  " << terms.size()
530
522
                       << " total subterms :" << std::endl;
531
522
    int eqc_count = 0;
532
1044
    std::map<Node, int> eqc_op_map;
533
1044
    std::map<Node, int> eqc;
534
1044
    std::map<int, std::vector<Node> > eqc_nodes;
535
    // partition based on top symbol
536
2450
    for (unsigned j = 0, size = terms.size(); j < size; j++)
537
    {
538
3856
      Node t = terms[j];
539
1928
      Trace("dt-nconst") << "    " << t << ", cdt=" << cdts[t] << std::endl;
540
      int e;
541
1928
      if (cdts[t])
542
      {
543
1461
        Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
544
2922
        Node op = t.getOperator();
545
1461
        std::map<Node, int>::iterator it = eqc_op_map.find(op);
546
1461
        if (it == eqc_op_map.end())
547
        {
548
825
          e = eqc_count;
549
825
          eqc_op_map[op] = eqc_count;
550
825
          eqc_count++;
551
        }
552
        else
553
        {
554
636
          e = it->second;
555
        }
556
      }
557
      else
558
      {
559
467
        e = eqc_count;
560
467
        eqc_count++;
561
      }
562
1928
      eqc[t] = e;
563
1928
      eqc_nodes[e].push_back(t);
564
    }
565
    // partition until fixed point
566
522
    int eqc_curr = 0;
567
522
    bool success = true;
568
2500
    while (success)
569
    {
570
989
      success = false;
571
989
      int eqc_end = eqc_count;
572
5543
      while (eqc_curr < eqc_end)
573
      {
574
2277
        if (eqc_nodes[eqc_curr].size() > 1)
575
        {
576
          // look at all nodes in this equivalence class
577
557
          unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
578
1114
          std::map<int, std::vector<Node> > prt;
579
1130
          for (unsigned j = 0; j < nchildren; j++)
580
          {
581
1056
            prt.clear();
582
            // partition based on children : for the first child that causes a
583
            // split, break
584
4183
            for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
585
                 k++)
586
            {
587
6254
              Node t = eqc_nodes[eqc_curr][k];
588
3127
              Assert(t.getNumChildren() == nchildren);
589
6254
              Node tc = t[j];
590
              // refer to loops
591
3127
              std::map<Node, Node>::iterator itr = rf.find(tc);
592
3127
              if (itr != rf.end())
593
              {
594
78
                tc = itr->second;
595
              }
596
3127
              Assert(eqc.find(tc) != eqc.end());
597
3127
              prt[eqc[tc]].push_back(t);
598
            }
599
1056
            if (prt.size() > 1)
600
            {
601
483
              success = true;
602
483
              break;
603
            }
604
          }
605
          // move into new eqc(s)
606
1615
          for (const std::pair<const int, std::vector<Node> >& p : prt)
607
          {
608
1058
            int e = eqc_count;
609
2691
            for (unsigned j = 0, size = p.second.size(); j < size; j++)
610
            {
611
3266
              Node t = p.second[j];
612
1633
              eqc[t] = e;
613
1633
              eqc_nodes[e].push_back(t);
614
            }
615
1058
            eqc_count++;
616
          }
617
        }
618
2277
        eqc_nodes.erase(eqc_curr);
619
2277
        eqc_curr++;
620
      }
621
    }
622
    // add in already occurring loop variables
623
740
    for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
624
    {
625
436
      Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
626
218
                               << " -> " << it->second << std::endl;
627
218
      Assert(eqc.find(it->second) != eqc.end());
628
218
      eqc[it->first] = eqc[it->second];
629
    }
630
    // we now have a partition of equivalent terms
631
522
    Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
632
2668
    for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
633
    {
634
4292
      Trace("dt-nconst") << "  " << it->first << " -> " << it->second
635
2146
                         << std::endl;
636
    }
637
    // traverse top-down based on equivalence class
638
1044
    std::map<int, int> eqc_stack;
639
522
    return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
640
  }
641
288
  Trace("dt-nconst") << "...invalid." << std::endl;
642
288
  return Node::null();
643
}
644
645
// normalize constant : apply to top-level codatatype constants
646
473777
Node DatatypesRewriter::normalizeConstant(Node n)
647
{
648
947554
  TypeNode tn = n.getType();
649
473777
  if (tn.isDatatype())
650
  {
651
443530
    if (tn.isCodatatype())
652
    {
653
198
      return normalizeCodatatypeConstant(n);
654
    }
655
    else
656
    {
657
886664
      std::vector<Node> children;
658
443332
      bool childrenChanged = false;
659
834627
      for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
660
      {
661
782590
        Node nc = normalizeConstant(n[i]);
662
391295
        children.push_back(nc);
663
391295
        childrenChanged = childrenChanged || nc != n[i];
664
      }
665
443332
      if (childrenChanged)
666
      {
667
        return NodeManager::currentNM()->mkNode(n.getKind(), children);
668
      }
669
    }
670
  }
671
473579
  return n;
672
}
673
674
3817
Node DatatypesRewriter::collectRef(Node n,
675
                                   std::vector<Node>& sk,
676
                                   std::map<Node, Node>& rf,
677
                                   std::vector<Node>& rf_pending,
678
                                   std::vector<Node>& terms,
679
                                   std::map<Node, bool>& cdts)
680
{
681
3817
  Assert(n.isConst());
682
7634
  TypeNode tn = n.getType();
683
7634
  Node ret = n;
684
3817
  bool isCdt = false;
685
3817
  if (tn.isDatatype())
686
  {
687
2446
    if (!tn.isCodatatype())
688
    {
689
      // nested datatype within codatatype : can be normalized independently
690
      // since all loops should be self-contained
691
42
      ret = normalizeConstant(n);
692
    }
693
    else
694
    {
695
2404
      isCdt = true;
696
2404
      if (n.getKind() == kind::APPLY_CONSTRUCTOR)
697
      {
698
1898
        sk.push_back(n);
699
1898
        rf_pending.push_back(Node::null());
700
3386
        std::vector<Node> children;
701
1898
        children.push_back(n.getOperator());
702
1898
        bool childChanged = false;
703
4495
        for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
704
        {
705
5604
          Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
706
3007
          if (nc.isNull())
707
          {
708
410
            return Node::null();
709
          }
710
2597
          childChanged = nc != n[i] || childChanged;
711
2597
          children.push_back(nc);
712
        }
713
1488
        sk.pop_back();
714
1488
        if (childChanged)
715
        {
716
445
          ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
717
                                                 children);
718
445
          if (!rf_pending.back().isNull())
719
          {
720
218
            rf[rf_pending.back()] = ret;
721
          }
722
        }
723
        else
724
        {
725
1043
          Assert(rf_pending.back().isNull());
726
        }
727
1488
        rf_pending.pop_back();
728
      }
729
      else
730
      {
731
        // a loop
732
506
        const Integer& i = n.getConst<CodatatypeBoundVariable>().getIndex();
733
506
        uint32_t index = i.toUnsignedInt();
734
506
        if (index >= sk.size())
735
        {
736
288
          return Node::null();
737
        }
738
218
        Assert(sk.size() == rf_pending.size());
739
436
        Node r = rf_pending[rf_pending.size() - 1 - index];
740
218
        if (r.isNull())
741
        {
742
436
          r = NodeManager::currentNM()->mkBoundVar(
743
436
              sk[rf_pending.size() - 1 - index].getType());
744
218
          rf_pending[rf_pending.size() - 1 - index] = r;
745
        }
746
218
        return r;
747
      }
748
    }
749
  }
750
5802
  Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
751
2901
                           << std::endl;
752
2901
  if (std::find(terms.begin(), terms.end(), ret) == terms.end())
753
  {
754
2240
    terms.push_back(ret);
755
2240
    Assert(ret.getType() == tn);
756
2240
    cdts[ret] = isCdt;
757
  }
758
2901
  return ret;
759
}
760
// eqc_stack stores depth
761
2485
Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
762
    Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
763
{
764
4970
  Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
765
2485
                           << " depth=" << depth << std::endl;
766
2485
  if (eqc.find(n) != eqc.end())
767
  {
768
2447
    int e = eqc[n];
769
2447
    std::map<int, int>::iterator it = eqc_stack.find(e);
770
2447
    if (it != eqc_stack.end())
771
    {
772
218
      int debruijn = depth - it->second - 1;
773
      return NodeManager::currentNM()->mkConst(
774
218
          CodatatypeBoundVariable(n.getType(), debruijn));
775
    }
776
4148
    std::vector<Node> children;
777
2229
    bool childChanged = false;
778
2229
    eqc_stack[e] = depth;
779
4192
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
780
    {
781
3926
      Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
782
1963
      children.push_back(nc);
783
1963
      childChanged = childChanged || nc != n[i];
784
    }
785
2229
    eqc_stack.erase(e);
786
2229
    if (childChanged)
787
    {
788
310
      Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
789
310
      children.insert(children.begin(), n.getOperator());
790
310
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
791
    }
792
  }
793
1957
  return n;
794
}
795
796
62
Node DatatypesRewriter::replaceDebruijn(Node n,
797
                                        Node orig,
798
                                        TypeNode orig_tn,
799
                                        unsigned depth)
800
{
801
62
  if (n.getKind() == kind::CODATATYPE_BOUND_VARIABLE && n.getType() == orig_tn)
802
  {
803
    unsigned index =
804
1
        n.getConst<CodatatypeBoundVariable>().getIndex().toUnsignedInt();
805
1
    if (index == depth)
806
    {
807
1
      return orig;
808
    }
809
  }
810
61
  else if (n.getNumChildren() > 0)
811
  {
812
48
    std::vector<Node> children;
813
24
    bool childChanged = false;
814
72
    for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
815
    {
816
96
      Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
817
48
      children.push_back(nc);
818
48
      childChanged = childChanged || nc != n[i];
819
    }
820
24
    if (childChanged)
821
    {
822
      if (n.hasOperator())
823
      {
824
        children.insert(children.begin(), n.getOperator());
825
      }
826
      return NodeManager::currentNM()->mkNode(n.getKind(), children);
827
    }
828
  }
829
61
  return n;
830
}
831
832
4495
Node DatatypesRewriter::expandApplySelector(Node n)
833
{
834
4495
  Assert(n.getKind() == APPLY_SELECTOR);
835
8990
  Node selector = n.getOperator();
836
  // APPLY_SELECTOR always applies to an external selector, cindexOf is
837
  // legal here
838
4495
  size_t cindex = utils::cindexOf(selector);
839
4495
  const DType& dt = utils::datatypeOf(selector);
840
4495
  const DTypeConstructor& c = dt[cindex];
841
8990
  Node selector_use;
842
8990
  TypeNode ndt = n[0].getType();
843
4495
  if (options::dtSharedSelectors())
844
  {
845
209
    size_t selectorIndex = utils::indexOf(selector);
846
209
    Trace("dt-expand") << "...selector index = " << selectorIndex << std::endl;
847
209
    Assert(selectorIndex < c.getNumArgs());
848
209
    selector_use = c.getSelectorInternal(ndt, selectorIndex);
849
  }
850
  else
851
  {
852
4286
    selector_use = selector;
853
  }
854
4495
  NodeManager* nm = NodeManager::currentNM();
855
8990
  Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
856
4495
  if (options::dtRewriteErrorSel())
857
  {
858
30
    return sel;
859
  }
860
8930
  Node tester = c.getTester();
861
8930
  Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]);
862
4465
  SkolemManager* sm = nm->getSkolemManager();
863
8930
  TypeNode tnw = nm->mkFunctionType(ndt, n.getType());
864
8930
  Node f = sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector);
865
8930
  Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]);
866
8930
  Node ret = nm->mkNode(kind::ITE, tst, sel, sk);
867
4465
  Trace("dt-expand") << "Expand def : " << n << " to " << ret << std::endl;
868
4465
  return ret;
869
}
870
871
93871
TrustNode DatatypesRewriter::expandDefinition(Node n)
872
{
873
93871
  NodeManager* nm = NodeManager::currentNM();
874
187742
  TypeNode tn = n.getType();
875
187742
  Node ret;
876
93871
  switch (n.getKind())
877
  {
878
4182
    case kind::APPLY_SELECTOR:
879
    {
880
4182
      ret = expandApplySelector(n);
881
    }
882
4182
    break;
883
26
    case APPLY_UPDATER:
884
    {
885
26
      Assert(tn.isDatatype());
886
26
      const DType& dt = tn.getDType();
887
52
      Node op = n.getOperator();
888
26
      size_t updateIndex = utils::indexOf(op);
889
26
      size_t cindex = utils::cindexOf(op);
890
26
      const DTypeConstructor& dc = dt[cindex];
891
52
      NodeBuilder b(APPLY_CONSTRUCTOR);
892
26
      b << dc.getConstructor();
893
26
      Trace("dt-expand") << "Expand updater " << n << std::endl;
894
26
      Trace("dt-expand") << "expr is " << n << std::endl;
895
26
      Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl;
896
26
      Trace("dt-expand") << "t is " << tn << std::endl;
897
76
      for (size_t i = 0, size = dc.getNumArgs(); i < size; ++i)
898
      {
899
50
        if (i == updateIndex)
900
        {
901
26
          b << n[1];
902
        }
903
        else
904
        {
905
120
          b << nm->mkNode(
906
96
              APPLY_SELECTOR_TOTAL, dc.getSelectorInternal(tn, i), n[0]);
907
        }
908
      }
909
26
      ret = b;
910
26
      if (dt.getNumConstructors() > 1)
911
      {
912
        // must be the right constructor to update
913
20
        Node tester = nm->mkNode(APPLY_TESTER, dc.getTester(), n[0]);
914
10
        ret = nm->mkNode(ITE, tester, ret, n[0]);
915
      }
916
52
      Trace("dt-expand") << "return " << ret << std::endl;
917
    }
918
26
    break;
919
89663
    default: break;
920
  }
921
93871
  if (!ret.isNull() && n != ret)
922
  {
923
4208
    return TrustNode::mkTrustRewrite(n, ret, nullptr);
924
  }
925
89663
  return TrustNode::null();
926
}
927
928
89584
Node DatatypesRewriter::sygusToBuiltinEval(Node n,
929
                                           const std::vector<Node>& args)
930
{
931
89584
  Assert(d_sygusEval != nullptr);
932
89584
  NodeManager* nm = NodeManager::currentNM();
933
  // constant arguments?
934
89584
  bool constArgs = true;
935
371518
  for (const Node& a : args)
936
  {
937
288475
    if (!a.isConst())
938
    {
939
6541
      constArgs = false;
940
6541
      break;
941
    }
942
  }
943
179168
  std::vector<Node> eargs;
944
89584
  bool svarsInit = false;
945
179168
  std::vector<Node> svars;
946
179168
  std::unordered_map<TNode, Node> visited;
947
89584
  std::unordered_map<TNode, Node>::iterator it;
948
179168
  std::vector<TNode> visit;
949
179168
  TNode cur;
950
  unsigned index;
951
89584
  visit.push_back(n);
952
241425
  do
953
  {
954
331009
    cur = visit.back();
955
331009
    visit.pop_back();
956
331009
    it = visited.find(cur);
957
331009
    if (it == visited.end())
958
    {
959
491466
      TypeNode tn = cur.getType();
960
245733
      if (!tn.isDatatype() || !tn.getDType().isSygus())
961
      {
962
114
        visited[cur] = cur;
963
      }
964
245619
      else if (cur.isConst())
965
      {
966
        // convert to builtin term
967
233456
        Node bt = utils::sygusToBuiltin(cur);
968
        // run the evaluator if possible
969
116728
        if (!svarsInit)
970
        {
971
88735
          svarsInit = true;
972
177470
          TypeNode type = cur.getType();
973
177470
          Node varList = type.getDType().getSygusVarList();
974
388574
          for (const Node& v : varList)
975
          {
976
299839
            svars.push_back(v);
977
          }
978
        }
979
116728
        Assert(args.size() == svars.size());
980
        // try evaluation if we have constant arguments
981
        Node ret =
982
233456
            constArgs ? d_sygusEval->eval(bt, svars, args) : Node::null();
983
116728
        if (ret.isNull())
984
        {
985
          // if evaluation was not available, use a substitution
986
9933
          ret = bt.substitute(
987
              svars.begin(), svars.end(), args.begin(), args.end());
988
        }
989
116728
        visited[cur] = ret;
990
      }
991
      else
992
      {
993
128891
        if (cur.getKind() == APPLY_CONSTRUCTOR)
994
        {
995
79862
          visited[cur] = Node::null();
996
79862
          visit.push_back(cur);
997
241425
          for (const Node& cn : cur)
998
          {
999
161563
            visit.push_back(cn);
1000
          }
1001
        }
1002
        else
1003
        {
1004
          // it is the evaluation of this term on the arguments
1005
49029
          if (eargs.empty())
1006
          {
1007
43057
            eargs.push_back(cur);
1008
43057
            eargs.insert(eargs.end(), args.begin(), args.end());
1009
          }
1010
          else
1011
          {
1012
5972
            eargs[0] = cur;
1013
          }
1014
49029
          visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs);
1015
        }
1016
      }
1017
    }
1018
85276
    else if (it->second.isNull())
1019
    {
1020
159724
      Node ret = cur;
1021
79862
      Assert(cur.getKind() == APPLY_CONSTRUCTOR);
1022
79862
      const DType& dt = cur.getType().getDType();
1023
      // non sygus-datatype terms are also themselves
1024
79862
      if (dt.isSygus())
1025
      {
1026
159724
        std::vector<Node> children;
1027
241425
        for (const Node& cn : cur)
1028
        {
1029
161563
          it = visited.find(cn);
1030
161563
          Assert(it != visited.end());
1031
161563
          Assert(!it->second.isNull());
1032
161563
          children.push_back(it->second);
1033
        }
1034
79862
        index = utils::indexOf(cur.getOperator());
1035
        // apply to children, which constructs the builtin term
1036
79862
        ret = utils::mkSygusTerm(dt, index, children);
1037
        // now apply it to arguments in args
1038
79862
        ret = utils::applySygusArgs(dt, dt[index].getSygusOp(), ret, args);
1039
      }
1040
79862
      visited[cur] = ret;
1041
    }
1042
331009
  } while (!visit.empty());
1043
89584
  Assert(visited.find(n) != visited.end());
1044
89584
  Assert(!visited.find(n)->second.isNull());
1045
179168
  return visited[n];
1046
}
1047
1048
}  // namespace datatypes
1049
}  // namespace theory
1050
31125
}  // namespace cvc5