GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/builtin/theory_builtin_rewriter.cpp Lines: 225 250 90.0 %
Date: 2021-03-23 Branches: 655 1541 42.5 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file theory_builtin_rewriter.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Haniel Barbosa, Morgan Deters
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief [[ Add one-line brief description here ]]
13
 **
14
 ** [[ Add lengthier description here ]]
15
 ** \todo document this file
16
 **/
17
18
#include "theory/builtin/theory_builtin_rewriter.h"
19
20
#include "expr/attribute.h"
21
#include "expr/node_algorithm.h"
22
#include "theory/rewriter.h"
23
24
using namespace std;
25
26
namespace CVC4 {
27
namespace theory {
28
namespace builtin {
29
30
10531
Node TheoryBuiltinRewriter::blastDistinct(TNode in) {
31
10531
  Assert(in.getKind() == kind::DISTINCT);
32
33
10531
  if(in.getNumChildren() == 2) {
34
    // if this is the case exactly 1 != pair will be generated so the
35
    // AND is not required
36
20188
    Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, in[0], in[1]);
37
20188
    Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq);
38
10094
    return neq;
39
  }
40
41
  // assume that in.getNumChildren() > 2 => diseqs.size() > 1
42
874
  vector<Node> diseqs;
43
3273
  for(TNode::iterator i = in.begin(); i != in.end(); ++i) {
44
2836
    TNode::iterator j = i;
45
69930
    while(++j != in.end()) {
46
67094
      Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, *i, *j);
47
67094
      Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq);
48
33547
      diseqs.push_back(neq);
49
    }
50
  }
51
874
  Node out = NodeManager::currentNM()->mkNode(kind::AND, diseqs);
52
437
  return out;
53
}
54
55
729516
RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) {
56
729516
  if( node.getKind()==kind::LAMBDA ){
57
    // The following code ensures that if node is equivalent to a constant
58
    // lambda, then we return the canonical representation for the lambda, which
59
    // in turn ensures that two constant lambdas are equivalent if and only
60
    // if they are the same node.
61
    // We canonicalize lambdas by turning them into array constants, applying
62
    // normalization on array constants, and then converting the array constant
63
    // back to a lambda.
64
8186
    Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
65
16372
    Node anode = getArrayRepresentationForLambda( node );
66
    // Only rewrite constant array nodes, since these are the only cases
67
    // where we require canonicalization of lambdas. Moreover, applying the
68
    // below code is not correct if the arguments to the lambda occur
69
    // in return values. For example, lambda x. ite( x=1, f(x), c ) would
70
    // be converted to (store (storeall ... c) 1 f(x)), and then converted
71
    // to lambda y. ite( y=1, f(x), c), losing the relation between x and y.
72
8186
    if (!anode.isNull() && anode.isConst())
73
    {
74
2930
      Assert(anode.getType().isArray());
75
      //must get the standard bound variable list
76
4475
      Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() );
77
4475
      Node retNode = getLambdaForArrayRepresentation( anode, varList );
78
2930
      if( !retNode.isNull() && retNode!=node ){
79
1385
        Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
80
1385
        Trace("builtin-rewrite") << "     input  : " << node << std::endl;
81
1385
        Trace("builtin-rewrite") << "     output : " << retNode << ", constant = " << retNode.isConst() << std::endl;
82
1385
        Trace("builtin-rewrite") << "  array rep : " << anode << ", constant = " << anode.isConst() << std::endl;
83
1385
        Assert(anode.isConst() == retNode.isConst());
84
1385
        Assert(retNode.getType() == node.getType());
85
1385
        Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode));
86
1385
        return RewriteResponse(REWRITE_DONE, retNode);
87
      }
88
    }
89
    else
90
    {
91
5256
      Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl;
92
    }
93
6801
    return RewriteResponse(REWRITE_DONE, node);
94
  }
95
  // otherwise, do the default call
96
721330
  return doRewrite(node);
97
}
98
99
1110822
RewriteResponse TheoryBuiltinRewriter::doRewrite(TNode node)
100
{
101
1110822
  switch (node.getKind())
102
  {
103
2811
    case kind::WITNESS:
104
    {
105
      // it is important to run this rewriting at prerewrite and postrewrite,
106
      // since e.g. arithmetic rewrites equalities in ways that may make an
107
      // equality not in solved form syntactically, e.g. (= x (+ 1 a)) rewrites
108
      // to (= a (- x 1)), where x no longer is in solved form.
109
5622
      Node rnode = rewriteWitness(node);
110
2811
      return RewriteResponse(REWRITE_DONE, rnode);
111
    }
112
10531
    case kind::DISTINCT:
113
10531
      return RewriteResponse(REWRITE_DONE, blastDistinct(node));
114
1097480
    default: return RewriteResponse(REWRITE_DONE, node);
115
  }
116
}
117
118
TypeNode TheoryBuiltinRewriter::getFunctionTypeForArrayType(TypeNode atn,
119
                                                            Node bvl)
120
{
121
  std::vector<TypeNode> children;
122
  for (unsigned i = 0; i < bvl.getNumChildren(); i++)
123
  {
124
    Assert(atn.isArray());
125
    Assert(bvl[i].getType() == atn.getArrayIndexType());
126
    children.push_back(atn.getArrayIndexType());
127
    atn = atn.getArrayConstituentType();
128
  }
129
  children.push_back(atn);
130
  return NodeManager::currentNM()->mkFunctionType(children);
131
}
132
133
TypeNode TheoryBuiltinRewriter::getArrayTypeForFunctionType(TypeNode ftn)
134
{
135
  Assert(ftn.isFunction());
136
  // construct the curried array type
137
  unsigned nchildren = ftn.getNumChildren();
138
  TypeNode ret = ftn[nchildren - 1];
139
  for (int i = (static_cast<int>(nchildren) - 2); i >= 0; i--)
140
  {
141
    ret = NodeManager::currentNM()->mkArrayType(ftn[i], ret);
142
  }
143
  return ret;
144
}
145
146
17936
Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex,
147
                                                                std::unordered_map< TNode, Node, TNodeHashFunction >& visited ){
148
17936
  std::unordered_map< TNode, Node, TNodeHashFunction >::iterator it = visited.find( a );
149
17936
  if( it==visited.end() ){
150
34008
    Node ret;
151
17004
    if( bvlIndex<bvl.getNumChildren() ){
152
9206
      Assert(a.getType().isArray());
153
9206
      if( a.getKind()==kind::STORE ){
154
        // convert the array recursively
155
11600
        Node body = getLambdaForArrayRepresentationRec( a[0], bvl, bvlIndex, visited );
156
5800
        if( !body.isNull() ){
157
          // convert the value recursively (bounded by the number of arguments in bvl)
158
11600
          Node val = getLambdaForArrayRepresentationRec( a[2], bvl, bvlIndex+1, visited );
159
5800
          if( !val.isNull() ){
160
5800
            Assert(!TypeNode::leastCommonTypeNode(a[1].getType(),
161
                                                  bvl[bvlIndex].getType())
162
                        .isNull());
163
5800
            Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType())
164
                        .isNull());
165
11600
            Node cond = bvl[bvlIndex].eqNode( a[1] );
166
5800
            ret = NodeManager::currentNM()->mkNode( kind::ITE, cond, val, body );
167
          }
168
        }
169
3406
      }else if( a.getKind()==kind::STORE_ALL ){
170
6812
        ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
171
6812
        Node sa = storeAll.getValue();
172
        // convert the default value recursively (bounded by the number of arguments in bvl)
173
3406
        ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited );
174
      }
175
    }else{
176
7798
      ret = a;
177
    }
178
17004
    visited[a] = ret;
179
17004
    return ret;
180
  }else{
181
932
    return it->second;
182
  }
183
}
184
185
2930
Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){
186
2930
  Assert(a.getType().isArray());
187
5860
  std::unordered_map< TNode, Node, TNodeHashFunction > visited;
188
2930
  Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
189
5860
  Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited );
190
2930
  if( !body.isNull() ){
191
2930
    body = Rewriter::rewrite( body );
192
2930
    Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl;
193
2930
    return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body );
194
  }else{
195
    Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
196
    return Node::null();
197
  }
198
}
199
200
16117
Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n,
201
                                                               TypeNode retType)
202
{
203
16117
  Assert(n.getKind() == kind::LAMBDA);
204
16117
  NodeManager* nm = NodeManager::currentNM();
205
16117
  Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl;
206
207
32234
  Node first_arg = n[0][0];
208
32234
  Node rec_bvl;
209
16117
  unsigned size = n[0].getNumChildren();
210
16117
  if (size > 1)
211
  {
212
10826
    std::vector< Node > args;
213
11799
    for (unsigned i = 1; i < size; i++)
214
    {
215
6386
      args.push_back( n[0][i] );
216
    }
217
5413
    rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args);
218
  }
219
220
16117
  Trace("builtin-rewrite-debug2") << "  process body..." << std::endl;
221
32234
  std::vector< Node > conds;
222
32234
  std::vector< Node > vals;
223
32234
  Node curr = n[1];
224
16117
  Kind ck = curr.getKind();
225
24818
  while (ck == kind::ITE || ck == kind::OR || ck == kind::AND
226
39968
         || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE)
227
  {
228
19348
    Node index_eq;
229
19348
    Node curr_val;
230
19348
    Node next;
231
    // Each iteration of this loop infers an entry in the function, e.g. it
232
    // has a value under some condition.
233
234
    // [1] We infer that the entry has value "curr_val" under condition
235
    // "index_eq". We set "next" to the node that is the remainder of the
236
    // function to process.
237
11035
    if (ck == kind::ITE)
238
    {
239
15850
      Trace("builtin-rewrite-debug2")
240
7925
          << "  process condition : " << curr[0] << std::endl;
241
7925
      index_eq = curr[0];
242
7925
      curr_val = curr[1];
243
7925
      next = curr[2];
244
    }
245
3110
    else if (ck == kind::OR || ck == kind::AND)
246
    {
247
1934
      Trace("builtin-rewrite-debug2")
248
967
          << "  process base : " << curr << std::endl;
249
      // curr = Rewriter::rewrite(curr);
250
      // Trace("builtin-rewrite-debug2")
251
      //     << "  rewriten base : " << curr << std::endl;
252
      // Complex Boolean return cases, in which
253
      //  (1) lambda x. (= x v1) v ... becomes
254
      //      lambda x. (ite (= x v1) true [...])
255
      //
256
      //  (2) lambda x. (not (= x v1)) ^ ... becomes
257
      //      lambda x. (ite (= x v1) false [...])
258
      //
259
      // Note the negated cases of the lhs of the OR/AND operators above are
260
      // handled by pushing the recursion to the then-branch, with the
261
      // else-branch being the constant value. For example, the negated (1)
262
      // would be
263
      //  (1') lambda x. (not (= x v1)) v ... becomes
264
      //       lambda x. (ite (= x v1) [...] true)
265
      // thus requiring the rest of the disjunction to be further processed in
266
      // the then-branch as the current value.
267
967
      bool pol = curr[0].getKind() != kind::NOT;
268
967
      bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR);
269
967
      index_eq = pol ? curr[0] : curr[0][0];
270
      // processed : the value that is determined by the first child of curr
271
      // remainder : the remaining children of curr
272
1738
      Node processed, remainder;
273
      // the value is the polarity of the first child or its inverse if we are
274
      // in the inverted case
275
967
      processed = nm->mkConst(!inverted? pol : !pol);
276
      // build an OR/AND with the remaining components
277
967
      if (curr.getNumChildren() == 2)
278
      {
279
931
        remainder = curr[1];
280
      }
281
      else
282
      {
283
72
        std::vector<Node> remainderNodes{curr.begin() + 1, curr.end()};
284
36
        remainder = nm->mkNode(ck, remainderNodes);
285
      }
286
967
      if (inverted)
287
      {
288
567
        curr_val = remainder;
289
567
        next = processed;
290
        // If the lambda contains more variables than the one being currently
291
        // processed, the current value can be non-constant, since it'll be
292
        // processed recursively below. Otherwise we fail.
293
567
        if (rec_bvl.isNull() && !curr_val.isConst())
294
        {
295
392
          Trace("builtin-rewrite-debug2")
296
196
              << "...non-const curr_val " << curr_val << "\n";
297
196
          return Node::null();
298
        }
299
      }
300
      else
301
      {
302
400
        curr_val = processed;
303
400
        next = remainder;
304
      }
305
771
      Trace("builtin-rewrite-debug2") << "  index_eq : " << index_eq << "\n";
306
771
      Trace("builtin-rewrite-debug2") << "  curr_val : " << curr_val << "\n";
307
1542
      Trace("builtin-rewrite-debug2") << "  next : " << next << std::endl;
308
    }
309
    else
310
    {
311
4286
      Trace("builtin-rewrite-debug2")
312
2143
          << "  process base : " << curr << std::endl;
313
      // Simple Boolean return cases, in which
314
      //  (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false)
315
      //  (2) lambda x. v becomes lambda x. (ite (= x v) true false)
316
      // Note the negateg cases of the bodies above are also handled.
317
2143
      bool pol = ck != kind::NOT;
318
2143
      index_eq = pol ? curr : curr[0];
319
2143
      curr_val = nm->mkConst(pol);
320
2143
      next = nm->mkConst(!pol);
321
    }
322
323
    // [2] We ensure that "index_eq" is an equality, if possible.
324
10839
    if (index_eq.getKind() != kind::EQUAL)
325
    {
326
2885
      bool pol = index_eq.getKind() != kind::NOT;
327
4131
      Node indexEqAtom = pol ? index_eq : index_eq[0];
328
2885
      if (indexEqAtom.getKind() == kind::BOUND_VARIABLE)
329
      {
330
1812
        if (!indexEqAtom.getType().isBoolean())
331
        {
332
          // Catches default case of non-Boolean variable, e.g.
333
          // lambda x : Int. x. In this case, it is not canonical and we fail.
334
1132
          Trace("builtin-rewrite-debug2")
335
566
              << "  ...non-Boolean variable." << std::endl;
336
566
          return Node::null();
337
        }
338
        // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as
339
        // lambda x. (ite (= x true) t s)
340
1246
        index_eq = indexEqAtom.eqNode(nm->mkConst(pol));
341
      }
342
      else
343
      {
344
        // non-equality condition
345
2146
        Trace("builtin-rewrite-debug2")
346
1073
            << "  ...non-equality condition." << std::endl;
347
1073
        return Node::null();
348
      }
349
    }
350
7954
    else if (Rewriter::rewrite(index_eq) != index_eq)
351
    {
352
      // equality must be oriented correctly based on rewriter
353
15
      Trace("builtin-rewrite-debug2") << "  ...equality not oriented properly." << std::endl;
354
15
      return Node::null();
355
    }
356
357
    // [3] We ensure that "index_eq" is an equality that is equivalent to
358
    // "first_arg" = "curr_index", where curr_index is a constant, and
359
    // "first_arg" is the current argument we are processing, if possible.
360
17498
    Node curr_index;
361
11175
    for( unsigned r=0; r<2; r++ ){
362
12940
      Node arg = index_eq[r];
363
12940
      Node val = index_eq[1-r];
364
10950
      if( arg==first_arg ){
365
8960
        if (!val.isConst())
366
        {
367
          // non-constant value
368
534
          Trace("builtin-rewrite-debug2")
369
267
              << "  ...non-constant value for argument\n.";
370
267
          return Node::null();
371
        }else{
372
8693
          curr_index = val;
373
17386
          Trace("builtin-rewrite-debug2")
374
8693
              << "  arg " << arg << " -> " << val << std::endl;
375
8693
          break;
376
        }
377
      }
378
    }
379
8918
    if (curr_index.isNull())
380
    {
381
450
      Trace("builtin-rewrite-debug2")
382
225
          << "  ...could not infer index value." << std::endl;
383
225
      return Node::null();
384
    }
385
386
    // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the
387
    // remaining arguments (rec_bvl).
388
8693
    if (!rec_bvl.isNull())
389
    {
390
1755
      curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val);
391
1755
      Trace("builtin-rewrite-debug") << push;
392
1755
      Trace("builtin-rewrite-debug2") << push;
393
1755
      curr_val = getArrayRepresentationForLambdaRec(curr_val, retType);
394
1755
      Trace("builtin-rewrite-debug") << pop;
395
1755
      Trace("builtin-rewrite-debug2") << pop;
396
1755
      if (curr_val.isNull())
397
      {
398
760
        Trace("builtin-rewrite-debug2")
399
380
            << "  ...failed to recursively find value." << std::endl;
400
380
        return Node::null();
401
      }
402
    }
403
16626
    Trace("builtin-rewrite-debug2")
404
8313
        << "  ...condition is index " << curr_val << std::endl;
405
406
    // [5] Add the entry
407
8313
    conds.push_back( curr_index );
408
8313
    vals.push_back( curr_val );
409
410
    // we will now process the remainder
411
8313
    curr = next;
412
8313
    ck = curr.getKind();
413
16626
    Trace("builtin-rewrite-debug2")
414
8313
        << "  process remainder : " << curr << std::endl;
415
  }
416
13395
  if( !rec_bvl.isNull() ){
417
3759
    curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr);
418
3759
    Trace("builtin-rewrite-debug") << push;
419
3759
    Trace("builtin-rewrite-debug2") << push;
420
3759
    curr = getArrayRepresentationForLambdaRec(curr, retType);
421
3759
    Trace("builtin-rewrite-debug") << pop;
422
3759
    Trace("builtin-rewrite-debug2") << pop;
423
  }
424
13395
  if( !curr.isNull() && curr.isConst() ){
425
    // compute the return type
426
11372
    TypeNode array_type = retType;
427
12235
    for (unsigned i = 0; i < size; i++)
428
    {
429
6549
      unsigned index = (size - 1) - i;
430
6549
      array_type = nm->mkArrayType(n[0][index].getType(), array_type);
431
    }
432
5686
    Trace("builtin-rewrite-debug2") << "  make array store all " << curr.getType() << " annotated : " << array_type << std::endl;
433
5686
    Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType()));
434
5686
    curr = nm->mkConst(ArrayStoreAll(array_type, curr));
435
5686
    Trace("builtin-rewrite-debug2") << "  build array..." << std::endl;
436
    // can only build if default value is constant (since array store all must be constant)
437
5686
    Trace("builtin-rewrite-debug2") << "  got constant base " << curr << std::endl;
438
5686
    Trace("builtin-rewrite-debug2") << "  conditions " << conds << std::endl;
439
5686
    Trace("builtin-rewrite-debug2") << "  values " << vals << std::endl;
440
    // construct store chain
441
13729
    for (int i = static_cast<int>(conds.size()) - 1; i >= 0; i--)
442
    {
443
8043
      Assert(conds[i].getType().isSubtypeOf(first_arg.getType()));
444
8043
      curr = nm->mkNode(kind::STORE, curr, conds[i], vals[i]);
445
    }
446
5686
    Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl;
447
5686
    return curr;
448
  }else{
449
7709
    Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl;
450
7709
    return Node::null();
451
  }
452
}
453
454
2811
Node TheoryBuiltinRewriter::rewriteWitness(TNode node)
455
{
456
2811
  Assert(node.getKind() == kind::WITNESS);
457
2811
  if (node[1].getKind() == kind::EQUAL)
458
  {
459
291
    for (size_t i = 0; i < 2; i++)
460
    {
461
      // (witness ((x T)) (= x t)) ---> t
462
194
      if (node[1][i] == node[0][0])
463
      {
464
        Trace("builtin-rewrite") << "Witness rewrite: " << node << " --> "
465
                                 << node[1][1 - i] << std::endl;
466
        // also must be a legal elimination: the other side of the equality
467
        // cannot contain the variable, and it must be a subtype of the
468
        // variable
469
        if (!expr::hasSubterm(node[1][1 - i], node[0][0])
470
            && node[1][i].getType().isSubtypeOf(node[0][0].getType()))
471
        {
472
          return node[1][1 - i];
473
        }
474
      }
475
    }
476
  }
477
2714
  else if (node[1] == node[0][0])
478
  {
479
    // (witness ((x Bool)) x) ---> true
480
    return NodeManager::currentNM()->mkConst(true);
481
  }
482
2714
  else if (node[1].getKind() == kind::NOT && node[1][0] == node[0][0])
483
  {
484
    // (witness ((x Bool)) (not x)) ---> false
485
    return NodeManager::currentNM()->mkConst(false);
486
  }
487
2811
  return node;
488
}
489
490
10603
Node TheoryBuiltinRewriter::getArrayRepresentationForLambda(TNode n)
491
{
492
10603
  Assert(n.getKind() == kind::LAMBDA);
493
  // must carry the overall return type to deal with cases like (lambda ((x Int)
494
  // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else
495
  // case above should be (arraystoreall (Array Int Real) 0.0)
496
21206
  Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType());
497
10603
  if (anode.isNull())
498
  {
499
7070
    return anode;
500
  }
501
  // must rewrite it to make canonical
502
3533
  return Rewriter::rewrite(anode);
503
}
504
505
}/* CVC4::theory::builtin namespace */
506
}/* CVC4::theory namespace */
507
26685
}/* CVC4 namespace */