GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/builtin/theory_builtin_rewriter.cpp Lines: 225 250 90.0 %
Date: 2021-05-22 Branches: 655 1541 42.5 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Haniel Barbosa, Morgan Deters
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * [[ Add one-line brief description here ]]
14
 *
15
 * [[ Add lengthier description here ]]
16
 * \todo document this file
17
 */
18
19
#include "theory/builtin/theory_builtin_rewriter.h"
20
21
#include "expr/attribute.h"
22
#include "expr/node_algorithm.h"
23
#include "theory/rewriter.h"
24
25
using namespace std;
26
27
namespace cvc5 {
28
namespace theory {
29
namespace builtin {
30
31
10798
Node TheoryBuiltinRewriter::blastDistinct(TNode in) {
32
10798
  Assert(in.getKind() == kind::DISTINCT);
33
34
10798
  if(in.getNumChildren() == 2) {
35
    // if this is the case exactly 1 != pair will be generated so the
36
    // AND is not required
37
20666
    Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, in[0], in[1]);
38
20666
    Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq);
39
10333
    return neq;
40
  }
41
42
  // assume that in.getNumChildren() > 2 => diseqs.size() > 1
43
930
  vector<Node> diseqs;
44
3429
  for(TNode::iterator i = in.begin(); i != in.end(); ++i) {
45
2964
    TNode::iterator j = i;
46
70608
    while(++j != in.end()) {
47
67644
      Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, *i, *j);
48
67644
      Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq);
49
33822
      diseqs.push_back(neq);
50
    }
51
  }
52
930
  Node out = NodeManager::currentNM()->mkNode(kind::AND, diseqs);
53
465
  return out;
54
}
55
56
742880
RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) {
57
742880
  if( node.getKind()==kind::LAMBDA ){
58
    // The following code ensures that if node is equivalent to a constant
59
    // lambda, then we return the canonical representation for the lambda, which
60
    // in turn ensures that two constant lambdas are equivalent if and only
61
    // if they are the same node.
62
    // We canonicalize lambdas by turning them into array constants, applying
63
    // normalization on array constants, and then converting the array constant
64
    // back to a lambda.
65
6856
    Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
66
13712
    Node anode = getArrayRepresentationForLambda( node );
67
    // Only rewrite constant array nodes, since these are the only cases
68
    // where we require canonicalization of lambdas. Moreover, applying the
69
    // below code is not correct if the arguments to the lambda occur
70
    // in return values. For example, lambda x. ite( x=1, f(x), c ) would
71
    // be converted to (store (storeall ... c) 1 f(x)), and then converted
72
    // to lambda y. ite( y=1, f(x), c), losing the relation between x and y.
73
6856
    if (!anode.isNull() && anode.isConst())
74
    {
75
2812
      Assert(anode.getType().isArray());
76
      //must get the standard bound variable list
77
4307
      Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() );
78
4307
      Node retNode = getLambdaForArrayRepresentation( anode, varList );
79
2812
      if( !retNode.isNull() && retNode!=node ){
80
1317
        Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
81
1317
        Trace("builtin-rewrite") << "     input  : " << node << std::endl;
82
1317
        Trace("builtin-rewrite") << "     output : " << retNode << ", constant = " << retNode.isConst() << std::endl;
83
1317
        Trace("builtin-rewrite") << "  array rep : " << anode << ", constant = " << anode.isConst() << std::endl;
84
1317
        Assert(anode.isConst() == retNode.isConst());
85
1317
        Assert(retNode.getType() == node.getType());
86
1317
        Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode));
87
1317
        return RewriteResponse(REWRITE_DONE, retNode);
88
      }
89
    }
90
    else
91
    {
92
4044
      Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl;
93
    }
94
5539
    return RewriteResponse(REWRITE_DONE, node);
95
  }
96
  // otherwise, do the default call
97
736024
  return doRewrite(node);
98
}
99
100
1135299
RewriteResponse TheoryBuiltinRewriter::doRewrite(TNode node)
101
{
102
1135299
  switch (node.getKind())
103
  {
104
2868
    case kind::WITNESS:
105
    {
106
      // it is important to run this rewriting at prerewrite and postrewrite,
107
      // since e.g. arithmetic rewrites equalities in ways that may make an
108
      // equality not in solved form syntactically, e.g. (= x (+ 1 a)) rewrites
109
      // to (= a (- x 1)), where x no longer is in solved form.
110
5736
      Node rnode = rewriteWitness(node);
111
2868
      return RewriteResponse(REWRITE_DONE, rnode);
112
    }
113
10798
    case kind::DISTINCT:
114
10798
      return RewriteResponse(REWRITE_DONE, blastDistinct(node));
115
1121633
    default: return RewriteResponse(REWRITE_DONE, node);
116
  }
117
}
118
119
TypeNode TheoryBuiltinRewriter::getFunctionTypeForArrayType(TypeNode atn,
120
                                                            Node bvl)
121
{
122
  std::vector<TypeNode> children;
123
  for (unsigned i = 0; i < bvl.getNumChildren(); i++)
124
  {
125
    Assert(atn.isArray());
126
    Assert(bvl[i].getType() == atn.getArrayIndexType());
127
    children.push_back(atn.getArrayIndexType());
128
    atn = atn.getArrayConstituentType();
129
  }
130
  children.push_back(atn);
131
  return NodeManager::currentNM()->mkFunctionType(children);
132
}
133
134
TypeNode TheoryBuiltinRewriter::getArrayTypeForFunctionType(TypeNode ftn)
135
{
136
  Assert(ftn.isFunction());
137
  // construct the curried array type
138
  unsigned nchildren = ftn.getNumChildren();
139
  TypeNode ret = ftn[nchildren - 1];
140
  for (int i = (static_cast<int>(nchildren) - 2); i >= 0; i--)
141
  {
142
    ret = NodeManager::currentNM()->mkArrayType(ftn[i], ret);
143
  }
144
  return ret;
145
}
146
147
17382
Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec(
148
    TNode a,
149
    TNode bvl,
150
    unsigned bvlIndex,
151
    std::unordered_map<TNode, Node>& visited)
152
{
153
17382
  std::unordered_map<TNode, Node>::iterator it = visited.find(a);
154
17382
  if( it==visited.end() ){
155
32972
    Node ret;
156
16486
    if( bvlIndex<bvl.getNumChildren() ){
157
8900
      Assert(a.getType().isArray());
158
8900
      if( a.getKind()==kind::STORE ){
159
        // convert the array recursively
160
11340
        Node body = getLambdaForArrayRepresentationRec( a[0], bvl, bvlIndex, visited );
161
5670
        if( !body.isNull() ){
162
          // convert the value recursively (bounded by the number of arguments in bvl)
163
11340
          Node val = getLambdaForArrayRepresentationRec( a[2], bvl, bvlIndex+1, visited );
164
5670
          if( !val.isNull() ){
165
5670
            Assert(!TypeNode::leastCommonTypeNode(a[1].getType(),
166
                                                  bvl[bvlIndex].getType())
167
                        .isNull());
168
5670
            Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType())
169
                        .isNull());
170
11340
            Node cond = bvl[bvlIndex].eqNode( a[1] );
171
5670
            ret = NodeManager::currentNM()->mkNode( kind::ITE, cond, val, body );
172
          }
173
        }
174
3230
      }else if( a.getKind()==kind::STORE_ALL ){
175
6460
        ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
176
6460
        Node sa = storeAll.getValue();
177
        // convert the default value recursively (bounded by the number of arguments in bvl)
178
3230
        ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited );
179
      }
180
    }else{
181
7586
      ret = a;
182
    }
183
16486
    visited[a] = ret;
184
16486
    return ret;
185
  }else{
186
896
    return it->second;
187
  }
188
}
189
190
2812
Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){
191
2812
  Assert(a.getType().isArray());
192
5624
  std::unordered_map<TNode, Node> visited;
193
2812
  Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
194
5624
  Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited );
195
2812
  if( !body.isNull() ){
196
2812
    body = Rewriter::rewrite( body );
197
2812
    Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl;
198
2812
    return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body );
199
  }else{
200
    Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
201
    return Node::null();
202
  }
203
}
204
205
12622
Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n,
206
                                                               TypeNode retType)
207
{
208
12622
  Assert(n.getKind() == kind::LAMBDA);
209
12622
  NodeManager* nm = NodeManager::currentNM();
210
12622
  Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl;
211
212
25244
  Node first_arg = n[0][0];
213
25244
  Node rec_bvl;
214
12622
  unsigned size = n[0].getNumChildren();
215
12622
  if (size > 1)
216
  {
217
7276
    std::vector< Node > args;
218
7933
    for (unsigned i = 1; i < size; i++)
219
    {
220
4295
      args.push_back( n[0][i] );
221
    }
222
3638
    rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args);
223
  }
224
225
12622
  Trace("builtin-rewrite-debug2") << "  process body..." << std::endl;
226
25244
  std::vector< Node > conds;
227
25244
  std::vector< Node > vals;
228
25244
  Node curr = n[1];
229
12622
  Kind ck = curr.getKind();
230
20919
  while (ck == kind::ITE || ck == kind::OR || ck == kind::AND
231
32692
         || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE)
232
  {
233
18078
    Node index_eq;
234
18078
    Node curr_val;
235
18078
    Node next;
236
    // Each iteration of this loop infers an entry in the function, e.g. it
237
    // has a value under some condition.
238
239
    // [1] We infer that the entry has value "curr_val" under condition
240
    // "index_eq". We set "next" to the node that is the remainder of the
241
    // function to process.
242
10074
    if (ck == kind::ITE)
243
    {
244
15422
      Trace("builtin-rewrite-debug2")
245
7711
          << "  process condition : " << curr[0] << std::endl;
246
7711
      index_eq = curr[0];
247
7711
      curr_val = curr[1];
248
7711
      next = curr[2];
249
    }
250
2363
    else if (ck == kind::OR || ck == kind::AND)
251
    {
252
1698
      Trace("builtin-rewrite-debug2")
253
849
          << "  process base : " << curr << std::endl;
254
      // curr = Rewriter::rewrite(curr);
255
      // Trace("builtin-rewrite-debug2")
256
      //     << "  rewriten base : " << curr << std::endl;
257
      // Complex Boolean return cases, in which
258
      //  (1) lambda x. (= x v1) v ... becomes
259
      //      lambda x. (ite (= x v1) true [...])
260
      //
261
      //  (2) lambda x. (not (= x v1)) ^ ... becomes
262
      //      lambda x. (ite (= x v1) false [...])
263
      //
264
      // Note the negated cases of the lhs of the OR/AND operators above are
265
      // handled by pushing the recursion to the then-branch, with the
266
      // else-branch being the constant value. For example, the negated (1)
267
      // would be
268
      //  (1') lambda x. (not (= x v1)) v ... becomes
269
      //       lambda x. (ite (= x v1) [...] true)
270
      // thus requiring the rest of the disjunction to be further processed in
271
      // the then-branch as the current value.
272
849
      bool pol = curr[0].getKind() != kind::NOT;
273
849
      bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR);
274
849
      index_eq = pol ? curr[0] : curr[0][0];
275
      // processed : the value that is determined by the first child of curr
276
      // remainder : the remaining children of curr
277
1502
      Node processed, remainder;
278
      // the value is the polarity of the first child or its inverse if we are
279
      // in the inverted case
280
849
      processed = nm->mkConst(!inverted? pol : !pol);
281
      // build an OR/AND with the remaining components
282
849
      if (curr.getNumChildren() == 2)
283
      {
284
813
        remainder = curr[1];
285
      }
286
      else
287
      {
288
72
        std::vector<Node> remainderNodes{curr.begin() + 1, curr.end()};
289
36
        remainder = nm->mkNode(ck, remainderNodes);
290
      }
291
849
      if (inverted)
292
      {
293
487
        curr_val = remainder;
294
487
        next = processed;
295
        // If the lambda contains more variables than the one being currently
296
        // processed, the current value can be non-constant, since it'll be
297
        // processed recursively below. Otherwise we fail.
298
487
        if (rec_bvl.isNull() && !curr_val.isConst())
299
        {
300
392
          Trace("builtin-rewrite-debug2")
301
196
              << "...non-const curr_val " << curr_val << "\n";
302
196
          return Node::null();
303
        }
304
      }
305
      else
306
      {
307
362
        curr_val = processed;
308
362
        next = remainder;
309
      }
310
653
      Trace("builtin-rewrite-debug2") << "  index_eq : " << index_eq << "\n";
311
653
      Trace("builtin-rewrite-debug2") << "  curr_val : " << curr_val << "\n";
312
1306
      Trace("builtin-rewrite-debug2") << "  next : " << next << std::endl;
313
    }
314
    else
315
    {
316
3028
      Trace("builtin-rewrite-debug2")
317
1514
          << "  process base : " << curr << std::endl;
318
      // Simple Boolean return cases, in which
319
      //  (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false)
320
      //  (2) lambda x. v becomes lambda x. (ite (= x v) true false)
321
      // Note the negateg cases of the bodies above are also handled.
322
1514
      bool pol = ck != kind::NOT;
323
1514
      index_eq = pol ? curr : curr[0];
324
1514
      curr_val = nm->mkConst(pol);
325
1514
      next = nm->mkConst(!pol);
326
    }
327
328
    // [2] We ensure that "index_eq" is an equality, if possible.
329
9878
    if (index_eq.getKind() != kind::EQUAL)
330
    {
331
2021
      bool pol = index_eq.getKind() != kind::NOT;
332
2758
      Node indexEqAtom = pol ? index_eq : index_eq[0];
333
2021
      if (indexEqAtom.getKind() == kind::BOUND_VARIABLE)
334
      {
335
1069
        if (!indexEqAtom.getType().isBoolean())
336
        {
337
          // Catches default case of non-Boolean variable, e.g.
338
          // lambda x : Int. x. In this case, it is not canonical and we fail.
339
664
          Trace("builtin-rewrite-debug2")
340
332
              << "  ...non-Boolean variable." << std::endl;
341
332
          return Node::null();
342
        }
343
        // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as
344
        // lambda x. (ite (= x true) t s)
345
737
        index_eq = indexEqAtom.eqNode(nm->mkConst(pol));
346
      }
347
      else
348
      {
349
        // non-equality condition
350
1904
        Trace("builtin-rewrite-debug2")
351
952
            << "  ...non-equality condition." << std::endl;
352
952
        return Node::null();
353
      }
354
    }
355
7857
    else if (Rewriter::rewrite(index_eq) != index_eq)
356
    {
357
      // equality must be oriented correctly based on rewriter
358
8
      Trace("builtin-rewrite-debug2") << "  ...equality not oriented properly." << std::endl;
359
8
      return Node::null();
360
    }
361
362
    // [3] We ensure that "index_eq" is an equality that is equivalent to
363
    // "first_arg" = "curr_index", where curr_index is a constant, and
364
    // "first_arg" is the current argument we are processing, if possible.
365
16590
    Node curr_index;
366
10499
    for( unsigned r=0; r<2; r++ ){
367
12230
      Node arg = index_eq[r];
368
12230
      Node val = index_eq[1-r];
369
10317
      if( arg==first_arg ){
370
8404
        if (!val.isConst())
371
        {
372
          // non-constant value
373
402
          Trace("builtin-rewrite-debug2")
374
201
              << "  ...non-constant value for argument\n.";
375
201
          return Node::null();
376
        }else{
377
8203
          curr_index = val;
378
16406
          Trace("builtin-rewrite-debug2")
379
8203
              << "  arg " << arg << " -> " << val << std::endl;
380
8203
          break;
381
        }
382
      }
383
    }
384
8385
    if (curr_index.isNull())
385
    {
386
364
      Trace("builtin-rewrite-debug2")
387
182
          << "  ...could not infer index value." << std::endl;
388
182
      return Node::null();
389
    }
390
391
    // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the
392
    // remaining arguments (rec_bvl).
393
8203
    if (!rec_bvl.isNull())
394
    {
395
1452
      curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val);
396
1452
      Trace("builtin-rewrite-debug") << push;
397
1452
      Trace("builtin-rewrite-debug2") << push;
398
1452
      curr_val = getArrayRepresentationForLambdaRec(curr_val, retType);
399
1452
      Trace("builtin-rewrite-debug") << pop;
400
1452
      Trace("builtin-rewrite-debug2") << pop;
401
1452
      if (curr_val.isNull())
402
      {
403
398
        Trace("builtin-rewrite-debug2")
404
199
            << "  ...failed to recursively find value." << std::endl;
405
199
        return Node::null();
406
      }
407
    }
408
16008
    Trace("builtin-rewrite-debug2")
409
8004
        << "  ...condition is index " << curr_val << std::endl;
410
411
    // [5] Add the entry
412
8004
    conds.push_back( curr_index );
413
8004
    vals.push_back( curr_val );
414
415
    // we will now process the remainder
416
8004
    curr = next;
417
8004
    ck = curr.getKind();
418
16008
    Trace("builtin-rewrite-debug2")
419
8004
        << "  process remainder : " << curr << std::endl;
420
  }
421
10552
  if( !rec_bvl.isNull() ){
422
2585
    curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr);
423
2585
    Trace("builtin-rewrite-debug") << push;
424
2585
    Trace("builtin-rewrite-debug2") << push;
425
2585
    curr = getArrayRepresentationForLambdaRec(curr, retType);
426
2585
    Trace("builtin-rewrite-debug") << pop;
427
2585
    Trace("builtin-rewrite-debug2") << pop;
428
  }
429
10552
  if( !curr.isNull() && curr.isConst() ){
430
    // compute the return type
431
10634
    TypeNode array_type = retType;
432
11451
    for (unsigned i = 0; i < size; i++)
433
    {
434
6134
      unsigned index = (size - 1) - i;
435
6134
      array_type = nm->mkArrayType(n[0][index].getType(), array_type);
436
    }
437
5317
    Trace("builtin-rewrite-debug2") << "  make array store all " << curr.getType() << " annotated : " << array_type << std::endl;
438
5317
    Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType()));
439
5317
    curr = nm->mkConst(ArrayStoreAll(array_type, curr));
440
5317
    Trace("builtin-rewrite-debug2") << "  build array..." << std::endl;
441
    // can only build if default value is constant (since array store all must be constant)
442
5317
    Trace("builtin-rewrite-debug2") << "  got constant base " << curr << std::endl;
443
5317
    Trace("builtin-rewrite-debug2") << "  conditions " << conds << std::endl;
444
5317
    Trace("builtin-rewrite-debug2") << "  values " << vals << std::endl;
445
    // construct store chain
446
13090
    for (int i = static_cast<int>(conds.size()) - 1; i >= 0; i--)
447
    {
448
7773
      Assert(conds[i].getType().isSubtypeOf(first_arg.getType()));
449
7773
      curr = nm->mkNode(kind::STORE, curr, conds[i], vals[i]);
450
    }
451
5317
    Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl;
452
5317
    return curr;
453
  }else{
454
5235
    Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl;
455
5235
    return Node::null();
456
  }
457
}
458
459
2868
Node TheoryBuiltinRewriter::rewriteWitness(TNode node)
460
{
461
2868
  Assert(node.getKind() == kind::WITNESS);
462
2868
  if (node[1].getKind() == kind::EQUAL)
463
  {
464
291
    for (size_t i = 0; i < 2; i++)
465
    {
466
      // (witness ((x T)) (= x t)) ---> t
467
194
      if (node[1][i] == node[0][0])
468
      {
469
        Trace("builtin-rewrite") << "Witness rewrite: " << node << " --> "
470
                                 << node[1][1 - i] << std::endl;
471
        // also must be a legal elimination: the other side of the equality
472
        // cannot contain the variable, and it must be a subtype of the
473
        // variable
474
        if (!expr::hasSubterm(node[1][1 - i], node[0][0])
475
            && node[1][i].getType().isSubtypeOf(node[0][0].getType()))
476
        {
477
          return node[1][1 - i];
478
        }
479
      }
480
    }
481
  }
482
2771
  else if (node[1] == node[0][0])
483
  {
484
    // (witness ((x Bool)) x) ---> true
485
    return NodeManager::currentNM()->mkConst(true);
486
  }
487
2771
  else if (node[1].getKind() == kind::NOT && node[1][0] == node[0][0])
488
  {
489
    // (witness ((x Bool)) (not x)) ---> false
490
    return NodeManager::currentNM()->mkConst(false);
491
  }
492
2868
  return node;
493
}
494
495
8585
Node TheoryBuiltinRewriter::getArrayRepresentationForLambda(TNode n)
496
{
497
8585
  Assert(n.getKind() == kind::LAMBDA);
498
  // must carry the overall return type to deal with cases like (lambda ((x Int)
499
  // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else
500
  // case above should be (arraystoreall (Array Int Real) 0.0)
501
17170
  Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType());
502
8585
  if (anode.isNull())
503
  {
504
5253
    return anode;
505
  }
506
  // must rewrite it to make canonical
507
3332
  return Rewriter::rewrite(anode);
508
}
509
510
}  // namespace builtin
511
}  // namespace theory
512
28191
}  // namespace cvc5