1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Mudathir Mohamed, Aina Niemetz |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Normal form for bag constants. |
14 |
|
*/ |
15 |
|
#include "normal_form.h" |
16 |
|
|
17 |
|
#include "expr/emptybag.h" |
18 |
|
#include "theory/sets/normal_form.h" |
19 |
|
#include "theory/type_enumerator.h" |
20 |
|
#include "util/rational.h" |
21 |
|
|
22 |
|
using namespace cvc5::kind; |
23 |
|
|
24 |
|
namespace cvc5 { |
25 |
|
namespace theory { |
26 |
|
namespace bags { |
27 |
|
|
28 |
117 |
bool NormalForm::isConstant(TNode n) |
29 |
|
{ |
30 |
117 |
if (n.getKind() == EMPTYBAG) |
31 |
|
{ |
32 |
|
// empty bags are already normalized |
33 |
|
return true; |
34 |
|
} |
35 |
117 |
if (n.getKind() == MK_BAG) |
36 |
|
{ |
37 |
|
// see the implementation in MkBagTypeRule::computeIsConst |
38 |
|
return n.isConst(); |
39 |
|
} |
40 |
117 |
if (n.getKind() == UNION_DISJOINT) |
41 |
|
{ |
42 |
117 |
if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst())) |
43 |
|
{ |
44 |
|
// the first child is not a constant |
45 |
59 |
return false; |
46 |
|
} |
47 |
|
// store the previous element to check the ordering of elements |
48 |
116 |
Node previousElement = n[0][0]; |
49 |
116 |
Node current = n[1]; |
50 |
74 |
while (current.getKind() == UNION_DISJOINT) |
51 |
|
{ |
52 |
8 |
if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst())) |
53 |
|
{ |
54 |
|
// the current element is not a constant |
55 |
|
return false; |
56 |
|
} |
57 |
8 |
if (previousElement >= current[0][0]) |
58 |
|
{ |
59 |
|
// the ordering is violated |
60 |
|
return false; |
61 |
|
} |
62 |
8 |
previousElement = current[0][0]; |
63 |
8 |
current = current[1]; |
64 |
|
} |
65 |
|
// check last element |
66 |
58 |
if (!(current.getKind() == kind::MK_BAG && current.isConst())) |
67 |
|
{ |
68 |
|
// the last element is not a constant |
69 |
|
return false; |
70 |
|
} |
71 |
58 |
if (previousElement >= current[0]) |
72 |
|
{ |
73 |
|
// the ordering is violated |
74 |
8 |
return false; |
75 |
|
} |
76 |
50 |
return true; |
77 |
|
} |
78 |
|
|
79 |
|
// only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be |
80 |
|
// constants |
81 |
|
return false; |
82 |
|
} |
83 |
|
|
84 |
1504 |
bool NormalForm::areChildrenConstants(TNode n) |
85 |
|
{ |
86 |
3493 |
return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); |
87 |
|
} |
88 |
|
|
89 |
146 |
Node NormalForm::evaluate(TNode n) |
90 |
|
{ |
91 |
146 |
Assert(areChildrenConstants(n)); |
92 |
146 |
if (n.isConst()) |
93 |
|
{ |
94 |
|
// a constant node is already in a normal form |
95 |
6 |
return n; |
96 |
|
} |
97 |
140 |
switch (n.getKind()) |
98 |
|
{ |
99 |
11 |
case MK_BAG: return evaluateMakeBag(n); |
100 |
58 |
case BAG_COUNT: return evaluateBagCount(n); |
101 |
7 |
case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); |
102 |
18 |
case UNION_DISJOINT: return evaluateUnionDisjoint(n); |
103 |
6 |
case UNION_MAX: return evaluateUnionMax(n); |
104 |
3 |
case INTERSECTION_MIN: return evaluateIntersectionMin(n); |
105 |
3 |
case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); |
106 |
2 |
case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); |
107 |
6 |
case BAG_CHOOSE: return evaluateChoose(n); |
108 |
6 |
case BAG_CARD: return evaluateCard(n); |
109 |
8 |
case BAG_IS_SINGLETON: return evaluateIsSingleton(n); |
110 |
6 |
case BAG_FROM_SET: return evaluateFromSet(n); |
111 |
6 |
case BAG_TO_SET: return evaluateToSet(n); |
112 |
|
default: break; |
113 |
|
} |
114 |
|
Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n |
115 |
|
<< std::endl; |
116 |
|
} |
117 |
|
|
118 |
|
template <typename T1, typename T2, typename T3, typename T4, typename T5> |
119 |
32 |
Node NormalForm::evaluateBinaryOperation(const TNode& n, |
120 |
|
T1&& equal, |
121 |
|
T2&& less, |
122 |
|
T3&& greaterOrEqual, |
123 |
|
T4&& remainderOfA, |
124 |
|
T5&& remainderOfB) |
125 |
|
{ |
126 |
64 |
std::map<Node, Rational> elementsA = getBagElements(n[0]); |
127 |
64 |
std::map<Node, Rational> elementsB = getBagElements(n[1]); |
128 |
64 |
std::map<Node, Rational> elements; |
129 |
|
|
130 |
32 |
std::map<Node, Rational>::const_iterator itA = elementsA.begin(); |
131 |
32 |
std::map<Node, Rational>::const_iterator itB = elementsB.begin(); |
132 |
|
|
133 |
96 |
Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " |
134 |
64 |
<< n.getKind() << "] " << std::endl |
135 |
32 |
<< "elements A: " << elementsA << std::endl |
136 |
32 |
<< "elements B: " << elementsB << std::endl; |
137 |
|
|
138 |
112 |
while (itA != elementsA.end() && itB != elementsB.end()) |
139 |
|
{ |
140 |
40 |
if (itA->first == itB->first) |
141 |
|
{ |
142 |
20 |
equal(elements, itA, itB); |
143 |
20 |
itA++; |
144 |
20 |
itB++; |
145 |
|
} |
146 |
20 |
else if (itA->first < itB->first) |
147 |
|
{ |
148 |
6 |
less(elements, itA, itB); |
149 |
6 |
itA++; |
150 |
|
} |
151 |
|
else |
152 |
|
{ |
153 |
14 |
greaterOrEqual(elements, itA, itB); |
154 |
14 |
itB++; |
155 |
|
} |
156 |
|
} |
157 |
|
|
158 |
|
// handle the remaining elements from A |
159 |
32 |
remainderOfA(elements, elementsA, itA); |
160 |
|
// handle the remaining elements from B |
161 |
32 |
remainderOfA(elements, elementsB, itB); |
162 |
|
|
163 |
32 |
Trace("bags-evaluate") << "elements: " << elements << std::endl; |
164 |
32 |
Node bag = constructConstantBagFromElements(n.getType(), elements); |
165 |
32 |
Trace("bags-evaluate") << "bag: " << bag << std::endl; |
166 |
64 |
return bag; |
167 |
|
} |
168 |
|
|
169 |
141 |
std::map<Node, Rational> NormalForm::getBagElements(TNode n) |
170 |
|
{ |
171 |
141 |
Assert(n.isConst()) << "node " << n << " is not in a normal form" |
172 |
|
<< std::endl; |
173 |
141 |
std::map<Node, Rational> elements; |
174 |
141 |
if (n.getKind() == EMPTYBAG) |
175 |
|
{ |
176 |
26 |
return elements; |
177 |
|
} |
178 |
231 |
while (n.getKind() == kind::UNION_DISJOINT) |
179 |
|
{ |
180 |
58 |
Assert(n[0].getKind() == kind::MK_BAG); |
181 |
116 |
Node element = n[0][0]; |
182 |
116 |
Rational count = n[0][1].getConst<Rational>(); |
183 |
58 |
elements[element] = count; |
184 |
58 |
n = n[1]; |
185 |
|
} |
186 |
115 |
Assert(n.getKind() == kind::MK_BAG); |
187 |
230 |
Node lastElement = n[0]; |
188 |
230 |
Rational lastCount = n[1].getConst<Rational>(); |
189 |
115 |
elements[lastElement] = lastCount; |
190 |
115 |
return elements; |
191 |
|
} |
192 |
|
|
193 |
45 |
Node NormalForm::constructConstantBagFromElements( |
194 |
|
TypeNode t, const std::map<Node, Rational>& elements) |
195 |
|
{ |
196 |
45 |
Assert(t.isBag()); |
197 |
45 |
NodeManager* nm = NodeManager::currentNM(); |
198 |
45 |
if (elements.empty()) |
199 |
|
{ |
200 |
6 |
return nm->mkConst(EmptyBag(t)); |
201 |
|
} |
202 |
78 |
TypeNode elementType = t.getBagElementType(); |
203 |
39 |
std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin(); |
204 |
|
Node bag = |
205 |
78 |
nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); |
206 |
91 |
while (++it != elements.rend()) |
207 |
|
{ |
208 |
|
Node n = |
209 |
52 |
nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); |
210 |
26 |
bag = nm->mkNode(UNION_DISJOINT, n, bag); |
211 |
|
} |
212 |
39 |
return bag; |
213 |
|
} |
214 |
|
|
215 |
30 |
Node NormalForm::constructBagFromElements(TypeNode t, |
216 |
|
const std::map<Node, Node>& elements) |
217 |
|
{ |
218 |
30 |
Assert(t.isBag()); |
219 |
30 |
NodeManager* nm = NodeManager::currentNM(); |
220 |
30 |
if (elements.empty()) |
221 |
|
{ |
222 |
2 |
return nm->mkConst(EmptyBag(t)); |
223 |
|
} |
224 |
56 |
TypeNode elementType = t.getBagElementType(); |
225 |
28 |
std::map<Node, Node>::const_reverse_iterator it = elements.rbegin(); |
226 |
56 |
Node bag = nm->mkBag(elementType, it->first, it->second); |
227 |
56 |
while (++it != elements.rend()) |
228 |
|
{ |
229 |
28 |
Node n = nm->mkBag(elementType, it->first, it->second); |
230 |
14 |
bag = nm->mkNode(UNION_DISJOINT, n, bag); |
231 |
|
} |
232 |
28 |
return bag; |
233 |
|
} |
234 |
|
|
235 |
11 |
Node NormalForm::evaluateMakeBag(TNode n) |
236 |
|
{ |
237 |
|
// the case where n is const should be handled earlier. |
238 |
|
// here we handle the case where the multiplicity is zero or negative |
239 |
11 |
Assert(n.getKind() == MK_BAG && !n.isConst() |
240 |
|
&& n[1].getConst<Rational>().sgn() < 1); |
241 |
11 |
Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); |
242 |
11 |
return emptybag; |
243 |
|
} |
244 |
|
|
245 |
58 |
Node NormalForm::evaluateBagCount(TNode n) |
246 |
|
{ |
247 |
58 |
Assert(n.getKind() == BAG_COUNT); |
248 |
|
// Examples |
249 |
|
// -------- |
250 |
|
// - (bag.count "x" (emptybag String)) = 0 |
251 |
|
// - (bag.count "x" (mkBag "y" 5)) = 0 |
252 |
|
// - (bag.count "x" (mkBag "x" 4)) = 4 |
253 |
|
// - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4 |
254 |
|
// - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0 |
255 |
|
|
256 |
116 |
std::map<Node, Rational> elements = getBagElements(n[1]); |
257 |
58 |
std::map<Node, Rational>::iterator it = elements.find(n[0]); |
258 |
|
|
259 |
58 |
NodeManager* nm = NodeManager::currentNM(); |
260 |
58 |
if (it != elements.end()) |
261 |
|
{ |
262 |
74 |
Node count = nm->mkConst(it->second); |
263 |
37 |
return count; |
264 |
|
} |
265 |
21 |
return nm->mkConst(Rational(0)); |
266 |
|
} |
267 |
|
|
268 |
7 |
Node NormalForm::evaluateDuplicateRemoval(TNode n) |
269 |
|
{ |
270 |
7 |
Assert(n.getKind() == DUPLICATE_REMOVAL); |
271 |
|
|
272 |
|
// Examples |
273 |
|
// -------- |
274 |
|
// - (duplicate_removal (emptybag String)) = (emptybag String) |
275 |
|
// - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1) |
276 |
|
// - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = |
277 |
|
// (disjoint_union (mkBag "x" 1) (mkBag "y" 1) |
278 |
|
|
279 |
14 |
std::map<Node, Rational> oldElements = getBagElements(n[0]); |
280 |
|
// copy elements from the old bag |
281 |
14 |
std::map<Node, Rational> newElements(oldElements); |
282 |
14 |
Rational one = Rational(1); |
283 |
7 |
std::map<Node, Rational>::iterator it; |
284 |
14 |
for (it = newElements.begin(); it != newElements.end(); it++) |
285 |
|
{ |
286 |
7 |
it->second = one; |
287 |
|
} |
288 |
7 |
Node bag = constructConstantBagFromElements(n[0].getType(), newElements); |
289 |
14 |
return bag; |
290 |
|
} |
291 |
|
|
292 |
18 |
Node NormalForm::evaluateUnionDisjoint(TNode n) |
293 |
|
{ |
294 |
18 |
Assert(n.getKind() == UNION_DISJOINT); |
295 |
|
// Example |
296 |
|
// ------- |
297 |
|
// input: (union_disjoint A B) |
298 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
299 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
300 |
|
// output: |
301 |
|
// (union_disjoint A B) |
302 |
|
// where A = (MK_BAG "x" 7) |
303 |
|
// B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) |
304 |
|
|
305 |
|
auto equal = [](std::map<Node, Rational>& elements, |
306 |
|
std::map<Node, Rational>::const_iterator& itA, |
307 |
6 |
std::map<Node, Rational>::const_iterator& itB) { |
308 |
|
// compute the sum of the multiplicities |
309 |
6 |
elements[itA->first] = itA->second + itB->second; |
310 |
6 |
}; |
311 |
|
|
312 |
|
auto less = [](std::map<Node, Rational>& elements, |
313 |
|
std::map<Node, Rational>::const_iterator& itA, |
314 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
315 |
|
// add the element to the result |
316 |
4 |
elements[itA->first] = itA->second; |
317 |
4 |
}; |
318 |
|
|
319 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
320 |
|
std::map<Node, Rational>::const_iterator& itA, |
321 |
6 |
std::map<Node, Rational>::const_iterator& itB) { |
322 |
|
// add the element to the result |
323 |
6 |
elements[itB->first] = itB->second; |
324 |
6 |
}; |
325 |
|
|
326 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
327 |
|
std::map<Node, Rational>& elementsA, |
328 |
48 |
std::map<Node, Rational>::const_iterator& itA) { |
329 |
|
// append the remainder of A |
330 |
60 |
while (itA != elementsA.end()) |
331 |
|
{ |
332 |
12 |
elements[itA->first] = itA->second; |
333 |
12 |
itA++; |
334 |
|
} |
335 |
36 |
}; |
336 |
|
|
337 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
338 |
|
std::map<Node, Rational>& elementsB, |
339 |
|
std::map<Node, Rational>::const_iterator& itB) { |
340 |
|
// append the remainder of B |
341 |
|
while (itB != elementsB.end()) |
342 |
|
{ |
343 |
|
elements[itB->first] = itB->second; |
344 |
|
itB++; |
345 |
|
} |
346 |
|
}; |
347 |
|
|
348 |
|
return evaluateBinaryOperation( |
349 |
18 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
350 |
|
} |
351 |
|
|
352 |
6 |
Node NormalForm::evaluateUnionMax(TNode n) |
353 |
|
{ |
354 |
6 |
Assert(n.getKind() == UNION_MAX); |
355 |
|
// Example |
356 |
|
// ------- |
357 |
|
// input: (union_max A B) |
358 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
359 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
360 |
|
// output: |
361 |
|
// (union_disjoint A B) |
362 |
|
// where A = (MK_BAG "x" 4) |
363 |
|
// B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) |
364 |
|
|
365 |
|
auto equal = [](std::map<Node, Rational>& elements, |
366 |
|
std::map<Node, Rational>::const_iterator& itA, |
367 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
368 |
|
// compute the maximum multiplicity |
369 |
4 |
elements[itA->first] = std::max(itA->second, itB->second); |
370 |
4 |
}; |
371 |
|
|
372 |
|
auto less = [](std::map<Node, Rational>& elements, |
373 |
|
std::map<Node, Rational>::const_iterator& itA, |
374 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
375 |
|
// add to the result |
376 |
2 |
elements[itA->first] = itA->second; |
377 |
2 |
}; |
378 |
|
|
379 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
380 |
|
std::map<Node, Rational>::const_iterator& itA, |
381 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
382 |
|
// add to the result |
383 |
2 |
elements[itB->first] = itB->second; |
384 |
2 |
}; |
385 |
|
|
386 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
387 |
|
std::map<Node, Rational>& elementsA, |
388 |
16 |
std::map<Node, Rational>::const_iterator& itA) { |
389 |
|
// append the remainder of A |
390 |
20 |
while (itA != elementsA.end()) |
391 |
|
{ |
392 |
4 |
elements[itA->first] = itA->second; |
393 |
4 |
itA++; |
394 |
|
} |
395 |
12 |
}; |
396 |
|
|
397 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
398 |
|
std::map<Node, Rational>& elementsB, |
399 |
|
std::map<Node, Rational>::const_iterator& itB) { |
400 |
|
// append the remainder of B |
401 |
|
while (itB != elementsB.end()) |
402 |
|
{ |
403 |
|
elements[itB->first] = itB->second; |
404 |
|
itB++; |
405 |
|
} |
406 |
|
}; |
407 |
|
|
408 |
|
return evaluateBinaryOperation( |
409 |
6 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
410 |
|
} |
411 |
|
|
412 |
3 |
Node NormalForm::evaluateIntersectionMin(TNode n) |
413 |
|
{ |
414 |
3 |
Assert(n.getKind() == INTERSECTION_MIN); |
415 |
|
// Example |
416 |
|
// ------- |
417 |
|
// input: (intersectionMin A B) |
418 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
419 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
420 |
|
// output: |
421 |
|
// (MK_BAG "x" 3) |
422 |
|
|
423 |
|
auto equal = [](std::map<Node, Rational>& elements, |
424 |
|
std::map<Node, Rational>::const_iterator& itA, |
425 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
426 |
|
// compute the minimum multiplicity |
427 |
4 |
elements[itA->first] = std::min(itA->second, itB->second); |
428 |
4 |
}; |
429 |
|
|
430 |
|
auto less = [](std::map<Node, Rational>& elements, |
431 |
|
std::map<Node, Rational>::const_iterator& itA, |
432 |
|
std::map<Node, Rational>::const_iterator& itB) { |
433 |
|
// do nothing |
434 |
|
}; |
435 |
|
|
436 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
437 |
|
std::map<Node, Rational>::const_iterator& itA, |
438 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
439 |
|
// do nothing |
440 |
2 |
}; |
441 |
|
|
442 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
443 |
|
std::map<Node, Rational>& elementsA, |
444 |
6 |
std::map<Node, Rational>::const_iterator& itA) { |
445 |
|
// do nothing |
446 |
6 |
}; |
447 |
|
|
448 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
449 |
|
std::map<Node, Rational>& elementsB, |
450 |
|
std::map<Node, Rational>::const_iterator& itB) { |
451 |
|
// do nothing |
452 |
|
}; |
453 |
|
|
454 |
|
return evaluateBinaryOperation( |
455 |
3 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
456 |
|
} |
457 |
|
|
458 |
3 |
Node NormalForm::evaluateDifferenceSubtract(TNode n) |
459 |
|
{ |
460 |
3 |
Assert(n.getKind() == DIFFERENCE_SUBTRACT); |
461 |
|
// Example |
462 |
|
// ------- |
463 |
|
// input: (difference_subtract A B) |
464 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
465 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
466 |
|
// output: |
467 |
|
// (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2)) |
468 |
|
|
469 |
|
auto equal = [](std::map<Node, Rational>& elements, |
470 |
|
std::map<Node, Rational>::const_iterator& itA, |
471 |
4 |
std::map<Node, Rational>::const_iterator& itB) { |
472 |
|
// subtract the multiplicities |
473 |
4 |
elements[itA->first] = itA->second - itB->second; |
474 |
4 |
}; |
475 |
|
|
476 |
|
auto less = [](std::map<Node, Rational>& elements, |
477 |
|
std::map<Node, Rational>::const_iterator& itA, |
478 |
|
std::map<Node, Rational>::const_iterator& itB) { |
479 |
|
// itA->first is not in B, so we add it to the difference subtract |
480 |
|
elements[itA->first] = itA->second; |
481 |
|
}; |
482 |
|
|
483 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
484 |
|
std::map<Node, Rational>::const_iterator& itA, |
485 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
486 |
|
// itB->first is not in A, so we just skip it |
487 |
2 |
}; |
488 |
|
|
489 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
490 |
|
std::map<Node, Rational>& elementsA, |
491 |
8 |
std::map<Node, Rational>::const_iterator& itA) { |
492 |
|
// append the remainder of A |
493 |
10 |
while (itA != elementsA.end()) |
494 |
|
{ |
495 |
2 |
elements[itA->first] = itA->second; |
496 |
2 |
itA++; |
497 |
|
} |
498 |
6 |
}; |
499 |
|
|
500 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
501 |
|
std::map<Node, Rational>& elementsB, |
502 |
|
std::map<Node, Rational>::const_iterator& itB) { |
503 |
|
// do nothing |
504 |
|
}; |
505 |
|
|
506 |
|
return evaluateBinaryOperation( |
507 |
3 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
508 |
|
} |
509 |
|
|
510 |
2 |
Node NormalForm::evaluateDifferenceRemove(TNode n) |
511 |
|
{ |
512 |
2 |
Assert(n.getKind() == DIFFERENCE_REMOVE); |
513 |
|
// Example |
514 |
|
// ------- |
515 |
|
// input: (difference_subtract A B) |
516 |
|
// where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) |
517 |
|
// B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) |
518 |
|
// output: |
519 |
|
// (MK_BAG "z" 2) |
520 |
|
|
521 |
|
auto equal = [](std::map<Node, Rational>& elements, |
522 |
|
std::map<Node, Rational>::const_iterator& itA, |
523 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
524 |
|
// skip the shared element by doing nothing |
525 |
2 |
}; |
526 |
|
|
527 |
|
auto less = [](std::map<Node, Rational>& elements, |
528 |
|
std::map<Node, Rational>::const_iterator& itA, |
529 |
|
std::map<Node, Rational>::const_iterator& itB) { |
530 |
|
// itA->first is not in B, so we add it to the difference remove |
531 |
|
elements[itA->first] = itA->second; |
532 |
|
}; |
533 |
|
|
534 |
|
auto greaterOrEqual = [](std::map<Node, Rational>& elements, |
535 |
|
std::map<Node, Rational>::const_iterator& itA, |
536 |
2 |
std::map<Node, Rational>::const_iterator& itB) { |
537 |
|
// itB->first is not in A, so we just skip it |
538 |
2 |
}; |
539 |
|
|
540 |
|
auto remainderOfA = [](std::map<Node, Rational>& elements, |
541 |
|
std::map<Node, Rational>& elementsA, |
542 |
6 |
std::map<Node, Rational>::const_iterator& itA) { |
543 |
|
// append the remainder of A |
544 |
8 |
while (itA != elementsA.end()) |
545 |
|
{ |
546 |
2 |
elements[itA->first] = itA->second; |
547 |
2 |
itA++; |
548 |
|
} |
549 |
4 |
}; |
550 |
|
|
551 |
|
auto remainderOfB = [](std::map<Node, Rational>& elements, |
552 |
|
std::map<Node, Rational>& elementsB, |
553 |
|
std::map<Node, Rational>::const_iterator& itB) { |
554 |
|
// do nothing |
555 |
|
}; |
556 |
|
|
557 |
|
return evaluateBinaryOperation( |
558 |
2 |
n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); |
559 |
|
} |
560 |
|
|
561 |
6 |
Node NormalForm::evaluateChoose(TNode n) |
562 |
|
{ |
563 |
6 |
Assert(n.getKind() == BAG_CHOOSE); |
564 |
|
// Examples |
565 |
|
// -------- |
566 |
|
// - (choose (emptyBag String)) = "" // the empty string which is the first |
567 |
|
// element returned by the type enumerator |
568 |
|
// - (choose (MK_BAG "x" 4)) = "x" |
569 |
|
// - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x" |
570 |
|
// deterministically return the first element |
571 |
|
|
572 |
6 |
if (n[0].getKind() == EMPTYBAG) |
573 |
|
{ |
574 |
4 |
TypeNode elementType = n[0].getType().getBagElementType(); |
575 |
4 |
TypeEnumerator typeEnumerator(elementType); |
576 |
|
// get the first value from the typeEnumerator |
577 |
4 |
Node element = *typeEnumerator; |
578 |
2 |
return element; |
579 |
|
} |
580 |
|
|
581 |
4 |
if (n[0].getKind() == MK_BAG) |
582 |
|
{ |
583 |
2 |
return n[0][0]; |
584 |
|
} |
585 |
2 |
Assert(n[0].getKind() == UNION_DISJOINT); |
586 |
|
// return the first element |
587 |
|
// e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) |
588 |
2 |
return n[0][0][0]; |
589 |
|
} |
590 |
|
|
591 |
6 |
Node NormalForm::evaluateCard(TNode n) |
592 |
|
{ |
593 |
6 |
Assert(n.getKind() == BAG_CARD); |
594 |
|
// Examples |
595 |
|
// -------- |
596 |
|
// - (card (emptyBag String)) = 0 |
597 |
|
// - (choose (MK_BAG "x" 4)) = 4 |
598 |
|
// - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5 |
599 |
|
|
600 |
12 |
std::map<Node, Rational> elements = getBagElements(n[0]); |
601 |
12 |
Rational sum(0); |
602 |
12 |
for (std::pair<Node, Rational> element : elements) |
603 |
|
{ |
604 |
6 |
sum += element.second; |
605 |
|
} |
606 |
|
|
607 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
608 |
6 |
Node sumNode = nm->mkConst(sum); |
609 |
12 |
return sumNode; |
610 |
|
} |
611 |
|
|
612 |
8 |
Node NormalForm::evaluateIsSingleton(TNode n) |
613 |
|
{ |
614 |
8 |
Assert(n.getKind() == BAG_IS_SINGLETON); |
615 |
|
// Examples |
616 |
|
// -------- |
617 |
|
// - (bag.is_singleton (emptyBag String)) = false |
618 |
|
// - (bag.is_singleton (MK_BAG "x" 1)) = true |
619 |
|
// - (bag.is_singleton (MK_BAG "x" 4)) = false |
620 |
|
// - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false |
621 |
|
|
622 |
8 |
if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne()) |
623 |
|
{ |
624 |
4 |
return NodeManager::currentNM()->mkConst(true); |
625 |
|
} |
626 |
4 |
return NodeManager::currentNM()->mkConst(false); |
627 |
|
} |
628 |
|
|
629 |
6 |
Node NormalForm::evaluateFromSet(TNode n) |
630 |
|
{ |
631 |
6 |
Assert(n.getKind() == BAG_FROM_SET); |
632 |
|
|
633 |
|
// Examples |
634 |
|
// -------- |
635 |
|
// - (bag.from_set (emptyset String)) = (emptybag String) |
636 |
|
// - (bag.from_set (singleton "x")) = (mkBag "x" 1) |
637 |
|
// - (bag.from_set (union (singleton "x") (singleton "y"))) = |
638 |
|
// (disjoint_union (mkBag "x" 1) (mkBag "y" 1)) |
639 |
|
|
640 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
641 |
|
std::set<Node> setElements = |
642 |
12 |
sets::NormalForm::getElementsFromNormalConstant(n[0]); |
643 |
12 |
Rational one = Rational(1); |
644 |
12 |
std::map<Node, Rational> bagElements; |
645 |
12 |
for (const Node& element : setElements) |
646 |
|
{ |
647 |
6 |
bagElements[element] = one; |
648 |
|
} |
649 |
12 |
TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); |
650 |
6 |
Node bag = constructConstantBagFromElements(bagType, bagElements); |
651 |
12 |
return bag; |
652 |
|
} |
653 |
|
|
654 |
6 |
Node NormalForm::evaluateToSet(TNode n) |
655 |
|
{ |
656 |
6 |
Assert(n.getKind() == BAG_TO_SET); |
657 |
|
|
658 |
|
// Examples |
659 |
|
// -------- |
660 |
|
// - (bag.to_set (emptybag String)) = (emptyset String) |
661 |
|
// - (bag.to_set (mkBag "x" 4)) = (singleton "x") |
662 |
|
// - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = |
663 |
|
// (union (singleton "x") (singleton "y"))) |
664 |
|
|
665 |
6 |
NodeManager* nm = NodeManager::currentNM(); |
666 |
12 |
std::map<Node, Rational> bagElements = getBagElements(n[0]); |
667 |
12 |
std::set<Node> setElements; |
668 |
6 |
std::map<Node, Rational>::const_reverse_iterator it; |
669 |
12 |
for (it = bagElements.rbegin(); it != bagElements.rend(); it++) |
670 |
|
{ |
671 |
6 |
setElements.insert(it->first); |
672 |
|
} |
673 |
12 |
TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); |
674 |
6 |
Node set = sets::NormalForm::elementsToSet(setElements, setType); |
675 |
12 |
return set; |
676 |
|
} |
677 |
|
|
678 |
|
} // namespace bags |
679 |
|
} // namespace theory |
680 |
29349 |
} // namespace cvc5 |