1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Mudathir Mohamed, Aina Niemetz |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Normal form for bag constants. |
14 |
|
*/ |
15 |
|
#include "normal_form.h" |
16 |
|
|
17 |
|
#include "expr/emptybag.h" |
18 |
|
#include "theory/sets/normal_form.h" |
19 |
|
#include "theory/type_enumerator.h" |
20 |
|
#include "util/rational.h" |
21 |
|
|
22 |
|
using namespace cvc5::kind; |
23 |
|
|
24 |
|
namespace cvc5 { |
25 |
|
namespace theory { |
26 |
|
namespace bags { |
27 |
|
|
28 |
121 |
bool NormalForm::isConstant(TNode n) |
29 |
|
{ |
30 |
121 |
if (n.getKind() == EMPTYBAG) |
31 |
|
{ |
32 |
|
// empty bags are already normalized |
33 |
|
return true; |
34 |
|
} |
35 |
121 |
if (n.getKind() == MK_BAG) |
36 |
|
{ |
37 |
|
// see the implementation in MkBagTypeRule::computeIsConst |
38 |
|
return n.isConst(); |
39 |
|
} |
40 |
121 |
if (n.getKind() == UNION_DISJOINT) |
41 |
|
{ |
42 |
121 |
if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst())) |
43 |
|
{ |
44 |
|
// the first child is not a constant |
45 |
59 |
return false; |
46 |
|
} |
47 |
|
// store the previous element to check the ordering of elements |
48 |
124 |
Node previousElement = n[0][0]; |
49 |
124 |
Node current = n[1]; |
50 |
78 |
while (current.getKind() == UNION_DISJOINT) |
51 |
|
{ |
52 |
8 |
if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst())) |
53 |
|
{ |
54 |
|
// the current element is not a constant |
55 |
|
return false; |
56 |
|
} |
57 |
8 |
if (previousElement >= current[0][0]) |
58 |
|
{ |
59 |
|
// the ordering is violated |
60 |
|
return false; |
61 |
|
} |
62 |
8 |
previousElement = current[0][0]; |
63 |
8 |
current = current[1]; |
64 |
|
} |
65 |
|
// check last element |
66 |
62 |
if (!(current.getKind() == kind::MK_BAG && current.isConst())) |
67 |
|
{ |
68 |
|
// the last element is not a constant |
69 |
|
return false; |
70 |
|
} |
71 |
62 |
if (previousElement >= current[0]) |
72 |
|
{ |
73 |
|
// the ordering is violated |
74 |
10 |
return false; |
75 |
|
} |
76 |
52 |
return true; |
77 |
|
} |
78 |
|
|
79 |
|
// only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be |
80 |
|
// constants |
81 |
|
return false; |
82 |
|
} |
83 |
|
|
84 |
1514 |
bool NormalForm::areChildrenConstants(TNode n) |
85 |
|
{ |
86 |
3521 |
return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); |
87 |
|
} |
88 |
|
|
89 |
150 |
Node NormalForm::evaluate(TNode n) |
90 |
|
{ |
91 |
150 |
Assert(areChildrenConstants(n)); |
92 |
150 |
if (n.isConst()) |
93 |
|
{ |
94 |
|
// a constant node is already in a normal form |
95 |
6 |
return n; |
96 |
|
} |
97 |
144 |
switch (n.getKind()) |
98 |
|
{ |
99 |
11 |
case MK_BAG: return evaluateMakeBag(n); |
100 |
58 |
case BAG_COUNT: return evaluateBagCount(n); |
101 |
7 |
case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); |
102 |
20 |
case UNION_DISJOINT: return evaluateUnionDisjoint(n); |
103 |
6 |
case UNION_MAX: return evaluateUnionMax(n); |
104 |
3 |
case INTERSECTION_MIN: return evaluateIntersectionMin(n); |
105 |
3 |
case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); |
106 |
2 |
case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); |
107 |
6 |
case BAG_CHOOSE: return evaluateChoose(n); |
108 |
6 |
case BAG_CARD: return evaluateCard(n); |
109 |
8 |
case BAG_IS_SINGLETON: return evaluateIsSingleton(n); |
110 |
6 |
case BAG_FROM_SET: return evaluateFromSet(n); |
111 |
6 |
case BAG_TO_SET: return evaluateToSet(n); |
112 |
2 |
case BAG_MAP: return evaluateBagMap(n); |
113 |
|
default: break; |
114 |
|
} |
115 |
|
Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n |
116 |
|
<< std::endl; |
117 |
|
} |
118 |
|
|
119 |
|
template <typename T1, typename T2, typename T3, typename T4, typename T5> |
120 |
34 |
Node NormalForm::evaluateBinaryOperation(const TNode& n, |
121 |
|
T1&& equal, |
122 |
|
T2&& less, |
123 |
|
T3&& greaterOrEqual, |
124 |
|
T4&& remainderOfA, |
125 |
|
T5&& remainderOfB) |
126 |
|
{ |
127 |
68 |
std::map<Node, Rational> elementsA = getBagElements(n[0]); |
128 |
68 |
std::map<Node, Rational> elementsB = getBagElements(n[1]); |
129 |
68 |
std::map<Node, Rational> elements; |
130 |
|
|
131 |
34 |
std::map<Node, Rational>::const_iterator itA = elementsA.begin(); |
132 |
34 |
std::map<Node, Rational>::const_iterator itB = elementsB.begin(); |
133 |
|
|
134 |
102 |
Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " |
135 |
68 |
<< n.getKind() << "] " << std::endl |
136 |
34 |
<< "elements A: " << elementsA << std::endl |
137 |
34 |
<< "elements B: " << elementsB << std::endl; |
138 |
|
|
139 |
118 |
while (itA != elementsA.end() && itB != elementsB.end()) |
140 |
|
{ |
141 |
42 |
if (itA->first == itB->first) |
142 |
|
{ |
143 |
22 |
equal(elements, itA, itB); |
144 |
22 |
itA++; |
145 |
22 |
itB++; |
146 |
|
} |
147 |
20 |
else if (itA->first < itB->first) |
148 |
|
{ |
149 |
6 |
less(elements, itA, itB); |
150 |
6 |
itA++; |
151 |
|
} |
152 |
|
else |
153 |
|
{ |
154 |
14 |
greaterOrEqual(elements, itA, itB); |
155 |
14 |
itB++; |
156 |
|
} |
157 |
|
} |
158 |
|
|
159 |
|
// handle the remaining elements from A |
160 |
34 |
remainderOfA(elements, elementsA, itA); |
161 |
|
// handle the remaining elements from B |
162 |
34 |
remainderOfA(elements, elementsB, itB); |
163 |
|
|
164 |
34 |
Trace("bags-evaluate") << "elements: " << elements << std::endl; |
165 |
34 |
Node bag = constructConstantBagFromElements(n.getType(), elements); |
166 |
34 |
Trace("bags-evaluate") << "bag: " << bag << std::endl; |
167 |
68 |
return bag; |
168 |
|
} |
169 |
|
|
170 |
149 |
std::map<Node, Rational> NormalForm::getBagElements(TNode n) |
171 |
|
{ |
172 |
149 |
Assert(n.isConst()) << "node " << n << " is not in a normal form" |
173 |
|
<< std::endl; |
174 |
149 |
std::map<Node, Rational> elements; |
175 |
149 |
if (n.getKind() == EMPTYBAG) |
176 |
|
{ |
177 |
28 |
return elements; |
178 |
|
} |
179 |
241 |
while (n.getKind() == kind::UNION_DISJOINT) |
180 |
|
{ |
181 |
60 |
Assert(n[0].getKind() == kind::MK_BAG); |
182 |
120 |
Node element = n[0][0]; |
183 |
120 |
Rational count = n[0][1].getConst<Rational>(); |
184 |
60 |
elements[element] = count; |
185 |
60 |
n = n[1]; |
186 |
|
} |
187 |
121 |
Assert(n.getKind() == kind::MK_BAG); |
188 |
242 |
Node lastElement = n[0]; |
189 |
242 |
Rational lastCount = n[1].getConst<Rational>(); |
190 |
121 |
elements[lastElement] = lastCount; |
191 |
121 |
return elements; |
192 |
|
} |
193 |
|
|
194 |
51 |
Node NormalForm::constructConstantBagFromElements( |
195 |
|
TypeNode t, const std::map<Node, Rational>& elements) |
196 |
|
{ |
197 |
51 |
Assert(t.isBag()); |
198 |
51 |
NodeManager* nm = NodeManager::currentNM(); |
199 |
51 |
if (elements.empty()) |
200 |
|
{ |
201 |
8 |
return nm->mkConst(EmptyBag(t)); |
202 |
|
} |
203 |
86 |
TypeNode elementType = t.getBagElementType(); |
204 |
43 |
std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin(); |
205 |
|
Node bag = |
206 |
86 |
nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); |
207 |
99 |
while (++it != elements.rend()) |
208 |
|
{ |
209 |
|
Node n = |
210 |
56 |
nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); |
211 |
28 |
bag = nm->mkNode(UNION_DISJOINT, n, bag); |
212 |
|
} |
213 |
43 |
return bag; |
214 |
|
} |
215 |
|
|
216 |
30 |
Node NormalForm::constructBagFromElements(TypeNode t, |
217 |
|
const std::map<Node, Node>& elements) |
218 |
|
{ |
219 |
30 |
Assert(t.isBag()); |
220 |
30 |
NodeManager* nm = NodeManager::currentNM(); |
221 |
30 |
if (elements.empty()) |
222 |
|
{ |
223 |
2 |
return nm->mkConst(EmptyBag(t)); |
224 |
|
} |
225 |
56 |
TypeNode elementType = t.getBagElementType(); |
226 |
28 |
std::map<Node, Node>::const_reverse_iterator it = elements.rbegin(); |
227 |
56 |
Node bag = nm->mkBag(elementType, it->first, it->second); |
228 |
56 |
while (++it != elements.rend()) |
229 |
|
{ |
230 |
28 |
Node n = nm->mkBag(elementType, it->first, it->second); |
231 |
14 |
bag = nm->mkNode(UNION_DISJOINT, n, bag); |
232 |
|
} |
233 |
28 |
return bag; |
234 |
|
} |
235 |
|
|
236 |
11 |
Node NormalForm::evaluateMakeBag(TNode n) |
237 |
|
{ |
238 |
|
// the case where n is const should be handled earlier. |
239 |
|
// here we handle the case where the multiplicity is zero or negative |
240 |
11 |
Assert(n.getKind() == MK_BAG && !n.isConst() |
241 |
|
&& n[1].getConst<Rational>().sgn() < 1); |
242 |
11 |
Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); |
243 |
11 |
return emptybag; |
244 |
|
} |
245 |
|
|
246 |
58 |
Node NormalForm::evaluateBagCount(TNode n) |
247 |
|
{ |
248 |
58 |
Assert(n.getKind() == BAG_COUNT); |
249 |
|
// Examples |
250 |
|
// -------- |
251 |
|
// - (bag.count "x" (emptybag String)) = 0 |
252 |
|
// - (bag.count "x" (mkBag "y" 5)) = 0 |
253 |
|
// - (bag.count "x" (mkBag "x" 4)) = 4 |
254 |
|
// - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4 |
255 |
|
// - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0 |
256 |
|
|
257 |
116 |
std::map<Node, Rational> elements = getBagElements(n[1]); |
258 |
58 |
std::map<Node, Rational>::iterator it = elements.find(n[0]); |
259 |
|
|
260 |
58 |
NodeManager* nm = NodeManager::currentNM(); |
261 |
58 |
if (it != elements.end()) |
262 |
|
{ |
263 |
74 |
Node count = nm->mkConst(it->second); |
264 |
37 |
return count; |
265 |
|
} |
266 |
21 |
return nm->mkConst(Rational(0)); |
267 |
|
} |
268 |
|
|
269 |
7 |
Node NormalForm::evaluateDuplicateRemoval(TNode n) |
270 |
|
{ |
271 |
7 |
Assert(n.getKind() == DUPLICATE_REMOVAL); |
272 |
|
|
273 |
|
// Examples |
274 |
|
// -------- |
275 |
|
// - (duplicate_removal (emptybag String)) = (emptybag String) |
276 |
|
// - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1) |
277 |
|
// - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = |
278 |
|
// (disjoint_union (mkBag "x" 1) (mkBag "y" 1) |
279 |
|
|
280 |
14 |
std::map<Node, Rational> oldElements = getBagElements(n[0]); |
281 |
|
// copy elements from the old bag |
282 |
14 |
std::map<Node, Rational> newElements(oldElements); |
283 |
14 |
Rational one = Rational(1); |
284 |
7 |
std::map<Node, Rational>::iterator it; |
285 |
14 |
for (it = newElements.begin(); it != newElements.end(); it++) |
286 |
|
{ |
287 |
7 |
it->second = one; |
288 |
|
} |
289 |
7 |
Node bag = constructConstantBagFromElements(n[0].getType(), newElements); |
290 |
14 |
return bag; |
291 |
|
} |
292 |
|
|
293 |
20 |
Node NormalForm::evaluateUnionDisjoint(TNode n) |
294 |
|
{ |
295 |
20 |
Assert(n.getKind() == UNION_DISJOINT); |
296 |
|
// Example |
297 |
|
// ------- |
298 |
|
// input: (union_disjoint A B) |
299 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
300 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
301 |
|
// output: |
302 |
|
// (union_disjoint A B) |
303 |
|
// where A = (MK_BAG "x" 7) |
304 |
|
// B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) |
305 |
|
|
306 |
|
auto equal = [](std::map<Node, Rational>& elements, |
307 |
|
std::map<Node, Rational>::const_iterator& itA, |
308 |
8 |
std::map<Node, Rational>::const_iterator& itB) { |
309 |
|
// compute the sum of the multiplicities |
310 |
8 |
elements[itA->first] = itA->second + itB->second; |
311 |
8 |
}; |
312 |
|
|
313 |
|
auto less = [](std::map<Node, Rational>& elements, |
314 |
|
std::map<Node, Rational>::const_iterator& itA, |
315 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
316 |
|
// add the element to the result |
317 |
4 |
elements[itA->first] = itA->second; |
318 |
4 |
}; |
319 |
|
|
320 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
321 |
|
std::map<Node, Rational>::const_iterator& itA, |
322 |
6 |
std::map<Node, Rational>::const_iterator& itB) { |
323 |
|
// add the element to the result |
324 |
6 |
elements[itB->first] = itB->second; |
325 |
6 |
}; |
326 |
|
|
327 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
328 |
|
std::map<Node, Rational>& elementsA, |
329 |
52 |
std::map<Node, Rational>::const_iterator& itA) { |
330 |
|
// append the remainder of A |
331 |
64 |
while (itA != elementsA.end()) |
332 |
|
{ |
333 |
12 |
elements[itA->first] = itA->second; |
334 |
12 |
itA++; |
335 |
|
} |
336 |
40 |
}; |
337 |
|
|
338 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
339 |
|
std::map<Node, Rational>& elementsB, |
340 |
|
std::map<Node, Rational>::const_iterator& itB) { |
341 |
|
// append the remainder of B |
342 |
|
while (itB != elementsB.end()) |
343 |
|
{ |
344 |
|
elements[itB->first] = itB->second; |
345 |
|
itB++; |
346 |
|
} |
347 |
|
}; |
348 |
|
|
349 |
|
return evaluateBinaryOperation( |
350 |
20 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
351 |
|
} |
352 |
|
|
353 |
6 |
Node NormalForm::evaluateUnionMax(TNode n) |
354 |
|
{ |
355 |
6 |
Assert(n.getKind() == UNION_MAX); |
356 |
|
// Example |
357 |
|
// ------- |
358 |
|
// input: (union_max A B) |
359 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
360 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
361 |
|
// output: |
362 |
|
// (union_disjoint A B) |
363 |
|
// where A = (MK_BAG "x" 4) |
364 |
|
// B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) |
365 |
|
|
366 |
|
auto equal = [](std::map<Node, Rational>& elements, |
367 |
|
std::map<Node, Rational>::const_iterator& itA, |
368 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
369 |
|
// compute the maximum multiplicity |
370 |
4 |
elements[itA->first] = std::max(itA->second, itB->second); |
371 |
4 |
}; |
372 |
|
|
373 |
|
auto less = [](std::map<Node, Rational>& elements, |
374 |
|
std::map<Node, Rational>::const_iterator& itA, |
375 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
376 |
|
// add to the result |
377 |
2 |
elements[itA->first] = itA->second; |
378 |
2 |
}; |
379 |
|
|
380 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
381 |
|
std::map<Node, Rational>::const_iterator& itA, |
382 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
383 |
|
// add to the result |
384 |
2 |
elements[itB->first] = itB->second; |
385 |
2 |
}; |
386 |
|
|
387 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
388 |
|
std::map<Node, Rational>& elementsA, |
389 |
16 |
std::map<Node, Rational>::const_iterator& itA) { |
390 |
|
// append the remainder of A |
391 |
20 |
while (itA != elementsA.end()) |
392 |
|
{ |
393 |
4 |
elements[itA->first] = itA->second; |
394 |
4 |
itA++; |
395 |
|
} |
396 |
12 |
}; |
397 |
|
|
398 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
399 |
|
std::map<Node, Rational>& elementsB, |
400 |
|
std::map<Node, Rational>::const_iterator& itB) { |
401 |
|
// append the remainder of B |
402 |
|
while (itB != elementsB.end()) |
403 |
|
{ |
404 |
|
elements[itB->first] = itB->second; |
405 |
|
itB++; |
406 |
|
} |
407 |
|
}; |
408 |
|
|
409 |
|
return evaluateBinaryOperation( |
410 |
6 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
411 |
|
} |
412 |
|
|
413 |
3 |
Node NormalForm::evaluateIntersectionMin(TNode n) |
414 |
|
{ |
415 |
3 |
Assert(n.getKind() == INTERSECTION_MIN); |
416 |
|
// Example |
417 |
|
// ------- |
418 |
|
// input: (intersectionMin A B) |
419 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
420 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
421 |
|
// output: |
422 |
|
// (MK_BAG "x" 3) |
423 |
|
|
424 |
|
auto equal = [](std::map<Node, Rational>& elements, |
425 |
|
std::map<Node, Rational>::const_iterator& itA, |
426 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
427 |
|
// compute the minimum multiplicity |
428 |
4 |
elements[itA->first] = std::min(itA->second, itB->second); |
429 |
4 |
}; |
430 |
|
|
431 |
|
auto less = [](std::map<Node, Rational>& elements, |
432 |
|
std::map<Node, Rational>::const_iterator& itA, |
433 |
|
std::map<Node, Rational>::const_iterator& itB) { |
434 |
|
// do nothing |
435 |
|
}; |
436 |
|
|
437 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
438 |
|
std::map<Node, Rational>::const_iterator& itA, |
439 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
440 |
|
// do nothing |
441 |
2 |
}; |
442 |
|
|
443 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
444 |
|
std::map<Node, Rational>& elementsA, |
445 |
6 |
std::map<Node, Rational>::const_iterator& itA) { |
446 |
|
// do nothing |
447 |
6 |
}; |
448 |
|
|
449 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
450 |
|
std::map<Node, Rational>& elementsB, |
451 |
|
std::map<Node, Rational>::const_iterator& itB) { |
452 |
|
// do nothing |
453 |
|
}; |
454 |
|
|
455 |
|
return evaluateBinaryOperation( |
456 |
3 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
457 |
|
} |
458 |
|
|
459 |
3 |
Node NormalForm::evaluateDifferenceSubtract(TNode n) |
460 |
|
{ |
461 |
3 |
Assert(n.getKind() == DIFFERENCE_SUBTRACT); |
462 |
|
// Example |
463 |
|
// ------- |
464 |
|
// input: (difference_subtract A B) |
465 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
466 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
467 |
|
// output: |
468 |
|
// (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2)) |
469 |
|
|
470 |
|
auto equal = [](std::map<Node, Rational>& elements, |
471 |
|
std::map<Node, Rational>::const_iterator& itA, |
472 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
473 |
|
// subtract the multiplicities |
474 |
4 |
elements[itA->first] = itA->second - itB->second; |
475 |
4 |
}; |
476 |
|
|
477 |
|
auto less = [](std::map<Node, Rational>& elements, |
478 |
|
std::map<Node, Rational>::const_iterator& itA, |
479 |
|
std::map<Node, Rational>::const_iterator& itB) { |
480 |
|
// itA->first is not in B, so we add it to the difference subtract |
481 |
|
elements[itA->first] = itA->second; |
482 |
|
}; |
483 |
|
|
484 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
485 |
|
std::map<Node, Rational>::const_iterator& itA, |
486 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
487 |
|
// itB->first is not in A, so we just skip it |
488 |
2 |
}; |
489 |
|
|
490 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
491 |
|
std::map<Node, Rational>& elementsA, |
492 |
8 |
std::map<Node, Rational>::const_iterator& itA) { |
493 |
|
// append the remainder of A |
494 |
10 |
while (itA != elementsA.end()) |
495 |
|
{ |
496 |
2 |
elements[itA->first] = itA->second; |
497 |
2 |
itA++; |
498 |
|
} |
499 |
6 |
}; |
500 |
|
|
501 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
502 |
|
std::map<Node, Rational>& elementsB, |
503 |
|
std::map<Node, Rational>::const_iterator& itB) { |
504 |
|
// do nothing |
505 |
|
}; |
506 |
|
|
507 |
|
return evaluateBinaryOperation( |
508 |
3 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
509 |
|
} |
510 |
|
|
511 |
2 |
Node NormalForm::evaluateDifferenceRemove(TNode n) |
512 |
|
{ |
513 |
2 |
Assert(n.getKind() == DIFFERENCE_REMOVE); |
514 |
|
// Example |
515 |
|
// ------- |
516 |
|
// input: (difference_subtract A B) |
517 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
518 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
519 |
|
// output: |
520 |
|
// (MK_BAG "z" 2) |
521 |
|
|
522 |
|
auto equal = [](std::map<Node, Rational>& elements, |
523 |
|
std::map<Node, Rational>::const_iterator& itA, |
524 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
525 |
|
// skip the shared element by doing nothing |
526 |
2 |
}; |
527 |
|
|
528 |
|
auto less = [](std::map<Node, Rational>& elements, |
529 |
|
std::map<Node, Rational>::const_iterator& itA, |
530 |
|
std::map<Node, Rational>::const_iterator& itB) { |
531 |
|
// itA->first is not in B, so we add it to the difference remove |
532 |
|
elements[itA->first] = itA->second; |
533 |
|
}; |
534 |
|
|
535 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
536 |
|
std::map<Node, Rational>::const_iterator& itA, |
537 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
538 |
|
// itB->first is not in A, so we just skip it |
539 |
2 |
}; |
540 |
|
|
541 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
542 |
|
std::map<Node, Rational>& elementsA, |
543 |
6 |
std::map<Node, Rational>::const_iterator& itA) { |
544 |
|
// append the remainder of A |
545 |
8 |
while (itA != elementsA.end()) |
546 |
|
{ |
547 |
2 |
elements[itA->first] = itA->second; |
548 |
2 |
itA++; |
549 |
|
} |
550 |
4 |
}; |
551 |
|
|
552 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
553 |
|
std::map<Node, Rational>& elementsB, |
554 |
|
std::map<Node, Rational>::const_iterator& itB) { |
555 |
|
// do nothing |
556 |
|
}; |
557 |
|
|
558 |
|
return evaluateBinaryOperation( |
559 |
2 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
560 |
|
} |
561 |
|
|
562 |
6 |
Node NormalForm::evaluateChoose(TNode n) |
563 |
|
{ |
564 |
6 |
Assert(n.getKind() == BAG_CHOOSE); |
565 |
|
// Examples |
566 |
|
// -------- |
567 |
|
// - (choose (emptyBag String)) = "" // the empty string which is the first |
568 |
|
// element returned by the type enumerator |
569 |
|
// - (choose (MK_BAG "x" 4)) = "x" |
570 |
|
// - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x" |
571 |
|
// deterministically return the first element |
572 |
|
|
573 |
6 |
if (n[0].getKind() == EMPTYBAG) |
574 |
|
{ |
575 |
4 |
TypeNode elementType = n[0].getType().getBagElementType(); |
576 |
4 |
TypeEnumerator typeEnumerator(elementType); |
577 |
|
// get the first value from the typeEnumerator |
578 |
4 |
Node element = *typeEnumerator; |
579 |
2 |
return element; |
580 |
|
} |
581 |
|
|
582 |
4 |
if (n[0].getKind() == MK_BAG) |
583 |
|
{ |
584 |
2 |
return n[0][0]; |
585 |
|
} |
586 |
2 |
Assert(n[0].getKind() == UNION_DISJOINT); |
587 |
|
// return the first element |
588 |
|
// e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) |
589 |
2 |
return n[0][0][0]; |
590 |
|
} |
591 |
|
|
592 |
6 |
Node NormalForm::evaluateCard(TNode n) |
593 |
|
{ |
594 |
6 |
Assert(n.getKind() == BAG_CARD); |
595 |
|
// Examples |
596 |
|
// -------- |
597 |
|
// - (card (emptyBag String)) = 0 |
598 |
|
// - (choose (MK_BAG "x" 4)) = 4 |
599 |
|
// - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5 |
600 |
|
|
601 |
12 |
std::map<Node, Rational> elements = getBagElements(n[0]); |
602 |
12 |
Rational sum(0); |
603 |
12 |
for (std::pair<Node, Rational> element : elements) |
604 |
|
{ |
605 |
6 |
sum += element.second; |
606 |
|
} |
607 |
|
|
608 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
609 |
6 |
Node sumNode = nm->mkConst(sum); |
610 |
12 |
return sumNode; |
611 |
|
} |
612 |
|
|
613 |
8 |
Node NormalForm::evaluateIsSingleton(TNode n) |
614 |
|
{ |
615 |
8 |
Assert(n.getKind() == BAG_IS_SINGLETON); |
616 |
|
// Examples |
617 |
|
// -------- |
618 |
|
// - (bag.is_singleton (emptyBag String)) = false |
619 |
|
// - (bag.is_singleton (MK_BAG "x" 1)) = true |
620 |
|
// - (bag.is_singleton (MK_BAG "x" 4)) = false |
621 |
|
// - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false |
622 |
|
|
623 |
8 |
if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne()) |
624 |
|
{ |
625 |
4 |
return NodeManager::currentNM()->mkConst(true); |
626 |
|
} |
627 |
4 |
return NodeManager::currentNM()->mkConst(false); |
628 |
|
} |
629 |
|
|
630 |
6 |
Node NormalForm::evaluateFromSet(TNode n) |
631 |
|
{ |
632 |
6 |
Assert(n.getKind() == BAG_FROM_SET); |
633 |
|
|
634 |
|
// Examples |
635 |
|
// -------- |
636 |
|
// - (bag.from_set (emptyset String)) = (emptybag String) |
637 |
|
// - (bag.from_set (singleton "x")) = (mkBag "x" 1) |
638 |
|
// - (bag.from_set (union (singleton "x") (singleton "y"))) = |
639 |
|
// (disjoint_union (mkBag "x" 1) (mkBag "y" 1)) |
640 |
|
|
641 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
642 |
|
std::set<Node> setElements = |
643 |
12 |
sets::NormalForm::getElementsFromNormalConstant(n[0]); |
644 |
12 |
Rational one = Rational(1); |
645 |
12 |
std::map<Node, Rational> bagElements; |
646 |
12 |
for (const Node& element : setElements) |
647 |
|
{ |
648 |
6 |
bagElements[element] = one; |
649 |
|
} |
650 |
12 |
TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); |
651 |
6 |
Node bag = constructConstantBagFromElements(bagType, bagElements); |
652 |
12 |
return bag; |
653 |
|
} |
654 |
|
|
655 |
6 |
Node NormalForm::evaluateToSet(TNode n) |
656 |
|
{ |
657 |
6 |
Assert(n.getKind() == BAG_TO_SET); |
658 |
|
|
659 |
|
// Examples |
660 |
|
// -------- |
661 |
|
// - (bag.to_set (emptybag String)) = (emptyset String) |
662 |
|
// - (bag.to_set (mkBag "x" 4)) = (singleton "x") |
663 |
|
// - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = |
664 |
|
// (union (singleton "x") (singleton "y"))) |
665 |
|
|
666 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
667 |
12 |
std::map<Node, Rational> bagElements = getBagElements(n[0]); |
668 |
12 |
std::set<Node> setElements; |
669 |
6 |
std::map<Node, Rational>::const_reverse_iterator it; |
670 |
12 |
for (it = bagElements.rbegin(); it != bagElements.rend(); it++) |
671 |
|
{ |
672 |
6 |
setElements.insert(it->first); |
673 |
|
} |
674 |
12 |
TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); |
675 |
6 |
Node set = sets::NormalForm::elementsToSet(setElements, setType); |
676 |
12 |
return set; |
677 |
|
} |
678 |
|
|
679 |
|
|
680 |
2 |
Node NormalForm::evaluateBagMap(TNode n) |
681 |
|
{ |
682 |
2 |
Assert(n.getKind() == BAG_MAP); |
683 |
|
|
684 |
|
// Examples |
685 |
|
// -------- |
686 |
|
// - (bag.map ((lambda ((x String)) "z") |
687 |
|
// (union_disjoint (bag "a" 2) (bag "b" 3)) = |
688 |
|
// (union_disjoint |
689 |
|
// (bag ((lambda ((x String)) "z") "a") 2) |
690 |
|
// (bag ((lambda ((x String)) "z") "b") 3)) = |
691 |
|
// (bag "z" 5) |
692 |
|
|
693 |
4 |
std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]); |
694 |
4 |
std::map<Node, Rational> mappedElements; |
695 |
2 |
std::map<Node, Rational>::iterator it = elements.begin(); |
696 |
2 |
NodeManager* nm = NodeManager::currentNM(); |
697 |
10 |
while (it != elements.end()) |
698 |
|
{ |
699 |
8 |
Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first); |
700 |
4 |
mappedElements[mappedElement] = it->second; |
701 |
4 |
++it; |
702 |
|
} |
703 |
4 |
TypeNode t = nm->mkBagType(n[0].getType().getRangeType()); |
704 |
2 |
Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); |
705 |
4 |
return ret; |
706 |
|
} |
707 |
|
|
708 |
|
} // namespace bags |
709 |
|
} // namespace theory |
710 |
29517 |
} // namespace cvc5 |