GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/uf/function_const.cpp Lines: 185 204 90.7 %
Date: 2021-11-07 Branches: 490 1070 45.8 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Utilities for function constants
14
 */
15
16
#include "theory/uf/function_const.h"
17
18
#include "expr/array_store_all.h"
19
#include "theory/rewriter.h"
20
21
namespace cvc5 {
22
namespace theory {
23
namespace uf {
24
25
TypeNode FunctionConst::getFunctionTypeForArrayType(TypeNode atn, Node bvl)
26
{
27
  std::vector<TypeNode> children;
28
  for (unsigned i = 0; i < bvl.getNumChildren(); i++)
29
  {
30
    Assert(atn.isArray());
31
    Assert(bvl[i].getType() == atn.getArrayIndexType());
32
    children.push_back(atn.getArrayIndexType());
33
    atn = atn.getArrayConstituentType();
34
  }
35
  children.push_back(atn);
36
  return NodeManager::currentNM()->mkFunctionType(children);
37
}
38
39
TypeNode FunctionConst::getArrayTypeForFunctionType(TypeNode ftn)
40
{
41
  Assert(ftn.isFunction());
42
  // construct the curried array type
43
  size_t nchildren = ftn.getNumChildren();
44
  TypeNode ret = ftn[nchildren - 1];
45
  for (size_t i = 0; i < nchildren - 1; i++)
46
  {
47
    size_t ii = nchildren - i - 2;
48
    ret = NodeManager::currentNM()->mkArrayType(ftn[ii], ret);
49
  }
50
  return ret;
51
}
52
53
9888
Node FunctionConst::getLambdaForArrayRepresentationRec(
54
    TNode a,
55
    TNode bvl,
56
    unsigned bvlIndex,
57
    std::unordered_map<TNode, Node>& visited)
58
{
59
9888
  std::unordered_map<TNode, Node>::iterator it = visited.find(a);
60
9888
  if (it != visited.end())
61
  {
62
996
    return it->second;
63
  }
64
17784
  Node ret;
65
8892
  if (bvlIndex < bvl.getNumChildren())
66
  {
67
5224
    Assert(a.getType().isArray());
68
5224
    if (a.getKind() == kind::STORE)
69
    {
70
      // convert the array recursively
71
      Node body =
72
6536
          getLambdaForArrayRepresentationRec(a[0], bvl, bvlIndex, visited);
73
3268
      if (!body.isNull())
74
      {
75
        // convert the value recursively (bounded by the number of arguments
76
        // in bvl)
77
        Node val = getLambdaForArrayRepresentationRec(
78
6536
            a[2], bvl, bvlIndex + 1, visited);
79
3268
        if (!val.isNull())
80
        {
81
3268
          Assert(!TypeNode::leastCommonTypeNode(a[1].getType(),
82
                                                bvl[bvlIndex].getType())
83
                      .isNull());
84
3268
          Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType())
85
                      .isNull());
86
6536
          Node cond = bvl[bvlIndex].eqNode(a[1]);
87
3268
          ret = NodeManager::currentNM()->mkNode(kind::ITE, cond, val, body);
88
        }
89
      }
90
    }
91
1956
    else if (a.getKind() == kind::STORE_ALL)
92
    {
93
3912
      ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
94
3912
      Node sa = storeAll.getValue();
95
      // convert the default value recursively (bounded by the number of
96
      // arguments in bvl)
97
1956
      ret = getLambdaForArrayRepresentationRec(sa, bvl, bvlIndex + 1, visited);
98
    }
99
  }
100
  else
101
  {
102
3668
    ret = a;
103
  }
104
8892
  visited[a] = ret;
105
8892
  return ret;
106
}
107
108
1396
Node FunctionConst::getLambdaForArrayRepresentation(TNode a, TNode bvl)
109
{
110
1396
  Assert(a.getType().isArray());
111
2792
  std::unordered_map<TNode, Node> visited;
112
2792
  Trace("builtin-rewrite-debug")
113
1396
      << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
114
2792
  Node body = getLambdaForArrayRepresentationRec(a, bvl, 0, visited);
115
1396
  if (!body.isNull())
116
  {
117
1396
    body = Rewriter::rewrite(body);
118
2792
    Trace("builtin-rewrite-debug")
119
1396
        << "...got lambda body " << body << std::endl;
120
1396
    return NodeManager::currentNM()->mkNode(kind::LAMBDA, bvl, body);
121
  }
122
  Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
123
  return Node::null();
124
}
125
126
12399
Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n,
127
                                                       TypeNode retType)
128
{
129
12399
  Assert(n.getKind() == kind::LAMBDA);
130
12399
  NodeManager* nm = NodeManager::currentNM();
131
24798
  Trace("builtin-rewrite-debug")
132
12399
      << "Get array representation for : " << n << std::endl;
133
134
24798
  Node first_arg = n[0][0];
135
24798
  Node rec_bvl;
136
12399
  size_t size = n[0].getNumChildren();
137
12399
  if (size > 1)
138
  {
139
10024
    std::vector<Node> args;
140
10912
    for (size_t i = 1; i < size; i++)
141
    {
142
5900
      args.push_back(n[0][i]);
143
    }
144
5012
    rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args);
145
  }
146
147
12399
  Trace("builtin-rewrite-debug2") << "  process body..." << std::endl;
148
24798
  std::vector<Node> conds;
149
24798
  std::vector<Node> vals;
150
24798
  Node curr = n[1];
151
12399
  Kind ck = curr.getKind();
152
18829
  while (ck == kind::ITE || ck == kind::OR || ck == kind::AND
153
30379
         || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE)
154
  {
155
13185
    Node index_eq;
156
13185
    Node curr_val;
157
13185
    Node next;
158
    // Each iteration of this loop infers an entry in the function, e.g. it
159
    // has a value under some condition.
160
161
    // [1] We infer that the entry has value "curr_val" under condition
162
    // "index_eq". We set "next" to the node that is the remainder of the
163
    // function to process.
164
7606
    if (ck == kind::ITE)
165
    {
166
9456
      Trace("builtin-rewrite-debug2")
167
4728
          << "  process condition : " << curr[0] << std::endl;
168
4728
      index_eq = curr[0];
169
4728
      curr_val = curr[1];
170
4728
      next = curr[2];
171
    }
172
2878
    else if (ck == kind::OR || ck == kind::AND)
173
    {
174
1698
      Trace("builtin-rewrite-debug2")
175
849
          << "  process base : " << curr << std::endl;
176
      // curr = Rewriter::rewrite(curr);
177
      // Trace("builtin-rewrite-debug2")
178
      //     << "  rewriten base : " << curr << std::endl;
179
      // Complex Boolean return cases, in which
180
      //  (1) lambda x. (= x v1) v ... becomes
181
      //      lambda x. (ite (= x v1) true [...])
182
      //
183
      //  (2) lambda x. (not (= x v1)) ^ ... becomes
184
      //      lambda x. (ite (= x v1) false [...])
185
      //
186
      // Note the negated cases of the lhs of the OR/AND operators above are
187
      // handled by pushing the recursion to the then-branch, with the
188
      // else-branch being the constant value. For example, the negated (1)
189
      // would be
190
      //  (1') lambda x. (not (= x v1)) v ... becomes
191
      //       lambda x. (ite (= x v1) [...] true)
192
      // thus requiring the rest of the disjunction to be further processed in
193
      // the then-branch as the current value.
194
849
      bool pol = curr[0].getKind() != kind::NOT;
195
849
      bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR);
196
849
      index_eq = pol ? curr[0] : curr[0][0];
197
      // processed : the value that is determined by the first child of curr
198
      // remainder : the remaining children of curr
199
1526
      Node processed, remainder;
200
      // the value is the polarity of the first child or its inverse if we are
201
      // in the inverted case
202
849
      processed = nm->mkConst(!inverted ? pol : !pol);
203
      // build an OR/AND with the remaining components
204
849
      if (curr.getNumChildren() == 2)
205
      {
206
819
        remainder = curr[1];
207
      }
208
      else
209
      {
210
60
        std::vector<Node> remainderNodes{curr.begin() + 1, curr.end()};
211
30
        remainder = nm->mkNode(ck, remainderNodes);
212
      }
213
849
      if (inverted)
214
      {
215
522
        curr_val = remainder;
216
522
        next = processed;
217
        // If the lambda contains more variables than the one being currently
218
        // processed, the current value can be non-constant, since it'll be
219
        // processed recursively below. Otherwise we fail.
220
522
        if (rec_bvl.isNull() && !curr_val.isConst())
221
        {
222
344
          Trace("builtin-rewrite-debug2")
223
172
              << "...non-const curr_val " << curr_val << "\n";
224
172
          return Node::null();
225
        }
226
      }
227
      else
228
      {
229
327
        curr_val = processed;
230
327
        next = remainder;
231
      }
232
677
      Trace("builtin-rewrite-debug2") << "  index_eq : " << index_eq << "\n";
233
677
      Trace("builtin-rewrite-debug2") << "  curr_val : " << curr_val << "\n";
234
1354
      Trace("builtin-rewrite-debug2") << "  next : " << next << std::endl;
235
    }
236
    else
237
    {
238
4058
      Trace("builtin-rewrite-debug2")
239
2029
          << "  process base : " << curr << std::endl;
240
      // Simple Boolean return cases, in which
241
      //  (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false)
242
      //  (2) lambda x. v becomes lambda x. (ite (= x v) true false)
243
      // Note the negateg cases of the bodies above are also handled.
244
2029
      bool pol = ck != kind::NOT;
245
2029
      index_eq = pol ? curr : curr[0];
246
2029
      curr_val = nm->mkConst(pol);
247
2029
      next = nm->mkConst(!pol);
248
    }
249
250
    // [2] We ensure that "index_eq" is an equality, if possible.
251
7434
    if (index_eq.getKind() != kind::EQUAL)
252
    {
253
2398
      bool pol = index_eq.getKind() != kind::NOT;
254
3825
      Node indexEqAtom = pol ? index_eq : index_eq[0];
255
2398
      if (indexEqAtom.getKind() == kind::BOUND_VARIABLE)
256
      {
257
1990
        if (!indexEqAtom.getType().isBoolean())
258
        {
259
          // Catches default case of non-Boolean variable, e.g.
260
          // lambda x : Int. x. In this case, it is not canonical and we fail.
261
1126
          Trace("builtin-rewrite-debug2")
262
563
              << "  ...non-Boolean variable." << std::endl;
263
563
          return Node::null();
264
        }
265
        // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as
266
        // lambda x. (ite (= x true) t s)
267
1427
        index_eq = indexEqAtom.eqNode(nm->mkConst(pol));
268
      }
269
      else
270
      {
271
        // non-equality condition
272
816
        Trace("builtin-rewrite-debug2")
273
408
            << "  ...non-equality condition." << std::endl;
274
408
        return Node::null();
275
      }
276
    }
277
5036
    else if (Rewriter::rewrite(index_eq) != index_eq)
278
    {
279
      // equality must be oriented correctly based on rewriter
280
34
      Trace("builtin-rewrite-debug2")
281
17
          << "  ...equality not oriented properly." << std::endl;
282
17
      return Node::null();
283
    }
284
285
    // [3] We ensure that "index_eq" is an equality that is equivalent to
286
    // "first_arg" = "curr_index", where curr_index is a constant, and
287
    // "first_arg" is the current argument we are processing, if possible.
288
12025
    Node curr_index;
289
8373
    for (unsigned r = 0; r < 2; r++)
290
    {
291
10057
      Node arg = index_eq[r];
292
10057
      Node val = index_eq[1 - r];
293
8130
      if (arg == first_arg)
294
      {
295
6203
        if (!val.isConst())
296
        {
297
          // non-constant value
298
470
          Trace("builtin-rewrite-debug2")
299
235
              << "  ...non-constant value for argument\n.";
300
235
          return Node::null();
301
        }
302
        else
303
        {
304
5968
          curr_index = val;
305
11936
          Trace("builtin-rewrite-debug2")
306
5968
              << "  arg " << arg << " -> " << val << std::endl;
307
5968
          break;
308
        }
309
      }
310
    }
311
6211
    if (curr_index.isNull())
312
    {
313
486
      Trace("builtin-rewrite-debug2")
314
243
          << "  ...could not infer index value." << std::endl;
315
243
      return Node::null();
316
    }
317
318
    // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the
319
    // remaining arguments (rec_bvl).
320
5968
    if (!rec_bvl.isNull())
321
    {
322
1946
      curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val);
323
1946
      Trace("builtin-rewrite-debug") << push;
324
1946
      Trace("builtin-rewrite-debug2") << push;
325
1946
      curr_val = getArrayRepresentationForLambdaRec(curr_val, retType);
326
1946
      Trace("builtin-rewrite-debug") << pop;
327
1946
      Trace("builtin-rewrite-debug2") << pop;
328
1946
      if (curr_val.isNull())
329
      {
330
778
        Trace("builtin-rewrite-debug2")
331
389
            << "  ...failed to recursively find value." << std::endl;
332
389
        return Node::null();
333
      }
334
    }
335
11158
    Trace("builtin-rewrite-debug2")
336
5579
        << "  ...condition is index " << curr_val << std::endl;
337
338
    // [5] Add the entry
339
5579
    conds.push_back(curr_index);
340
5579
    vals.push_back(curr_val);
341
342
    // we will now process the remainder
343
5579
    curr = next;
344
5579
    ck = curr.getKind();
345
11158
    Trace("builtin-rewrite-debug2")
346
5579
        << "  process remainder : " << curr << std::endl;
347
  }
348
10372
  if (!rec_bvl.isNull())
349
  {
350
3528
    curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr);
351
3528
    Trace("builtin-rewrite-debug") << push;
352
3528
    Trace("builtin-rewrite-debug2") << push;
353
3528
    curr = getArrayRepresentationForLambdaRec(curr, retType);
354
3528
    Trace("builtin-rewrite-debug") << pop;
355
3528
    Trace("builtin-rewrite-debug2") << pop;
356
  }
357
10372
  if (!curr.isNull() && curr.isConst())
358
  {
359
    // compute the return type
360
8936
    TypeNode array_type = retType;
361
10014
    for (size_t i = 0; i < size; i++)
362
    {
363
5546
      size_t index = (size - 1) - i;
364
5546
      array_type = nm->mkArrayType(n[0][index].getType(), array_type);
365
    }
366
8936
    Trace("builtin-rewrite-debug2")
367
8936
        << "  make array store all " << curr.getType()
368
4468
        << " annotated : " << array_type << std::endl;
369
4468
    Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType()));
370
4468
    curr = nm->mkConst(ArrayStoreAll(array_type, curr));
371
4468
    Trace("builtin-rewrite-debug2") << "  build array..." << std::endl;
372
    // can only build if default value is constant (since array store all must
373
    // be constant)
374
8936
    Trace("builtin-rewrite-debug2")
375
4468
        << "  got constant base " << curr << std::endl;
376
4468
    Trace("builtin-rewrite-debug2") << "  conditions " << conds << std::endl;
377
4468
    Trace("builtin-rewrite-debug2") << "  values " << vals << std::endl;
378
    // construct store chain
379
9873
    for (size_t i = 0, numCond = conds.size(); i < numCond; i++)
380
    {
381
5405
      size_t ii = (numCond - 1) - i;
382
5405
      Assert(conds[ii].getType().isSubtypeOf(first_arg.getType()));
383
5405
      curr = nm->mkNode(kind::STORE, curr, conds[ii], vals[ii]);
384
    }
385
8936
    Trace("builtin-rewrite-debug")
386
4468
        << "...got array " << curr << " for " << n << std::endl;
387
4468
    return curr;
388
  }
389
11808
  Trace("builtin-rewrite-debug")
390
5904
      << "...failed to get array (cannot get constant default value)"
391
5904
      << std::endl;
392
5904
  return Node::null();
393
}
394
395
6925
Node FunctionConst::getArrayRepresentationForLambda(TNode n)
396
{
397
6925
  Assert(n.getKind() == kind::LAMBDA);
398
  // must carry the overall return type to deal with cases like (lambda ((x Int)
399
  // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else
400
  // case above should be (arraystoreall (Array Int Real) 0.0)
401
13850
  Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType());
402
6925
  if (anode.isNull())
403
  {
404
4962
    return anode;
405
  }
406
  // must rewrite it to make canonical
407
1963
  return Rewriter::rewrite(anode);
408
}
409
410
}  // namespace uf
411
}  // namespace theory
412
31137
}  // namespace cvc5