1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Mudathir Mohamed, Aina Niemetz |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Normal form for bag constants. |
14 |
|
*/ |
15 |
|
#include "normal_form.h" |
16 |
|
|
17 |
|
#include "theory/sets/normal_form.h" |
18 |
|
#include "theory/type_enumerator.h" |
19 |
|
|
20 |
|
using namespace cvc5::kind; |
21 |
|
|
22 |
|
namespace cvc5 { |
23 |
|
namespace theory { |
24 |
|
namespace bags { |
25 |
|
|
26 |
117 |
bool NormalForm::isConstant(TNode n) |
27 |
|
{ |
28 |
117 |
if (n.getKind() == EMPTYBAG) |
29 |
|
{ |
30 |
|
// empty bags are already normalized |
31 |
|
return true; |
32 |
|
} |
33 |
117 |
if (n.getKind() == MK_BAG) |
34 |
|
{ |
35 |
|
// see the implementation in MkBagTypeRule::computeIsConst |
36 |
|
return n.isConst(); |
37 |
|
} |
38 |
117 |
if (n.getKind() == UNION_DISJOINT) |
39 |
|
{ |
40 |
117 |
if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst())) |
41 |
|
{ |
42 |
|
// the first child is not a constant |
43 |
59 |
return false; |
44 |
|
} |
45 |
|
// store the previous element to check the ordering of elements |
46 |
116 |
Node previousElement = n[0][0]; |
47 |
116 |
Node current = n[1]; |
48 |
74 |
while (current.getKind() == UNION_DISJOINT) |
49 |
|
{ |
50 |
8 |
if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst())) |
51 |
|
{ |
52 |
|
// the current element is not a constant |
53 |
|
return false; |
54 |
|
} |
55 |
8 |
if (previousElement >= current[0][0]) |
56 |
|
{ |
57 |
|
// the ordering is violated |
58 |
|
return false; |
59 |
|
} |
60 |
8 |
previousElement = current[0][0]; |
61 |
8 |
current = current[1]; |
62 |
|
} |
63 |
|
// check last element |
64 |
58 |
if (!(current.getKind() == kind::MK_BAG && current.isConst())) |
65 |
|
{ |
66 |
|
// the last element is not a constant |
67 |
|
return false; |
68 |
|
} |
69 |
58 |
if (previousElement >= current[0]) |
70 |
|
{ |
71 |
|
// the ordering is violated |
72 |
8 |
return false; |
73 |
|
} |
74 |
50 |
return true; |
75 |
|
} |
76 |
|
|
77 |
|
// only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be |
78 |
|
// constants |
79 |
|
return false; |
80 |
|
} |
81 |
|
|
82 |
1509 |
bool NormalForm::areChildrenConstants(TNode n) |
83 |
|
{ |
84 |
3503 |
return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); |
85 |
|
} |
86 |
|
|
87 |
146 |
Node NormalForm::evaluate(TNode n) |
88 |
|
{ |
89 |
146 |
Assert(areChildrenConstants(n)); |
90 |
146 |
if (n.isConst()) |
91 |
|
{ |
92 |
|
// a constant node is already in a normal form |
93 |
6 |
return n; |
94 |
|
} |
95 |
140 |
switch (n.getKind()) |
96 |
|
{ |
97 |
11 |
case MK_BAG: return evaluateMakeBag(n); |
98 |
58 |
case BAG_COUNT: return evaluateBagCount(n); |
99 |
7 |
case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); |
100 |
18 |
case UNION_DISJOINT: return evaluateUnionDisjoint(n); |
101 |
6 |
case UNION_MAX: return evaluateUnionMax(n); |
102 |
3 |
case INTERSECTION_MIN: return evaluateIntersectionMin(n); |
103 |
3 |
case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); |
104 |
2 |
case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); |
105 |
6 |
case BAG_CHOOSE: return evaluateChoose(n); |
106 |
6 |
case BAG_CARD: return evaluateCard(n); |
107 |
8 |
case BAG_IS_SINGLETON: return evaluateIsSingleton(n); |
108 |
6 |
case BAG_FROM_SET: return evaluateFromSet(n); |
109 |
6 |
case BAG_TO_SET: return evaluateToSet(n); |
110 |
|
default: break; |
111 |
|
} |
112 |
|
Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n |
113 |
|
<< std::endl; |
114 |
|
} |
115 |
|
|
116 |
|
template <typename T1, typename T2, typename T3, typename T4, typename T5> |
117 |
32 |
Node NormalForm::evaluateBinaryOperation(const TNode& n, |
118 |
|
T1&& equal, |
119 |
|
T2&& less, |
120 |
|
T3&& greaterOrEqual, |
121 |
|
T4&& remainderOfA, |
122 |
|
T5&& remainderOfB) |
123 |
|
{ |
124 |
64 |
std::map<Node, Rational> elementsA = getBagElements(n[0]); |
125 |
64 |
std::map<Node, Rational> elementsB = getBagElements(n[1]); |
126 |
64 |
std::map<Node, Rational> elements; |
127 |
|
|
128 |
32 |
std::map<Node, Rational>::const_iterator itA = elementsA.begin(); |
129 |
32 |
std::map<Node, Rational>::const_iterator itB = elementsB.begin(); |
130 |
|
|
131 |
96 |
Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " |
132 |
64 |
<< n.getKind() << "] " << std::endl |
133 |
32 |
<< "elements A: " << elementsA << std::endl |
134 |
32 |
<< "elements B: " << elementsB << std::endl; |
135 |
|
|
136 |
112 |
while (itA != elementsA.end() && itB != elementsB.end()) |
137 |
|
{ |
138 |
40 |
if (itA->first == itB->first) |
139 |
|
{ |
140 |
20 |
equal(elements, itA, itB); |
141 |
20 |
itA++; |
142 |
20 |
itB++; |
143 |
|
} |
144 |
20 |
else if (itA->first < itB->first) |
145 |
|
{ |
146 |
6 |
less(elements, itA, itB); |
147 |
6 |
itA++; |
148 |
|
} |
149 |
|
else |
150 |
|
{ |
151 |
14 |
greaterOrEqual(elements, itA, itB); |
152 |
14 |
itB++; |
153 |
|
} |
154 |
|
} |
155 |
|
|
156 |
|
// handle the remaining elements from A |
157 |
32 |
remainderOfA(elements, elementsA, itA); |
158 |
|
// handle the remaining elements from B |
159 |
32 |
remainderOfA(elements, elementsB, itB); |
160 |
|
|
161 |
32 |
Trace("bags-evaluate") << "elements: " << elements << std::endl; |
162 |
32 |
Node bag = constructConstantBagFromElements(n.getType(), elements); |
163 |
32 |
Trace("bags-evaluate") << "bag: " << bag << std::endl; |
164 |
64 |
return bag; |
165 |
|
} |
166 |
|
|
167 |
141 |
std::map<Node, Rational> NormalForm::getBagElements(TNode n) |
168 |
|
{ |
169 |
141 |
Assert(n.isConst()) << "node " << n << " is not in a normal form" |
170 |
|
<< std::endl; |
171 |
141 |
std::map<Node, Rational> elements; |
172 |
141 |
if (n.getKind() == EMPTYBAG) |
173 |
|
{ |
174 |
26 |
return elements; |
175 |
|
} |
176 |
231 |
while (n.getKind() == kind::UNION_DISJOINT) |
177 |
|
{ |
178 |
58 |
Assert(n[0].getKind() == kind::MK_BAG); |
179 |
116 |
Node element = n[0][0]; |
180 |
116 |
Rational count = n[0][1].getConst<Rational>(); |
181 |
58 |
elements[element] = count; |
182 |
58 |
n = n[1]; |
183 |
|
} |
184 |
115 |
Assert(n.getKind() == kind::MK_BAG); |
185 |
230 |
Node lastElement = n[0]; |
186 |
230 |
Rational lastCount = n[1].getConst<Rational>(); |
187 |
115 |
elements[lastElement] = lastCount; |
188 |
115 |
return elements; |
189 |
|
} |
190 |
|
|
191 |
45 |
Node NormalForm::constructConstantBagFromElements( |
192 |
|
TypeNode t, const std::map<Node, Rational>& elements) |
193 |
|
{ |
194 |
45 |
Assert(t.isBag()); |
195 |
45 |
NodeManager* nm = NodeManager::currentNM(); |
196 |
45 |
if (elements.empty()) |
197 |
|
{ |
198 |
6 |
return nm->mkConst(EmptyBag(t)); |
199 |
|
} |
200 |
78 |
TypeNode elementType = t.getBagElementType(); |
201 |
39 |
std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin(); |
202 |
|
Node bag = |
203 |
78 |
nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); |
204 |
91 |
while (++it != elements.rend()) |
205 |
|
{ |
206 |
|
Node n = |
207 |
52 |
nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); |
208 |
26 |
bag = nm->mkNode(UNION_DISJOINT, n, bag); |
209 |
|
} |
210 |
39 |
return bag; |
211 |
|
} |
212 |
|
|
213 |
30 |
Node NormalForm::constructBagFromElements(TypeNode t, |
214 |
|
const std::map<Node, Node>& elements) |
215 |
|
{ |
216 |
30 |
Assert(t.isBag()); |
217 |
30 |
NodeManager* nm = NodeManager::currentNM(); |
218 |
30 |
if (elements.empty()) |
219 |
|
{ |
220 |
2 |
return nm->mkConst(EmptyBag(t)); |
221 |
|
} |
222 |
56 |
TypeNode elementType = t.getBagElementType(); |
223 |
28 |
std::map<Node, Node>::const_reverse_iterator it = elements.rbegin(); |
224 |
56 |
Node bag = nm->mkBag(elementType, it->first, it->second); |
225 |
56 |
while (++it != elements.rend()) |
226 |
|
{ |
227 |
28 |
Node n = nm->mkBag(elementType, it->first, it->second); |
228 |
14 |
bag = nm->mkNode(UNION_DISJOINT, n, bag); |
229 |
|
} |
230 |
28 |
return bag; |
231 |
|
} |
232 |
|
|
233 |
11 |
Node NormalForm::evaluateMakeBag(TNode n) |
234 |
|
{ |
235 |
|
// the case where n is const should be handled earlier. |
236 |
|
// here we handle the case where the multiplicity is zero or negative |
237 |
11 |
Assert(n.getKind() == MK_BAG && !n.isConst() |
238 |
|
&& n[1].getConst<Rational>().sgn() < 1); |
239 |
11 |
Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); |
240 |
11 |
return emptybag; |
241 |
|
} |
242 |
|
|
243 |
58 |
Node NormalForm::evaluateBagCount(TNode n) |
244 |
|
{ |
245 |
58 |
Assert(n.getKind() == BAG_COUNT); |
246 |
|
// Examples |
247 |
|
// -------- |
248 |
|
// - (bag.count "x" (emptybag String)) = 0 |
249 |
|
// - (bag.count "x" (mkBag "y" 5)) = 0 |
250 |
|
// - (bag.count "x" (mkBag "x" 4)) = 4 |
251 |
|
// - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4 |
252 |
|
// - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0 |
253 |
|
|
254 |
116 |
std::map<Node, Rational> elements = getBagElements(n[1]); |
255 |
58 |
std::map<Node, Rational>::iterator it = elements.find(n[0]); |
256 |
|
|
257 |
58 |
NodeManager* nm = NodeManager::currentNM(); |
258 |
58 |
if (it != elements.end()) |
259 |
|
{ |
260 |
74 |
Node count = nm->mkConst(it->second); |
261 |
37 |
return count; |
262 |
|
} |
263 |
21 |
return nm->mkConst(Rational(0)); |
264 |
|
} |
265 |
|
|
266 |
7 |
Node NormalForm::evaluateDuplicateRemoval(TNode n) |
267 |
|
{ |
268 |
7 |
Assert(n.getKind() == DUPLICATE_REMOVAL); |
269 |
|
|
270 |
|
// Examples |
271 |
|
// -------- |
272 |
|
// - (duplicate_removal (emptybag String)) = (emptybag String) |
273 |
|
// - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1) |
274 |
|
// - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = |
275 |
|
// (disjoint_union (mkBag "x" 1) (mkBag "y" 1) |
276 |
|
|
277 |
14 |
std::map<Node, Rational> oldElements = getBagElements(n[0]); |
278 |
|
// copy elements from the old bag |
279 |
14 |
std::map<Node, Rational> newElements(oldElements); |
280 |
14 |
Rational one = Rational(1); |
281 |
7 |
std::map<Node, Rational>::iterator it; |
282 |
14 |
for (it = newElements.begin(); it != newElements.end(); it++) |
283 |
|
{ |
284 |
7 |
it->second = one; |
285 |
|
} |
286 |
7 |
Node bag = constructConstantBagFromElements(n[0].getType(), newElements); |
287 |
14 |
return bag; |
288 |
|
} |
289 |
|
|
290 |
18 |
Node NormalForm::evaluateUnionDisjoint(TNode n) |
291 |
|
{ |
292 |
18 |
Assert(n.getKind() == UNION_DISJOINT); |
293 |
|
// Example |
294 |
|
// ------- |
295 |
|
// input: (union_disjoint A B) |
296 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
297 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
298 |
|
// output: |
299 |
|
// (union_disjoint A B) |
300 |
|
// where A = (MK_BAG "x" 7) |
301 |
|
// B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) |
302 |
|
|
303 |
|
auto equal = [](std::map<Node, Rational>& elements, |
304 |
|
std::map<Node, Rational>::const_iterator& itA, |
305 |
6 |
std::map<Node, Rational>::const_iterator& itB) { |
306 |
|
// compute the sum of the multiplicities |
307 |
6 |
elements[itA->first] = itA->second + itB->second; |
308 |
6 |
}; |
309 |
|
|
310 |
|
auto less = [](std::map<Node, Rational>& elements, |
311 |
|
std::map<Node, Rational>::const_iterator& itA, |
312 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
313 |
|
// add the element to the result |
314 |
4 |
elements[itA->first] = itA->second; |
315 |
4 |
}; |
316 |
|
|
317 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
318 |
|
std::map<Node, Rational>::const_iterator& itA, |
319 |
6 |
std::map<Node, Rational>::const_iterator& itB) { |
320 |
|
// add the element to the result |
321 |
6 |
elements[itB->first] = itB->second; |
322 |
6 |
}; |
323 |
|
|
324 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
325 |
|
std::map<Node, Rational>& elementsA, |
326 |
48 |
std::map<Node, Rational>::const_iterator& itA) { |
327 |
|
// append the remainder of A |
328 |
60 |
while (itA != elementsA.end()) |
329 |
|
{ |
330 |
12 |
elements[itA->first] = itA->second; |
331 |
12 |
itA++; |
332 |
|
} |
333 |
36 |
}; |
334 |
|
|
335 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
336 |
|
std::map<Node, Rational>& elementsB, |
337 |
|
std::map<Node, Rational>::const_iterator& itB) { |
338 |
|
// append the remainder of B |
339 |
|
while (itB != elementsB.end()) |
340 |
|
{ |
341 |
|
elements[itB->first] = itB->second; |
342 |
|
itB++; |
343 |
|
} |
344 |
|
}; |
345 |
|
|
346 |
|
return evaluateBinaryOperation( |
347 |
18 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
348 |
|
} |
349 |
|
|
350 |
6 |
Node NormalForm::evaluateUnionMax(TNode n) |
351 |
|
{ |
352 |
6 |
Assert(n.getKind() == UNION_MAX); |
353 |
|
// Example |
354 |
|
// ------- |
355 |
|
// input: (union_max A B) |
356 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
357 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
358 |
|
// output: |
359 |
|
// (union_disjoint A B) |
360 |
|
// where A = (MK_BAG "x" 4) |
361 |
|
// B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) |
362 |
|
|
363 |
|
auto equal = [](std::map<Node, Rational>& elements, |
364 |
|
std::map<Node, Rational>::const_iterator& itA, |
365 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
366 |
|
// compute the maximum multiplicity |
367 |
4 |
elements[itA->first] = std::max(itA->second, itB->second); |
368 |
4 |
}; |
369 |
|
|
370 |
|
auto less = [](std::map<Node, Rational>& elements, |
371 |
|
std::map<Node, Rational>::const_iterator& itA, |
372 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
373 |
|
// add to the result |
374 |
2 |
elements[itA->first] = itA->second; |
375 |
2 |
}; |
376 |
|
|
377 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
378 |
|
std::map<Node, Rational>::const_iterator& itA, |
379 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
380 |
|
// add to the result |
381 |
2 |
elements[itB->first] = itB->second; |
382 |
2 |
}; |
383 |
|
|
384 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
385 |
|
std::map<Node, Rational>& elementsA, |
386 |
16 |
std::map<Node, Rational>::const_iterator& itA) { |
387 |
|
// append the remainder of A |
388 |
20 |
while (itA != elementsA.end()) |
389 |
|
{ |
390 |
4 |
elements[itA->first] = itA->second; |
391 |
4 |
itA++; |
392 |
|
} |
393 |
12 |
}; |
394 |
|
|
395 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
396 |
|
std::map<Node, Rational>& elementsB, |
397 |
|
std::map<Node, Rational>::const_iterator& itB) { |
398 |
|
// append the remainder of B |
399 |
|
while (itB != elementsB.end()) |
400 |
|
{ |
401 |
|
elements[itB->first] = itB->second; |
402 |
|
itB++; |
403 |
|
} |
404 |
|
}; |
405 |
|
|
406 |
|
return evaluateBinaryOperation( |
407 |
6 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
408 |
|
} |
409 |
|
|
410 |
3 |
Node NormalForm::evaluateIntersectionMin(TNode n) |
411 |
|
{ |
412 |
3 |
Assert(n.getKind() == INTERSECTION_MIN); |
413 |
|
// Example |
414 |
|
// ------- |
415 |
|
// input: (intersectionMin A B) |
416 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
417 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
418 |
|
// output: |
419 |
|
// (MK_BAG "x" 3) |
420 |
|
|
421 |
|
auto equal = [](std::map<Node, Rational>& elements, |
422 |
|
std::map<Node, Rational>::const_iterator& itA, |
423 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
424 |
|
// compute the minimum multiplicity |
425 |
4 |
elements[itA->first] = std::min(itA->second, itB->second); |
426 |
4 |
}; |
427 |
|
|
428 |
|
auto less = [](std::map<Node, Rational>& elements, |
429 |
|
std::map<Node, Rational>::const_iterator& itA, |
430 |
|
std::map<Node, Rational>::const_iterator& itB) { |
431 |
|
// do nothing |
432 |
|
}; |
433 |
|
|
434 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
435 |
|
std::map<Node, Rational>::const_iterator& itA, |
436 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
437 |
|
// do nothing |
438 |
2 |
}; |
439 |
|
|
440 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
441 |
|
std::map<Node, Rational>& elementsA, |
442 |
6 |
std::map<Node, Rational>::const_iterator& itA) { |
443 |
|
// do nothing |
444 |
6 |
}; |
445 |
|
|
446 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
447 |
|
std::map<Node, Rational>& elementsB, |
448 |
|
std::map<Node, Rational>::const_iterator& itB) { |
449 |
|
// do nothing |
450 |
|
}; |
451 |
|
|
452 |
|
return evaluateBinaryOperation( |
453 |
3 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
454 |
|
} |
455 |
|
|
456 |
3 |
Node NormalForm::evaluateDifferenceSubtract(TNode n) |
457 |
|
{ |
458 |
3 |
Assert(n.getKind() == DIFFERENCE_SUBTRACT); |
459 |
|
// Example |
460 |
|
// ------- |
461 |
|
// input: (difference_subtract A B) |
462 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
463 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
464 |
|
// output: |
465 |
|
// (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2)) |
466 |
|
|
467 |
|
auto equal = [](std::map<Node, Rational>& elements, |
468 |
|
std::map<Node, Rational>::const_iterator& itA, |
469 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
470 |
|
// subtract the multiplicities |
471 |
4 |
elements[itA->first] = itA->second - itB->second; |
472 |
4 |
}; |
473 |
|
|
474 |
|
auto less = [](std::map<Node, Rational>& elements, |
475 |
|
std::map<Node, Rational>::const_iterator& itA, |
476 |
|
std::map<Node, Rational>::const_iterator& itB) { |
477 |
|
// itA->first is not in B, so we add it to the difference subtract |
478 |
|
elements[itA->first] = itA->second; |
479 |
|
}; |
480 |
|
|
481 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
482 |
|
std::map<Node, Rational>::const_iterator& itA, |
483 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
484 |
|
// itB->first is not in A, so we just skip it |
485 |
2 |
}; |
486 |
|
|
487 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
488 |
|
std::map<Node, Rational>& elementsA, |
489 |
8 |
std::map<Node, Rational>::const_iterator& itA) { |
490 |
|
// append the remainder of A |
491 |
10 |
while (itA != elementsA.end()) |
492 |
|
{ |
493 |
2 |
elements[itA->first] = itA->second; |
494 |
2 |
itA++; |
495 |
|
} |
496 |
6 |
}; |
497 |
|
|
498 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
499 |
|
std::map<Node, Rational>& elementsB, |
500 |
|
std::map<Node, Rational>::const_iterator& itB) { |
501 |
|
// do nothing |
502 |
|
}; |
503 |
|
|
504 |
|
return evaluateBinaryOperation( |
505 |
3 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
506 |
|
} |
507 |
|
|
508 |
2 |
Node NormalForm::evaluateDifferenceRemove(TNode n) |
509 |
|
{ |
510 |
2 |
Assert(n.getKind() == DIFFERENCE_REMOVE); |
511 |
|
// Example |
512 |
|
// ------- |
513 |
|
// input: (difference_subtract A B) |
514 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
515 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
516 |
|
// output: |
517 |
|
// (MK_BAG "z" 2) |
518 |
|
|
519 |
|
auto equal = [](std::map<Node, Rational>& elements, |
520 |
|
std::map<Node, Rational>::const_iterator& itA, |
521 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
522 |
|
// skip the shared element by doing nothing |
523 |
2 |
}; |
524 |
|
|
525 |
|
auto less = [](std::map<Node, Rational>& elements, |
526 |
|
std::map<Node, Rational>::const_iterator& itA, |
527 |
|
std::map<Node, Rational>::const_iterator& itB) { |
528 |
|
// itA->first is not in B, so we add it to the difference remove |
529 |
|
elements[itA->first] = itA->second; |
530 |
|
}; |
531 |
|
|
532 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
533 |
|
std::map<Node, Rational>::const_iterator& itA, |
534 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
535 |
|
// itB->first is not in A, so we just skip it |
536 |
2 |
}; |
537 |
|
|
538 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
539 |
|
std::map<Node, Rational>& elementsA, |
540 |
6 |
std::map<Node, Rational>::const_iterator& itA) { |
541 |
|
// append the remainder of A |
542 |
8 |
while (itA != elementsA.end()) |
543 |
|
{ |
544 |
2 |
elements[itA->first] = itA->second; |
545 |
2 |
itA++; |
546 |
|
} |
547 |
4 |
}; |
548 |
|
|
549 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
550 |
|
std::map<Node, Rational>& elementsB, |
551 |
|
std::map<Node, Rational>::const_iterator& itB) { |
552 |
|
// do nothing |
553 |
|
}; |
554 |
|
|
555 |
|
return evaluateBinaryOperation( |
556 |
2 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
557 |
|
} |
558 |
|
|
559 |
6 |
Node NormalForm::evaluateChoose(TNode n) |
560 |
|
{ |
561 |
6 |
Assert(n.getKind() == BAG_CHOOSE); |
562 |
|
// Examples |
563 |
|
// -------- |
564 |
|
// - (choose (emptyBag String)) = "" // the empty string which is the first |
565 |
|
// element returned by the type enumerator |
566 |
|
// - (choose (MK_BAG "x" 4)) = "x" |
567 |
|
// - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x" |
568 |
|
// deterministically return the first element |
569 |
|
|
570 |
6 |
if (n[0].getKind() == EMPTYBAG) |
571 |
|
{ |
572 |
4 |
TypeNode elementType = n[0].getType().getBagElementType(); |
573 |
4 |
TypeEnumerator typeEnumerator(elementType); |
574 |
|
// get the first value from the typeEnumerator |
575 |
4 |
Node element = *typeEnumerator; |
576 |
2 |
return element; |
577 |
|
} |
578 |
|
|
579 |
4 |
if (n[0].getKind() == MK_BAG) |
580 |
|
{ |
581 |
2 |
return n[0][0]; |
582 |
|
} |
583 |
2 |
Assert(n[0].getKind() == UNION_DISJOINT); |
584 |
|
// return the first element |
585 |
|
// e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) |
586 |
2 |
return n[0][0][0]; |
587 |
|
} |
588 |
|
|
589 |
6 |
Node NormalForm::evaluateCard(TNode n) |
590 |
|
{ |
591 |
6 |
Assert(n.getKind() == BAG_CARD); |
592 |
|
// Examples |
593 |
|
// -------- |
594 |
|
// - (card (emptyBag String)) = 0 |
595 |
|
// - (choose (MK_BAG "x" 4)) = 4 |
596 |
|
// - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5 |
597 |
|
|
598 |
12 |
std::map<Node, Rational> elements = getBagElements(n[0]); |
599 |
12 |
Rational sum(0); |
600 |
12 |
for (std::pair<Node, Rational> element : elements) |
601 |
|
{ |
602 |
6 |
sum += element.second; |
603 |
|
} |
604 |
|
|
605 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
606 |
6 |
Node sumNode = nm->mkConst(sum); |
607 |
12 |
return sumNode; |
608 |
|
} |
609 |
|
|
610 |
8 |
Node NormalForm::evaluateIsSingleton(TNode n) |
611 |
|
{ |
612 |
8 |
Assert(n.getKind() == BAG_IS_SINGLETON); |
613 |
|
// Examples |
614 |
|
// -------- |
615 |
|
// - (bag.is_singleton (emptyBag String)) = false |
616 |
|
// - (bag.is_singleton (MK_BAG "x" 1)) = true |
617 |
|
// - (bag.is_singleton (MK_BAG "x" 4)) = false |
618 |
|
// - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false |
619 |
|
|
620 |
8 |
if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne()) |
621 |
|
{ |
622 |
4 |
return NodeManager::currentNM()->mkConst(true); |
623 |
|
} |
624 |
4 |
return NodeManager::currentNM()->mkConst(false); |
625 |
|
} |
626 |
|
|
627 |
6 |
Node NormalForm::evaluateFromSet(TNode n) |
628 |
|
{ |
629 |
6 |
Assert(n.getKind() == BAG_FROM_SET); |
630 |
|
|
631 |
|
// Examples |
632 |
|
// -------- |
633 |
|
// - (bag.from_set (emptyset String)) = (emptybag String) |
634 |
|
// - (bag.from_set (singleton "x")) = (mkBag "x" 1) |
635 |
|
// - (bag.from_set (union (singleton "x") (singleton "y"))) = |
636 |
|
// (disjoint_union (mkBag "x" 1) (mkBag "y" 1)) |
637 |
|
|
638 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
639 |
|
std::set<Node> setElements = |
640 |
12 |
sets::NormalForm::getElementsFromNormalConstant(n[0]); |
641 |
12 |
Rational one = Rational(1); |
642 |
12 |
std::map<Node, Rational> bagElements; |
643 |
12 |
for (const Node& element : setElements) |
644 |
|
{ |
645 |
6 |
bagElements[element] = one; |
646 |
|
} |
647 |
12 |
TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); |
648 |
6 |
Node bag = constructConstantBagFromElements(bagType, bagElements); |
649 |
12 |
return bag; |
650 |
|
} |
651 |
|
|
652 |
6 |
Node NormalForm::evaluateToSet(TNode n) |
653 |
|
{ |
654 |
6 |
Assert(n.getKind() == BAG_TO_SET); |
655 |
|
|
656 |
|
// Examples |
657 |
|
// -------- |
658 |
|
// - (bag.to_set (emptybag String)) = (emptyset String) |
659 |
|
// - (bag.to_set (mkBag "x" 4)) = (singleton "x") |
660 |
|
// - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = |
661 |
|
// (union (singleton "x") (singleton "y"))) |
662 |
|
|
663 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
664 |
12 |
std::map<Node, Rational> bagElements = getBagElements(n[0]); |
665 |
12 |
std::set<Node> setElements; |
666 |
6 |
std::map<Node, Rational>::const_reverse_iterator it; |
667 |
12 |
for (it = bagElements.rbegin(); it != bagElements.rend(); it++) |
668 |
|
{ |
669 |
6 |
setElements.insert(it->first); |
670 |
|
} |
671 |
12 |
TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); |
672 |
6 |
Node set = sets::NormalForm::elementsToSet(setElements, setType); |
673 |
12 |
return set; |
674 |
|
} |
675 |
|
|
676 |
|
} // namespace bags |
677 |
|
} // namespace theory |
678 |
28191 |
} // namespace cvc5 |