1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Andrew Reynolds |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Utilities for function constants |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "theory/uf/function_const.h" |
17 |
|
|
18 |
|
#include "expr/array_store_all.h" |
19 |
|
#include "theory/rewriter.h" |
20 |
|
|
21 |
|
namespace cvc5 { |
22 |
|
namespace theory { |
23 |
|
namespace uf { |
24 |
|
|
25 |
|
TypeNode FunctionConst::getFunctionTypeForArrayType(TypeNode atn, Node bvl) |
26 |
|
{ |
27 |
|
std::vector<TypeNode> children; |
28 |
|
for (unsigned i = 0; i < bvl.getNumChildren(); i++) |
29 |
|
{ |
30 |
|
Assert(atn.isArray()); |
31 |
|
Assert(bvl[i].getType() == atn.getArrayIndexType()); |
32 |
|
children.push_back(atn.getArrayIndexType()); |
33 |
|
atn = atn.getArrayConstituentType(); |
34 |
|
} |
35 |
|
children.push_back(atn); |
36 |
|
return NodeManager::currentNM()->mkFunctionType(children); |
37 |
|
} |
38 |
|
|
39 |
|
TypeNode FunctionConst::getArrayTypeForFunctionType(TypeNode ftn) |
40 |
|
{ |
41 |
|
Assert(ftn.isFunction()); |
42 |
|
// construct the curried array type |
43 |
|
size_t nchildren = ftn.getNumChildren(); |
44 |
|
TypeNode ret = ftn[nchildren - 1]; |
45 |
|
for (size_t i = 0; i < nchildren - 1; i++) |
46 |
|
{ |
47 |
|
size_t ii = nchildren - i - 2; |
48 |
|
ret = NodeManager::currentNM()->mkArrayType(ftn[ii], ret); |
49 |
|
} |
50 |
|
return ret; |
51 |
|
} |
52 |
|
|
53 |
9888 |
Node FunctionConst::getLambdaForArrayRepresentationRec( |
54 |
|
TNode a, |
55 |
|
TNode bvl, |
56 |
|
unsigned bvlIndex, |
57 |
|
std::unordered_map<TNode, Node>& visited) |
58 |
|
{ |
59 |
9888 |
std::unordered_map<TNode, Node>::iterator it = visited.find(a); |
60 |
9888 |
if (it != visited.end()) |
61 |
|
{ |
62 |
996 |
return it->second; |
63 |
|
} |
64 |
17784 |
Node ret; |
65 |
8892 |
if (bvlIndex < bvl.getNumChildren()) |
66 |
|
{ |
67 |
5224 |
Assert(a.getType().isArray()); |
68 |
5224 |
if (a.getKind() == kind::STORE) |
69 |
|
{ |
70 |
|
// convert the array recursively |
71 |
|
Node body = |
72 |
6536 |
getLambdaForArrayRepresentationRec(a[0], bvl, bvlIndex, visited); |
73 |
3268 |
if (!body.isNull()) |
74 |
|
{ |
75 |
|
// convert the value recursively (bounded by the number of arguments |
76 |
|
// in bvl) |
77 |
|
Node val = getLambdaForArrayRepresentationRec( |
78 |
6536 |
a[2], bvl, bvlIndex + 1, visited); |
79 |
3268 |
if (!val.isNull()) |
80 |
|
{ |
81 |
3268 |
Assert(!TypeNode::leastCommonTypeNode(a[1].getType(), |
82 |
|
bvl[bvlIndex].getType()) |
83 |
|
.isNull()); |
84 |
3268 |
Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType()) |
85 |
|
.isNull()); |
86 |
6536 |
Node cond = bvl[bvlIndex].eqNode(a[1]); |
87 |
3268 |
ret = NodeManager::currentNM()->mkNode(kind::ITE, cond, val, body); |
88 |
|
} |
89 |
|
} |
90 |
|
} |
91 |
1956 |
else if (a.getKind() == kind::STORE_ALL) |
92 |
|
{ |
93 |
3912 |
ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>(); |
94 |
3912 |
Node sa = storeAll.getValue(); |
95 |
|
// convert the default value recursively (bounded by the number of |
96 |
|
// arguments in bvl) |
97 |
1956 |
ret = getLambdaForArrayRepresentationRec(sa, bvl, bvlIndex + 1, visited); |
98 |
|
} |
99 |
|
} |
100 |
|
else |
101 |
|
{ |
102 |
3668 |
ret = a; |
103 |
|
} |
104 |
8892 |
visited[a] = ret; |
105 |
8892 |
return ret; |
106 |
|
} |
107 |
|
|
108 |
1396 |
Node FunctionConst::getLambdaForArrayRepresentation(TNode a, TNode bvl) |
109 |
|
{ |
110 |
1396 |
Assert(a.getType().isArray()); |
111 |
2792 |
std::unordered_map<TNode, Node> visited; |
112 |
2792 |
Trace("builtin-rewrite-debug") |
113 |
1396 |
<< "Get lambda for : " << a << ", with variables " << bvl << std::endl; |
114 |
2792 |
Node body = getLambdaForArrayRepresentationRec(a, bvl, 0, visited); |
115 |
1396 |
if (!body.isNull()) |
116 |
|
{ |
117 |
1396 |
body = Rewriter::rewrite(body); |
118 |
2792 |
Trace("builtin-rewrite-debug") |
119 |
1396 |
<< "...got lambda body " << body << std::endl; |
120 |
1396 |
return NodeManager::currentNM()->mkNode(kind::LAMBDA, bvl, body); |
121 |
|
} |
122 |
|
Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl; |
123 |
|
return Node::null(); |
124 |
|
} |
125 |
|
|
126 |
12399 |
Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n, |
127 |
|
TypeNode retType) |
128 |
|
{ |
129 |
12399 |
Assert(n.getKind() == kind::LAMBDA); |
130 |
12399 |
NodeManager* nm = NodeManager::currentNM(); |
131 |
24798 |
Trace("builtin-rewrite-debug") |
132 |
12399 |
<< "Get array representation for : " << n << std::endl; |
133 |
|
|
134 |
24798 |
Node first_arg = n[0][0]; |
135 |
24798 |
Node rec_bvl; |
136 |
12399 |
size_t size = n[0].getNumChildren(); |
137 |
12399 |
if (size > 1) |
138 |
|
{ |
139 |
10024 |
std::vector<Node> args; |
140 |
10912 |
for (size_t i = 1; i < size; i++) |
141 |
|
{ |
142 |
5900 |
args.push_back(n[0][i]); |
143 |
|
} |
144 |
5012 |
rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args); |
145 |
|
} |
146 |
|
|
147 |
12399 |
Trace("builtin-rewrite-debug2") << " process body..." << std::endl; |
148 |
24798 |
std::vector<Node> conds; |
149 |
24798 |
std::vector<Node> vals; |
150 |
24798 |
Node curr = n[1]; |
151 |
12399 |
Kind ck = curr.getKind(); |
152 |
18829 |
while (ck == kind::ITE || ck == kind::OR || ck == kind::AND |
153 |
30379 |
|| ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE) |
154 |
|
{ |
155 |
13185 |
Node index_eq; |
156 |
13185 |
Node curr_val; |
157 |
13185 |
Node next; |
158 |
|
// Each iteration of this loop infers an entry in the function, e.g. it |
159 |
|
// has a value under some condition. |
160 |
|
|
161 |
|
// [1] We infer that the entry has value "curr_val" under condition |
162 |
|
// "index_eq". We set "next" to the node that is the remainder of the |
163 |
|
// function to process. |
164 |
7606 |
if (ck == kind::ITE) |
165 |
|
{ |
166 |
9456 |
Trace("builtin-rewrite-debug2") |
167 |
4728 |
<< " process condition : " << curr[0] << std::endl; |
168 |
4728 |
index_eq = curr[0]; |
169 |
4728 |
curr_val = curr[1]; |
170 |
4728 |
next = curr[2]; |
171 |
|
} |
172 |
2878 |
else if (ck == kind::OR || ck == kind::AND) |
173 |
|
{ |
174 |
1698 |
Trace("builtin-rewrite-debug2") |
175 |
849 |
<< " process base : " << curr << std::endl; |
176 |
|
// curr = Rewriter::rewrite(curr); |
177 |
|
// Trace("builtin-rewrite-debug2") |
178 |
|
// << " rewriten base : " << curr << std::endl; |
179 |
|
// Complex Boolean return cases, in which |
180 |
|
// (1) lambda x. (= x v1) v ... becomes |
181 |
|
// lambda x. (ite (= x v1) true [...]) |
182 |
|
// |
183 |
|
// (2) lambda x. (not (= x v1)) ^ ... becomes |
184 |
|
// lambda x. (ite (= x v1) false [...]) |
185 |
|
// |
186 |
|
// Note the negated cases of the lhs of the OR/AND operators above are |
187 |
|
// handled by pushing the recursion to the then-branch, with the |
188 |
|
// else-branch being the constant value. For example, the negated (1) |
189 |
|
// would be |
190 |
|
// (1') lambda x. (not (= x v1)) v ... becomes |
191 |
|
// lambda x. (ite (= x v1) [...] true) |
192 |
|
// thus requiring the rest of the disjunction to be further processed in |
193 |
|
// the then-branch as the current value. |
194 |
849 |
bool pol = curr[0].getKind() != kind::NOT; |
195 |
849 |
bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR); |
196 |
849 |
index_eq = pol ? curr[0] : curr[0][0]; |
197 |
|
// processed : the value that is determined by the first child of curr |
198 |
|
// remainder : the remaining children of curr |
199 |
1526 |
Node processed, remainder; |
200 |
|
// the value is the polarity of the first child or its inverse if we are |
201 |
|
// in the inverted case |
202 |
849 |
processed = nm->mkConst(!inverted ? pol : !pol); |
203 |
|
// build an OR/AND with the remaining components |
204 |
849 |
if (curr.getNumChildren() == 2) |
205 |
|
{ |
206 |
819 |
remainder = curr[1]; |
207 |
|
} |
208 |
|
else |
209 |
|
{ |
210 |
60 |
std::vector<Node> remainderNodes{curr.begin() + 1, curr.end()}; |
211 |
30 |
remainder = nm->mkNode(ck, remainderNodes); |
212 |
|
} |
213 |
849 |
if (inverted) |
214 |
|
{ |
215 |
522 |
curr_val = remainder; |
216 |
522 |
next = processed; |
217 |
|
// If the lambda contains more variables than the one being currently |
218 |
|
// processed, the current value can be non-constant, since it'll be |
219 |
|
// processed recursively below. Otherwise we fail. |
220 |
522 |
if (rec_bvl.isNull() && !curr_val.isConst()) |
221 |
|
{ |
222 |
344 |
Trace("builtin-rewrite-debug2") |
223 |
172 |
<< "...non-const curr_val " << curr_val << "\n"; |
224 |
172 |
return Node::null(); |
225 |
|
} |
226 |
|
} |
227 |
|
else |
228 |
|
{ |
229 |
327 |
curr_val = processed; |
230 |
327 |
next = remainder; |
231 |
|
} |
232 |
677 |
Trace("builtin-rewrite-debug2") << " index_eq : " << index_eq << "\n"; |
233 |
677 |
Trace("builtin-rewrite-debug2") << " curr_val : " << curr_val << "\n"; |
234 |
1354 |
Trace("builtin-rewrite-debug2") << " next : " << next << std::endl; |
235 |
|
} |
236 |
|
else |
237 |
|
{ |
238 |
4058 |
Trace("builtin-rewrite-debug2") |
239 |
2029 |
<< " process base : " << curr << std::endl; |
240 |
|
// Simple Boolean return cases, in which |
241 |
|
// (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false) |
242 |
|
// (2) lambda x. v becomes lambda x. (ite (= x v) true false) |
243 |
|
// Note the negateg cases of the bodies above are also handled. |
244 |
2029 |
bool pol = ck != kind::NOT; |
245 |
2029 |
index_eq = pol ? curr : curr[0]; |
246 |
2029 |
curr_val = nm->mkConst(pol); |
247 |
2029 |
next = nm->mkConst(!pol); |
248 |
|
} |
249 |
|
|
250 |
|
// [2] We ensure that "index_eq" is an equality, if possible. |
251 |
7434 |
if (index_eq.getKind() != kind::EQUAL) |
252 |
|
{ |
253 |
2398 |
bool pol = index_eq.getKind() != kind::NOT; |
254 |
3825 |
Node indexEqAtom = pol ? index_eq : index_eq[0]; |
255 |
2398 |
if (indexEqAtom.getKind() == kind::BOUND_VARIABLE) |
256 |
|
{ |
257 |
1990 |
if (!indexEqAtom.getType().isBoolean()) |
258 |
|
{ |
259 |
|
// Catches default case of non-Boolean variable, e.g. |
260 |
|
// lambda x : Int. x. In this case, it is not canonical and we fail. |
261 |
1126 |
Trace("builtin-rewrite-debug2") |
262 |
563 |
<< " ...non-Boolean variable." << std::endl; |
263 |
563 |
return Node::null(); |
264 |
|
} |
265 |
|
// Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as |
266 |
|
// lambda x. (ite (= x true) t s) |
267 |
1427 |
index_eq = indexEqAtom.eqNode(nm->mkConst(pol)); |
268 |
|
} |
269 |
|
else |
270 |
|
{ |
271 |
|
// non-equality condition |
272 |
816 |
Trace("builtin-rewrite-debug2") |
273 |
408 |
<< " ...non-equality condition." << std::endl; |
274 |
408 |
return Node::null(); |
275 |
|
} |
276 |
|
} |
277 |
5036 |
else if (Rewriter::rewrite(index_eq) != index_eq) |
278 |
|
{ |
279 |
|
// equality must be oriented correctly based on rewriter |
280 |
34 |
Trace("builtin-rewrite-debug2") |
281 |
17 |
<< " ...equality not oriented properly." << std::endl; |
282 |
17 |
return Node::null(); |
283 |
|
} |
284 |
|
|
285 |
|
// [3] We ensure that "index_eq" is an equality that is equivalent to |
286 |
|
// "first_arg" = "curr_index", where curr_index is a constant, and |
287 |
|
// "first_arg" is the current argument we are processing, if possible. |
288 |
12025 |
Node curr_index; |
289 |
8373 |
for (unsigned r = 0; r < 2; r++) |
290 |
|
{ |
291 |
10057 |
Node arg = index_eq[r]; |
292 |
10057 |
Node val = index_eq[1 - r]; |
293 |
8130 |
if (arg == first_arg) |
294 |
|
{ |
295 |
6203 |
if (!val.isConst()) |
296 |
|
{ |
297 |
|
// non-constant value |
298 |
470 |
Trace("builtin-rewrite-debug2") |
299 |
235 |
<< " ...non-constant value for argument\n."; |
300 |
235 |
return Node::null(); |
301 |
|
} |
302 |
|
else |
303 |
|
{ |
304 |
5968 |
curr_index = val; |
305 |
11936 |
Trace("builtin-rewrite-debug2") |
306 |
5968 |
<< " arg " << arg << " -> " << val << std::endl; |
307 |
5968 |
break; |
308 |
|
} |
309 |
|
} |
310 |
|
} |
311 |
6211 |
if (curr_index.isNull()) |
312 |
|
{ |
313 |
486 |
Trace("builtin-rewrite-debug2") |
314 |
243 |
<< " ...could not infer index value." << std::endl; |
315 |
243 |
return Node::null(); |
316 |
|
} |
317 |
|
|
318 |
|
// [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the |
319 |
|
// remaining arguments (rec_bvl). |
320 |
5968 |
if (!rec_bvl.isNull()) |
321 |
|
{ |
322 |
1946 |
curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val); |
323 |
1946 |
Trace("builtin-rewrite-debug") << push; |
324 |
1946 |
Trace("builtin-rewrite-debug2") << push; |
325 |
1946 |
curr_val = getArrayRepresentationForLambdaRec(curr_val, retType); |
326 |
1946 |
Trace("builtin-rewrite-debug") << pop; |
327 |
1946 |
Trace("builtin-rewrite-debug2") << pop; |
328 |
1946 |
if (curr_val.isNull()) |
329 |
|
{ |
330 |
778 |
Trace("builtin-rewrite-debug2") |
331 |
389 |
<< " ...failed to recursively find value." << std::endl; |
332 |
389 |
return Node::null(); |
333 |
|
} |
334 |
|
} |
335 |
11158 |
Trace("builtin-rewrite-debug2") |
336 |
5579 |
<< " ...condition is index " << curr_val << std::endl; |
337 |
|
|
338 |
|
// [5] Add the entry |
339 |
5579 |
conds.push_back(curr_index); |
340 |
5579 |
vals.push_back(curr_val); |
341 |
|
|
342 |
|
// we will now process the remainder |
343 |
5579 |
curr = next; |
344 |
5579 |
ck = curr.getKind(); |
345 |
11158 |
Trace("builtin-rewrite-debug2") |
346 |
5579 |
<< " process remainder : " << curr << std::endl; |
347 |
|
} |
348 |
10372 |
if (!rec_bvl.isNull()) |
349 |
|
{ |
350 |
3528 |
curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr); |
351 |
3528 |
Trace("builtin-rewrite-debug") << push; |
352 |
3528 |
Trace("builtin-rewrite-debug2") << push; |
353 |
3528 |
curr = getArrayRepresentationForLambdaRec(curr, retType); |
354 |
3528 |
Trace("builtin-rewrite-debug") << pop; |
355 |
3528 |
Trace("builtin-rewrite-debug2") << pop; |
356 |
|
} |
357 |
10372 |
if (!curr.isNull() && curr.isConst()) |
358 |
|
{ |
359 |
|
// compute the return type |
360 |
8936 |
TypeNode array_type = retType; |
361 |
10014 |
for (size_t i = 0; i < size; i++) |
362 |
|
{ |
363 |
5546 |
size_t index = (size - 1) - i; |
364 |
5546 |
array_type = nm->mkArrayType(n[0][index].getType(), array_type); |
365 |
|
} |
366 |
8936 |
Trace("builtin-rewrite-debug2") |
367 |
8936 |
<< " make array store all " << curr.getType() |
368 |
4468 |
<< " annotated : " << array_type << std::endl; |
369 |
4468 |
Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType())); |
370 |
4468 |
curr = nm->mkConst(ArrayStoreAll(array_type, curr)); |
371 |
4468 |
Trace("builtin-rewrite-debug2") << " build array..." << std::endl; |
372 |
|
// can only build if default value is constant (since array store all must |
373 |
|
// be constant) |
374 |
8936 |
Trace("builtin-rewrite-debug2") |
375 |
4468 |
<< " got constant base " << curr << std::endl; |
376 |
4468 |
Trace("builtin-rewrite-debug2") << " conditions " << conds << std::endl; |
377 |
4468 |
Trace("builtin-rewrite-debug2") << " values " << vals << std::endl; |
378 |
|
// construct store chain |
379 |
9873 |
for (size_t i = 0, numCond = conds.size(); i < numCond; i++) |
380 |
|
{ |
381 |
5405 |
size_t ii = (numCond - 1) - i; |
382 |
5405 |
Assert(conds[ii].getType().isSubtypeOf(first_arg.getType())); |
383 |
5405 |
curr = nm->mkNode(kind::STORE, curr, conds[ii], vals[ii]); |
384 |
|
} |
385 |
8936 |
Trace("builtin-rewrite-debug") |
386 |
4468 |
<< "...got array " << curr << " for " << n << std::endl; |
387 |
4468 |
return curr; |
388 |
|
} |
389 |
11808 |
Trace("builtin-rewrite-debug") |
390 |
5904 |
<< "...failed to get array (cannot get constant default value)" |
391 |
5904 |
<< std::endl; |
392 |
5904 |
return Node::null(); |
393 |
|
} |
394 |
|
|
395 |
6925 |
Node FunctionConst::getArrayRepresentationForLambda(TNode n) |
396 |
|
{ |
397 |
6925 |
Assert(n.getKind() == kind::LAMBDA); |
398 |
|
// must carry the overall return type to deal with cases like (lambda ((x Int) |
399 |
|
// (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else |
400 |
|
// case above should be (arraystoreall (Array Int Real) 0.0) |
401 |
13850 |
Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType()); |
402 |
6925 |
if (anode.isNull()) |
403 |
|
{ |
404 |
4962 |
return anode; |
405 |
|
} |
406 |
|
// must rewrite it to make canonical |
407 |
1963 |
return Rewriter::rewrite(anode); |
408 |
|
} |
409 |
|
|
410 |
|
} // namespace uf |
411 |
|
} // namespace theory |
412 |
31140 |
} // namespace cvc5 |