GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/builtin/theory_builtin_rewriter.cpp Lines: 225 250 90.0 %
Date: 2021-09-17 Branches: 655 1539 42.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Haniel Barbosa, Morgan Deters
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * [[ Add one-line brief description here ]]
14
 *
15
 * [[ Add lengthier description here ]]
16
 * \todo document this file
17
 */
18
19
#include "theory/builtin/theory_builtin_rewriter.h"
20
21
#include "expr/array_store_all.h"
22
#include "expr/attribute.h"
23
#include "expr/node_algorithm.h"
24
#include "theory/rewriter.h"
25
26
using namespace std;
27
28
namespace cvc5 {
29
namespace theory {
30
namespace builtin {
31
32
10514
Node TheoryBuiltinRewriter::blastDistinct(TNode in) {
33
10514
  Assert(in.getKind() == kind::DISTINCT);
34
35
10514
  if(in.getNumChildren() == 2) {
36
    // if this is the case exactly 1 != pair will be generated so the
37
    // AND is not required
38
20018
    Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, in[0], in[1]);
39
20018
    Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq);
40
10009
    return neq;
41
  }
42
43
  // assume that in.getNumChildren() > 2 => diseqs.size() > 1
44
1010
  vector<Node> diseqs;
45
3631
  for(TNode::iterator i = in.begin(); i != in.end(); ++i) {
46
3126
    TNode::iterator j = i;
47
71334
    while(++j != in.end()) {
48
68208
      Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, *i, *j);
49
68208
      Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq);
50
34104
      diseqs.push_back(neq);
51
    }
52
  }
53
1010
  Node out = NodeManager::currentNM()->mkNode(kind::AND, diseqs);
54
505
  return out;
55
}
56
57
818686
RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) {
58
818686
  if( node.getKind()==kind::LAMBDA ){
59
    // The following code ensures that if node is equivalent to a constant
60
    // lambda, then we return the canonical representation for the lambda, which
61
    // in turn ensures that two constant lambdas are equivalent if and only
62
    // if they are the same node.
63
    // We canonicalize lambdas by turning them into array constants, applying
64
    // normalization on array constants, and then converting the array constant
65
    // back to a lambda.
66
5892
    Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
67
11784
    Node anode = getArrayRepresentationForLambda( node );
68
    // Only rewrite constant array nodes, since these are the only cases
69
    // where we require canonicalization of lambdas. Moreover, applying the
70
    // below code is not correct if the arguments to the lambda occur
71
    // in return values. For example, lambda x. ite( x=1, f(x), c ) would
72
    // be converted to (store (storeall ... c) 1 f(x)), and then converted
73
    // to lambda y. ite( y=1, f(x), c), losing the relation between x and y.
74
5892
    if (!anode.isNull() && anode.isConst())
75
    {
76
2704
      Assert(anode.getType().isArray());
77
      //must get the standard bound variable list
78
4120
      Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() );
79
4120
      Node retNode = getLambdaForArrayRepresentation( anode, varList );
80
2704
      if( !retNode.isNull() && retNode!=node ){
81
1288
        Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
82
1288
        Trace("builtin-rewrite") << "     input  : " << node << std::endl;
83
1288
        Trace("builtin-rewrite") << "     output : " << retNode << ", constant = " << retNode.isConst() << std::endl;
84
1288
        Trace("builtin-rewrite") << "  array rep : " << anode << ", constant = " << anode.isConst() << std::endl;
85
1288
        Assert(anode.isConst() == retNode.isConst());
86
1288
        Assert(retNode.getType() == node.getType());
87
1288
        Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode));
88
1288
        return RewriteResponse(REWRITE_DONE, retNode);
89
      }
90
    }
91
    else
92
    {
93
3188
      Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl;
94
    }
95
4604
    return RewriteResponse(REWRITE_DONE, node);
96
  }
97
  // otherwise, do the default call
98
812794
  return doRewrite(node);
99
}
100
101
1248966
RewriteResponse TheoryBuiltinRewriter::doRewrite(TNode node)
102
{
103
1248966
  switch (node.getKind())
104
  {
105
2475
    case kind::WITNESS:
106
    {
107
      // it is important to run this rewriting at prerewrite and postrewrite,
108
      // since e.g. arithmetic rewrites equalities in ways that may make an
109
      // equality not in solved form syntactically, e.g. (= x (+ 1 a)) rewrites
110
      // to (= a (- x 1)), where x no longer is in solved form.
111
4950
      Node rnode = rewriteWitness(node);
112
2475
      return RewriteResponse(REWRITE_DONE, rnode);
113
    }
114
10514
    case kind::DISTINCT:
115
10514
      return RewriteResponse(REWRITE_DONE, blastDistinct(node));
116
1235977
    default: return RewriteResponse(REWRITE_DONE, node);
117
  }
118
}
119
120
TypeNode TheoryBuiltinRewriter::getFunctionTypeForArrayType(TypeNode atn,
121
                                                            Node bvl)
122
{
123
  std::vector<TypeNode> children;
124
  for (unsigned i = 0; i < bvl.getNumChildren(); i++)
125
  {
126
    Assert(atn.isArray());
127
    Assert(bvl[i].getType() == atn.getArrayIndexType());
128
    children.push_back(atn.getArrayIndexType());
129
    atn = atn.getArrayConstituentType();
130
  }
131
  children.push_back(atn);
132
  return NodeManager::currentNM()->mkFunctionType(children);
133
}
134
135
TypeNode TheoryBuiltinRewriter::getArrayTypeForFunctionType(TypeNode ftn)
136
{
137
  Assert(ftn.isFunction());
138
  // construct the curried array type
139
  unsigned nchildren = ftn.getNumChildren();
140
  TypeNode ret = ftn[nchildren - 1];
141
  for (int i = (static_cast<int>(nchildren) - 2); i >= 0; i--)
142
  {
143
    ret = NodeManager::currentNM()->mkArrayType(ftn[i], ret);
144
  }
145
  return ret;
146
}
147
148
16982
Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec(
149
    TNode a,
150
    TNode bvl,
151
    unsigned bvlIndex,
152
    std::unordered_map<TNode, Node>& visited)
153
{
154
16982
  std::unordered_map<TNode, Node>::iterator it = visited.find(a);
155
16982
  if( it==visited.end() ){
156
32248
    Node ret;
157
16124
    if( bvlIndex<bvl.getNumChildren() ){
158
8696
      Assert(a.getType().isArray());
159
8696
      if( a.getKind()==kind::STORE ){
160
        // convert the array recursively
161
11164
        Node body = getLambdaForArrayRepresentationRec( a[0], bvl, bvlIndex, visited );
162
5582
        if( !body.isNull() ){
163
          // convert the value recursively (bounded by the number of arguments in bvl)
164
11164
          Node val = getLambdaForArrayRepresentationRec( a[2], bvl, bvlIndex+1, visited );
165
5582
          if( !val.isNull() ){
166
5582
            Assert(!TypeNode::leastCommonTypeNode(a[1].getType(),
167
                                                  bvl[bvlIndex].getType())
168
                        .isNull());
169
5582
            Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType())
170
                        .isNull());
171
11164
            Node cond = bvl[bvlIndex].eqNode( a[1] );
172
5582
            ret = NodeManager::currentNM()->mkNode( kind::ITE, cond, val, body );
173
          }
174
        }
175
3114
      }else if( a.getKind()==kind::STORE_ALL ){
176
6228
        ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
177
6228
        Node sa = storeAll.getValue();
178
        // convert the default value recursively (bounded by the number of arguments in bvl)
179
3114
        ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited );
180
      }
181
    }else{
182
7428
      ret = a;
183
    }
184
16124
    visited[a] = ret;
185
16124
    return ret;
186
  }else{
187
858
    return it->second;
188
  }
189
}
190
191
2704
Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){
192
2704
  Assert(a.getType().isArray());
193
5408
  std::unordered_map<TNode, Node> visited;
194
2704
  Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
195
5408
  Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited );
196
2704
  if( !body.isNull() ){
197
2704
    body = Rewriter::rewrite( body );
198
2704
    Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl;
199
2704
    return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body );
200
  }else{
201
    Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
202
    return Node::null();
203
  }
204
}
205
206
11777
Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n,
207
                                                               TypeNode retType)
208
{
209
11777
  Assert(n.getKind() == kind::LAMBDA);
210
11777
  NodeManager* nm = NodeManager::currentNM();
211
11777
  Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl;
212
213
23554
  Node first_arg = n[0][0];
214
23554
  Node rec_bvl;
215
11777
  unsigned size = n[0].getNumChildren();
216
11777
  if (size > 1)
217
  {
218
6082
    std::vector< Node > args;
219
6599
    for (unsigned i = 1; i < size; i++)
220
    {
221
3558
      args.push_back( n[0][i] );
222
    }
223
3041
    rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args);
224
  }
225
226
11777
  Trace("builtin-rewrite-debug2") << "  process body..." << std::endl;
227
23554
  std::vector< Node > conds;
228
23554
  std::vector< Node > vals;
229
23554
  Node curr = n[1];
230
11777
  Kind ck = curr.getKind();
231
19916
  while (ck == kind::ITE || ck == kind::OR || ck == kind::AND
232
31091
         || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE)
233
  {
234
17700
    Node index_eq;
235
17700
    Node curr_val;
236
17700
    Node next;
237
    // Each iteration of this loop infers an entry in the function, e.g. it
238
    // has a value under some condition.
239
240
    // [1] We infer that the entry has value "curr_val" under condition
241
    // "index_eq". We set "next" to the node that is the remainder of the
242
    // function to process.
243
9705
    if (ck == kind::ITE)
244
    {
245
15702
      Trace("builtin-rewrite-debug2")
246
7851
          << "  process condition : " << curr[0] << std::endl;
247
7851
      index_eq = curr[0];
248
7851
      curr_val = curr[1];
249
7851
      next = curr[2];
250
    }
251
1854
    else if (ck == kind::OR || ck == kind::AND)
252
    {
253
1204
      Trace("builtin-rewrite-debug2")
254
602
          << "  process base : " << curr << std::endl;
255
      // curr = Rewriter::rewrite(curr);
256
      // Trace("builtin-rewrite-debug2")
257
      //     << "  rewriten base : " << curr << std::endl;
258
      // Complex Boolean return cases, in which
259
      //  (1) lambda x. (= x v1) v ... becomes
260
      //      lambda x. (ite (= x v1) true [...])
261
      //
262
      //  (2) lambda x. (not (= x v1)) ^ ... becomes
263
      //      lambda x. (ite (= x v1) false [...])
264
      //
265
      // Note the negated cases of the lhs of the OR/AND operators above are
266
      // handled by pushing the recursion to the then-branch, with the
267
      // else-branch being the constant value. For example, the negated (1)
268
      // would be
269
      //  (1') lambda x. (not (= x v1)) v ... becomes
270
      //       lambda x. (ite (= x v1) [...] true)
271
      // thus requiring the rest of the disjunction to be further processed in
272
      // the then-branch as the current value.
273
602
      bool pol = curr[0].getKind() != kind::NOT;
274
602
      bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR);
275
602
      index_eq = pol ? curr[0] : curr[0][0];
276
      // processed : the value that is determined by the first child of curr
277
      // remainder : the remaining children of curr
278
1035
      Node processed, remainder;
279
      // the value is the polarity of the first child or its inverse if we are
280
      // in the inverted case
281
602
      processed = nm->mkConst(!inverted? pol : !pol);
282
      // build an OR/AND with the remaining components
283
602
      if (curr.getNumChildren() == 2)
284
      {
285
586
        remainder = curr[1];
286
      }
287
      else
288
      {
289
32
        std::vector<Node> remainderNodes{curr.begin() + 1, curr.end()};
290
16
        remainder = nm->mkNode(ck, remainderNodes);
291
      }
292
602
      if (inverted)
293
      {
294
363
        curr_val = remainder;
295
363
        next = processed;
296
        // If the lambda contains more variables than the one being currently
297
        // processed, the current value can be non-constant, since it'll be
298
        // processed recursively below. Otherwise we fail.
299
363
        if (rec_bvl.isNull() && !curr_val.isConst())
300
        {
301
338
          Trace("builtin-rewrite-debug2")
302
169
              << "...non-const curr_val " << curr_val << "\n";
303
169
          return Node::null();
304
        }
305
      }
306
      else
307
      {
308
239
        curr_val = processed;
309
239
        next = remainder;
310
      }
311
433
      Trace("builtin-rewrite-debug2") << "  index_eq : " << index_eq << "\n";
312
433
      Trace("builtin-rewrite-debug2") << "  curr_val : " << curr_val << "\n";
313
866
      Trace("builtin-rewrite-debug2") << "  next : " << next << std::endl;
314
    }
315
    else
316
    {
317
2504
      Trace("builtin-rewrite-debug2")
318
1252
          << "  process base : " << curr << std::endl;
319
      // Simple Boolean return cases, in which
320
      //  (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false)
321
      //  (2) lambda x. v becomes lambda x. (ite (= x v) true false)
322
      // Note the negateg cases of the bodies above are also handled.
323
1252
      bool pol = ck != kind::NOT;
324
1252
      index_eq = pol ? curr : curr[0];
325
1252
      curr_val = nm->mkConst(pol);
326
1252
      next = nm->mkConst(!pol);
327
    }
328
329
    // [2] We ensure that "index_eq" is an equality, if possible.
330
9536
    if (index_eq.getKind() != kind::EQUAL)
331
    {
332
1643
      bool pol = index_eq.getKind() != kind::NOT;
333
2380
      Node indexEqAtom = pol ? index_eq : index_eq[0];
334
1643
      if (indexEqAtom.getKind() == kind::BOUND_VARIABLE)
335
      {
336
1079
        if (!indexEqAtom.getType().isBoolean())
337
        {
338
          // Catches default case of non-Boolean variable, e.g.
339
          // lambda x : Int. x. In this case, it is not canonical and we fail.
340
684
          Trace("builtin-rewrite-debug2")
341
342
              << "  ...non-Boolean variable." << std::endl;
342
342
          return Node::null();
343
        }
344
        // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as
345
        // lambda x. (ite (= x true) t s)
346
737
        index_eq = indexEqAtom.eqNode(nm->mkConst(pol));
347
      }
348
      else
349
      {
350
        // non-equality condition
351
1128
        Trace("builtin-rewrite-debug2")
352
564
            << "  ...non-equality condition." << std::endl;
353
564
        return Node::null();
354
      }
355
    }
356
7893
    else if (Rewriter::rewrite(index_eq) != index_eq)
357
    {
358
      // equality must be oriented correctly based on rewriter
359
9
      Trace("builtin-rewrite-debug2") << "  ...equality not oriented properly." << std::endl;
360
9
      return Node::null();
361
    }
362
363
    // [3] We ensure that "index_eq" is an equality that is equivalent to
364
    // "first_arg" = "curr_index", where curr_index is a constant, and
365
    // "first_arg" is the current argument we are processing, if possible.
366
16616
    Node curr_index;
367
10467
    for( unsigned r=0; r<2; r++ ){
368
12107
      Node arg = index_eq[r];
369
12107
      Node val = index_eq[1-r];
370
10261
      if( arg==first_arg ){
371
8415
        if (!val.isConst())
372
        {
373
          // non-constant value
374
442
          Trace("builtin-rewrite-debug2")
375
221
              << "  ...non-constant value for argument\n.";
376
221
          return Node::null();
377
        }else{
378
8194
          curr_index = val;
379
16388
          Trace("builtin-rewrite-debug2")
380
8194
              << "  arg " << arg << " -> " << val << std::endl;
381
8194
          break;
382
        }
383
      }
384
    }
385
8400
    if (curr_index.isNull())
386
    {
387
412
      Trace("builtin-rewrite-debug2")
388
206
          << "  ...could not infer index value." << std::endl;
389
206
      return Node::null();
390
    }
391
392
    // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the
393
    // remaining arguments (rec_bvl).
394
8194
    if (!rec_bvl.isNull())
395
    {
396
1440
      curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val);
397
1440
      Trace("builtin-rewrite-debug") << push;
398
1440
      Trace("builtin-rewrite-debug2") << push;
399
1440
      curr_val = getArrayRepresentationForLambdaRec(curr_val, retType);
400
1440
      Trace("builtin-rewrite-debug") << pop;
401
1440
      Trace("builtin-rewrite-debug2") << pop;
402
1440
      if (curr_val.isNull())
403
      {
404
398
        Trace("builtin-rewrite-debug2")
405
199
            << "  ...failed to recursively find value." << std::endl;
406
199
        return Node::null();
407
      }
408
    }
409
15990
    Trace("builtin-rewrite-debug2")
410
7995
        << "  ...condition is index " << curr_val << std::endl;
411
412
    // [5] Add the entry
413
7995
    conds.push_back( curr_index );
414
7995
    vals.push_back( curr_val );
415
416
    // we will now process the remainder
417
7995
    curr = next;
418
7995
    ck = curr.getKind();
419
15990
    Trace("builtin-rewrite-debug2")
420
7995
        << "  process remainder : " << curr << std::endl;
421
  }
422
10067
  if( !rec_bvl.isNull() ){
423
2247
    curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr);
424
2247
    Trace("builtin-rewrite-debug") << push;
425
2247
    Trace("builtin-rewrite-debug2") << push;
426
2247
    curr = getArrayRepresentationForLambdaRec(curr, retType);
427
2247
    Trace("builtin-rewrite-debug") << pop;
428
2247
    Trace("builtin-rewrite-debug2") << pop;
429
  }
430
10067
  if( !curr.isNull() && curr.isConst() ){
431
    // compute the return type
432
10390
    TypeNode array_type = retType;
433
11199
    for (unsigned i = 0; i < size; i++)
434
    {
435
6004
      unsigned index = (size - 1) - i;
436
6004
      array_type = nm->mkArrayType(n[0][index].getType(), array_type);
437
    }
438
5195
    Trace("builtin-rewrite-debug2") << "  make array store all " << curr.getType() << " annotated : " << array_type << std::endl;
439
5195
    Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType()));
440
5195
    curr = nm->mkConst(ArrayStoreAll(array_type, curr));
441
5195
    Trace("builtin-rewrite-debug2") << "  build array..." << std::endl;
442
    // can only build if default value is constant (since array store all must be constant)
443
5195
    Trace("builtin-rewrite-debug2") << "  got constant base " << curr << std::endl;
444
5195
    Trace("builtin-rewrite-debug2") << "  conditions " << conds << std::endl;
445
5195
    Trace("builtin-rewrite-debug2") << "  values " << vals << std::endl;
446
    // construct store chain
447
12899
    for (int i = static_cast<int>(conds.size()) - 1; i >= 0; i--)
448
    {
449
7704
      Assert(conds[i].getType().isSubtypeOf(first_arg.getType()));
450
7704
      curr = nm->mkNode(kind::STORE, curr, conds[i], vals[i]);
451
    }
452
5195
    Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl;
453
5195
    return curr;
454
  }else{
455
4872
    Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl;
456
4872
    return Node::null();
457
  }
458
}
459
460
2475
Node TheoryBuiltinRewriter::rewriteWitness(TNode node)
461
{
462
2475
  Assert(node.getKind() == kind::WITNESS);
463
2475
  if (node[1].getKind() == kind::EQUAL)
464
  {
465
72
    for (size_t i = 0; i < 2; i++)
466
    {
467
      // (witness ((x T)) (= x t)) ---> t
468
48
      if (node[1][i] == node[0][0])
469
      {
470
        Trace("builtin-rewrite") << "Witness rewrite: " << node << " --> "
471
                                 << node[1][1 - i] << std::endl;
472
        // also must be a legal elimination: the other side of the equality
473
        // cannot contain the variable, and it must be a subtype of the
474
        // variable
475
        if (!expr::hasSubterm(node[1][1 - i], node[0][0])
476
            && node[1][i].getType().isSubtypeOf(node[0][0].getType()))
477
        {
478
          return node[1][1 - i];
479
        }
480
      }
481
    }
482
  }
483
2451
  else if (node[1] == node[0][0])
484
  {
485
    // (witness ((x Bool)) x) ---> true
486
    return NodeManager::currentNM()->mkConst(true);
487
  }
488
2451
  else if (node[1].getKind() == kind::NOT && node[1][0] == node[0][0])
489
  {
490
    // (witness ((x Bool)) (not x)) ---> false
491
    return NodeManager::currentNM()->mkConst(false);
492
  }
493
2475
  return node;
494
}
495
496
8090
Node TheoryBuiltinRewriter::getArrayRepresentationForLambda(TNode n)
497
{
498
8090
  Assert(n.getKind() == kind::LAMBDA);
499
  // must carry the overall return type to deal with cases like (lambda ((x Int)
500
  // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else
501
  // case above should be (arraystoreall (Array Int Real) 0.0)
502
16180
  Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType());
503
8090
  if (anode.isNull())
504
  {
505
4860
    return anode;
506
  }
507
  // must rewrite it to make canonical
508
3230
  return Rewriter::rewrite(anode);
509
}
510
511
}  // namespace builtin
512
}  // namespace theory
513
29577
}  // namespace cvc5