GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/normal_form.cpp Lines: 257 275 93.5 %
Date: 2021-09-29 Branches: 577 1341 43.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Normal form for bag constants.
14
 */
15
#include "normal_form.h"
16
17
#include "expr/emptybag.h"
18
#include "theory/sets/normal_form.h"
19
#include "theory/type_enumerator.h"
20
#include "util/rational.h"
21
22
using namespace cvc5::kind;
23
24
namespace cvc5 {
25
namespace theory {
26
namespace bags {
27
28
106
bool NormalForm::isConstant(TNode n)
29
{
30
106
  if (n.getKind() == EMPTYBAG)
31
  {
32
    // empty bags are already normalized
33
    return true;
34
  }
35
106
  if (n.getKind() == MK_BAG)
36
  {
37
    // see the implementation in MkBagTypeRule::computeIsConst
38
    return n.isConst();
39
  }
40
106
  if (n.getKind() == UNION_DISJOINT)
41
  {
42
106
    if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst()))
43
    {
44
      // the first child is not a constant
45
55
      return false;
46
    }
47
    // store the previous element to check the ordering of elements
48
102
    Node previousElement = n[0][0];
49
102
    Node current = n[1];
50
67
    while (current.getKind() == UNION_DISJOINT)
51
    {
52
8
      if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst()))
53
      {
54
        // the current element is not a constant
55
        return false;
56
      }
57
8
      if (previousElement >= current[0][0])
58
      {
59
        // the ordering is violated
60
        return false;
61
      }
62
8
      previousElement = current[0][0];
63
8
      current = current[1];
64
    }
65
    // check last element
66
51
    if (!(current.getKind() == kind::MK_BAG && current.isConst()))
67
    {
68
      // the last element is not a constant
69
      return false;
70
    }
71
51
    if (previousElement >= current[0])
72
    {
73
      // the ordering is violated
74
10
      return false;
75
    }
76
41
    return true;
77
  }
78
79
  // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be
80
  // constants
81
  return false;
82
}
83
84
1124
bool NormalForm::areChildrenConstants(TNode n)
85
{
86
2661
  return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
87
}
88
89
146
Node NormalForm::evaluate(TNode n)
90
{
91
146
  Assert(areChildrenConstants(n));
92
146
  if (n.isConst())
93
  {
94
    // a constant node is already in a normal form
95
6
    return n;
96
  }
97
140
  switch (n.getKind())
98
  {
99
11
    case MK_BAG: return evaluateMakeBag(n);
100
54
    case BAG_COUNT: return evaluateBagCount(n);
101
7
    case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
102
20
    case UNION_DISJOINT: return evaluateUnionDisjoint(n);
103
6
    case UNION_MAX: return evaluateUnionMax(n);
104
3
    case INTERSECTION_MIN: return evaluateIntersectionMin(n);
105
3
    case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
106
2
    case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
107
6
    case BAG_CHOOSE: return evaluateChoose(n);
108
6
    case BAG_CARD: return evaluateCard(n);
109
8
    case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
110
6
    case BAG_FROM_SET: return evaluateFromSet(n);
111
6
    case BAG_TO_SET: return evaluateToSet(n);
112
2
    case BAG_MAP: return evaluateBagMap(n);
113
    default: break;
114
  }
115
  Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
116
              << std::endl;
117
}
118
119
template <typename T1, typename T2, typename T3, typename T4, typename T5>
120
34
Node NormalForm::evaluateBinaryOperation(const TNode& n,
121
                                         T1&& equal,
122
                                         T2&& less,
123
                                         T3&& greaterOrEqual,
124
                                         T4&& remainderOfA,
125
                                         T5&& remainderOfB)
126
{
127
68
  std::map<Node, Rational> elementsA = getBagElements(n[0]);
128
68
  std::map<Node, Rational> elementsB = getBagElements(n[1]);
129
68
  std::map<Node, Rational> elements;
130
131
34
  std::map<Node, Rational>::const_iterator itA = elementsA.begin();
132
34
  std::map<Node, Rational>::const_iterator itB = elementsB.begin();
133
134
102
  Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
135
68
                         << n.getKind() << "] " << std::endl
136
34
                         << "elements A: " << elementsA << std::endl
137
34
                         << "elements B: " << elementsB << std::endl;
138
139
118
  while (itA != elementsA.end() && itB != elementsB.end())
140
  {
141
42
    if (itA->first == itB->first)
142
    {
143
22
      equal(elements, itA, itB);
144
22
      itA++;
145
22
      itB++;
146
    }
147
20
    else if (itA->first < itB->first)
148
    {
149
6
      less(elements, itA, itB);
150
6
      itA++;
151
    }
152
    else
153
    {
154
14
      greaterOrEqual(elements, itA, itB);
155
14
      itB++;
156
    }
157
  }
158
159
  // handle the remaining elements from A
160
34
  remainderOfA(elements, elementsA, itA);
161
  // handle the remaining elements from B
162
34
  remainderOfA(elements, elementsB, itB);
163
164
34
  Trace("bags-evaluate") << "elements: " << elements << std::endl;
165
34
  Node bag = constructConstantBagFromElements(n.getType(), elements);
166
34
  Trace("bags-evaluate") << "bag: " << bag << std::endl;
167
68
  return bag;
168
}
169
170
145
std::map<Node, Rational> NormalForm::getBagElements(TNode n)
171
{
172
145
  Assert(n.isConst()) << "node " << n << " is not in a normal form"
173
                      << std::endl;
174
145
  std::map<Node, Rational> elements;
175
145
  if (n.getKind() == EMPTYBAG)
176
  {
177
24
    return elements;
178
  }
179
241
  while (n.getKind() == kind::UNION_DISJOINT)
180
  {
181
60
    Assert(n[0].getKind() == kind::MK_BAG);
182
120
    Node element = n[0][0];
183
120
    Rational count = n[0][1].getConst<Rational>();
184
60
    elements[element] = count;
185
60
    n = n[1];
186
  }
187
121
  Assert(n.getKind() == kind::MK_BAG);
188
242
  Node lastElement = n[0];
189
242
  Rational lastCount = n[1].getConst<Rational>();
190
121
  elements[lastElement] = lastCount;
191
121
  return elements;
192
}
193
194
51
Node NormalForm::constructConstantBagFromElements(
195
    TypeNode t, const std::map<Node, Rational>& elements)
196
{
197
51
  Assert(t.isBag());
198
51
  NodeManager* nm = NodeManager::currentNM();
199
51
  if (elements.empty())
200
  {
201
8
    return nm->mkConst(EmptyBag(t));
202
  }
203
86
  TypeNode elementType = t.getBagElementType();
204
43
  std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
205
  Node bag =
206
86
      nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
207
99
  while (++it != elements.rend())
208
  {
209
    Node n =
210
56
        nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
211
28
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
212
  }
213
43
  return bag;
214
}
215
216
30
Node NormalForm::constructBagFromElements(TypeNode t,
217
                                          const std::map<Node, Node>& elements)
218
{
219
30
  Assert(t.isBag());
220
30
  NodeManager* nm = NodeManager::currentNM();
221
30
  if (elements.empty())
222
  {
223
2
    return nm->mkConst(EmptyBag(t));
224
  }
225
56
  TypeNode elementType = t.getBagElementType();
226
28
  std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
227
56
  Node bag = nm->mkBag(elementType, it->first, it->second);
228
56
  while (++it != elements.rend())
229
  {
230
28
    Node n = nm->mkBag(elementType, it->first, it->second);
231
14
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
232
  }
233
28
  return bag;
234
}
235
236
11
Node NormalForm::evaluateMakeBag(TNode n)
237
{
238
  // the case where n is const should be handled earlier.
239
  // here we handle the case where the multiplicity is zero or negative
240
11
  Assert(n.getKind() == MK_BAG && !n.isConst()
241
         && n[1].getConst<Rational>().sgn() < 1);
242
11
  Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
243
11
  return emptybag;
244
}
245
246
54
Node NormalForm::evaluateBagCount(TNode n)
247
{
248
54
  Assert(n.getKind() == BAG_COUNT);
249
  // Examples
250
  // --------
251
  // - (bag.count "x" (emptybag String)) = 0
252
  // - (bag.count "x" (mkBag "y" 5)) = 0
253
  // - (bag.count "x" (mkBag "x" 4)) = 4
254
  // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
255
  // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
256
257
108
  std::map<Node, Rational> elements = getBagElements(n[1]);
258
54
  std::map<Node, Rational>::iterator it = elements.find(n[0]);
259
260
54
  NodeManager* nm = NodeManager::currentNM();
261
54
  if (it != elements.end())
262
  {
263
74
    Node count = nm->mkConst(it->second);
264
37
    return count;
265
  }
266
17
  return nm->mkConst(Rational(0));
267
}
268
269
7
Node NormalForm::evaluateDuplicateRemoval(TNode n)
270
{
271
7
  Assert(n.getKind() == DUPLICATE_REMOVAL);
272
273
  // Examples
274
  // --------
275
  //  - (duplicate_removal (emptybag String)) = (emptybag String)
276
  //  - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1)
277
  //  - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
278
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1)
279
280
14
  std::map<Node, Rational> oldElements = getBagElements(n[0]);
281
  // copy elements from the old bag
282
14
  std::map<Node, Rational> newElements(oldElements);
283
14
  Rational one = Rational(1);
284
7
  std::map<Node, Rational>::iterator it;
285
14
  for (it = newElements.begin(); it != newElements.end(); it++)
286
  {
287
7
    it->second = one;
288
  }
289
7
  Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
290
14
  return bag;
291
}
292
293
20
Node NormalForm::evaluateUnionDisjoint(TNode n)
294
{
295
20
  Assert(n.getKind() == UNION_DISJOINT);
296
  // Example
297
  // -------
298
  // input: (union_disjoint A B)
299
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
300
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
301
  // output:
302
  //    (union_disjoint A B)
303
  //        where A = (MK_BAG "x" 7)
304
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
305
306
  auto equal = [](std::map<Node, Rational>& elements,
307
                  std::map<Node, Rational>::const_iterator& itA,
308
8
                  std::map<Node, Rational>::const_iterator& itB) {
309
    // compute the sum of the multiplicities
310
8
    elements[itA->first] = itA->second + itB->second;
311
8
  };
312
313
  auto less = [](std::map<Node, Rational>& elements,
314
                 std::map<Node, Rational>::const_iterator& itA,
315
4
                 std::map<Node, Rational>::const_iterator& itB) {
316
    // add the element to the result
317
4
    elements[itA->first] = itA->second;
318
4
  };
319
320
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
321
                           std::map<Node, Rational>::const_iterator& itA,
322
6
                           std::map<Node, Rational>::const_iterator& itB) {
323
    // add the element to the result
324
6
    elements[itB->first] = itB->second;
325
6
  };
326
327
  auto remainderOfA = [](std::map<Node, Rational>& elements,
328
                         std::map<Node, Rational>& elementsA,
329
52
                         std::map<Node, Rational>::const_iterator& itA) {
330
    // append the remainder of A
331
64
    while (itA != elementsA.end())
332
    {
333
12
      elements[itA->first] = itA->second;
334
12
      itA++;
335
    }
336
40
  };
337
338
  auto remainderOfB = [](std::map<Node, Rational>& elements,
339
                         std::map<Node, Rational>& elementsB,
340
                         std::map<Node, Rational>::const_iterator& itB) {
341
    // append the remainder of B
342
    while (itB != elementsB.end())
343
    {
344
      elements[itB->first] = itB->second;
345
      itB++;
346
    }
347
  };
348
349
  return evaluateBinaryOperation(
350
20
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
351
}
352
353
6
Node NormalForm::evaluateUnionMax(TNode n)
354
{
355
6
  Assert(n.getKind() == UNION_MAX);
356
  // Example
357
  // -------
358
  // input: (union_max A B)
359
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
360
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
361
  // output:
362
  //    (union_disjoint A B)
363
  //        where A = (MK_BAG "x" 4)
364
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
365
366
  auto equal = [](std::map<Node, Rational>& elements,
367
                  std::map<Node, Rational>::const_iterator& itA,
368
4
                  std::map<Node, Rational>::const_iterator& itB) {
369
    // compute the maximum multiplicity
370
4
    elements[itA->first] = std::max(itA->second, itB->second);
371
4
  };
372
373
  auto less = [](std::map<Node, Rational>& elements,
374
                 std::map<Node, Rational>::const_iterator& itA,
375
2
                 std::map<Node, Rational>::const_iterator& itB) {
376
    // add to the result
377
2
    elements[itA->first] = itA->second;
378
2
  };
379
380
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
381
                           std::map<Node, Rational>::const_iterator& itA,
382
2
                           std::map<Node, Rational>::const_iterator& itB) {
383
    // add to the result
384
2
    elements[itB->first] = itB->second;
385
2
  };
386
387
  auto remainderOfA = [](std::map<Node, Rational>& elements,
388
                         std::map<Node, Rational>& elementsA,
389
16
                         std::map<Node, Rational>::const_iterator& itA) {
390
    // append the remainder of A
391
20
    while (itA != elementsA.end())
392
    {
393
4
      elements[itA->first] = itA->second;
394
4
      itA++;
395
    }
396
12
  };
397
398
  auto remainderOfB = [](std::map<Node, Rational>& elements,
399
                         std::map<Node, Rational>& elementsB,
400
                         std::map<Node, Rational>::const_iterator& itB) {
401
    // append the remainder of B
402
    while (itB != elementsB.end())
403
    {
404
      elements[itB->first] = itB->second;
405
      itB++;
406
    }
407
  };
408
409
  return evaluateBinaryOperation(
410
6
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
411
}
412
413
3
Node NormalForm::evaluateIntersectionMin(TNode n)
414
{
415
3
  Assert(n.getKind() == INTERSECTION_MIN);
416
  // Example
417
  // -------
418
  // input: (intersectionMin A B)
419
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
420
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
421
  // output:
422
  //        (MK_BAG "x" 3)
423
424
  auto equal = [](std::map<Node, Rational>& elements,
425
                  std::map<Node, Rational>::const_iterator& itA,
426
4
                  std::map<Node, Rational>::const_iterator& itB) {
427
    // compute the minimum multiplicity
428
4
    elements[itA->first] = std::min(itA->second, itB->second);
429
4
  };
430
431
  auto less = [](std::map<Node, Rational>& elements,
432
                 std::map<Node, Rational>::const_iterator& itA,
433
                 std::map<Node, Rational>::const_iterator& itB) {
434
    // do nothing
435
  };
436
437
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
438
                           std::map<Node, Rational>::const_iterator& itA,
439
2
                           std::map<Node, Rational>::const_iterator& itB) {
440
    // do nothing
441
2
  };
442
443
  auto remainderOfA = [](std::map<Node, Rational>& elements,
444
                         std::map<Node, Rational>& elementsA,
445
6
                         std::map<Node, Rational>::const_iterator& itA) {
446
    // do nothing
447
6
  };
448
449
  auto remainderOfB = [](std::map<Node, Rational>& elements,
450
                         std::map<Node, Rational>& elementsB,
451
                         std::map<Node, Rational>::const_iterator& itB) {
452
    // do nothing
453
  };
454
455
  return evaluateBinaryOperation(
456
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
457
}
458
459
3
Node NormalForm::evaluateDifferenceSubtract(TNode n)
460
{
461
3
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
462
  // Example
463
  // -------
464
  // input: (difference_subtract A B)
465
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
466
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
467
  // output:
468
  //    (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
469
470
  auto equal = [](std::map<Node, Rational>& elements,
471
                  std::map<Node, Rational>::const_iterator& itA,
472
4
                  std::map<Node, Rational>::const_iterator& itB) {
473
    // subtract the multiplicities
474
4
    elements[itA->first] = itA->second - itB->second;
475
4
  };
476
477
  auto less = [](std::map<Node, Rational>& elements,
478
                 std::map<Node, Rational>::const_iterator& itA,
479
                 std::map<Node, Rational>::const_iterator& itB) {
480
    // itA->first is not in B, so we add it to the difference subtract
481
    elements[itA->first] = itA->second;
482
  };
483
484
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
485
                           std::map<Node, Rational>::const_iterator& itA,
486
2
                           std::map<Node, Rational>::const_iterator& itB) {
487
    // itB->first is not in A, so we just skip it
488
2
  };
489
490
  auto remainderOfA = [](std::map<Node, Rational>& elements,
491
                         std::map<Node, Rational>& elementsA,
492
8
                         std::map<Node, Rational>::const_iterator& itA) {
493
    // append the remainder of A
494
10
    while (itA != elementsA.end())
495
    {
496
2
      elements[itA->first] = itA->second;
497
2
      itA++;
498
    }
499
6
  };
500
501
  auto remainderOfB = [](std::map<Node, Rational>& elements,
502
                         std::map<Node, Rational>& elementsB,
503
                         std::map<Node, Rational>::const_iterator& itB) {
504
    // do nothing
505
  };
506
507
  return evaluateBinaryOperation(
508
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
509
}
510
511
2
Node NormalForm::evaluateDifferenceRemove(TNode n)
512
{
513
2
  Assert(n.getKind() == DIFFERENCE_REMOVE);
514
  // Example
515
  // -------
516
  // input: (difference_subtract A B)
517
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
518
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
519
  // output:
520
  //    (MK_BAG "z" 2)
521
522
  auto equal = [](std::map<Node, Rational>& elements,
523
                  std::map<Node, Rational>::const_iterator& itA,
524
2
                  std::map<Node, Rational>::const_iterator& itB) {
525
    // skip the shared element by doing nothing
526
2
  };
527
528
  auto less = [](std::map<Node, Rational>& elements,
529
                 std::map<Node, Rational>::const_iterator& itA,
530
                 std::map<Node, Rational>::const_iterator& itB) {
531
    // itA->first is not in B, so we add it to the difference remove
532
    elements[itA->first] = itA->second;
533
  };
534
535
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
536
                           std::map<Node, Rational>::const_iterator& itA,
537
2
                           std::map<Node, Rational>::const_iterator& itB) {
538
    // itB->first is not in A, so we just skip it
539
2
  };
540
541
  auto remainderOfA = [](std::map<Node, Rational>& elements,
542
                         std::map<Node, Rational>& elementsA,
543
6
                         std::map<Node, Rational>::const_iterator& itA) {
544
    // append the remainder of A
545
8
    while (itA != elementsA.end())
546
    {
547
2
      elements[itA->first] = itA->second;
548
2
      itA++;
549
    }
550
4
  };
551
552
  auto remainderOfB = [](std::map<Node, Rational>& elements,
553
                         std::map<Node, Rational>& elementsB,
554
                         std::map<Node, Rational>::const_iterator& itB) {
555
    // do nothing
556
  };
557
558
  return evaluateBinaryOperation(
559
2
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
560
}
561
562
6
Node NormalForm::evaluateChoose(TNode n)
563
{
564
6
  Assert(n.getKind() == BAG_CHOOSE);
565
  // Examples
566
  // --------
567
  // - (choose (emptyBag String)) = "" // the empty string which is the first
568
  //   element returned by the type enumerator
569
  // - (choose (MK_BAG "x" 4)) = "x"
570
  // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x"
571
  //     deterministically return the first element
572
573
6
  if (n[0].getKind() == EMPTYBAG)
574
  {
575
4
    TypeNode elementType = n[0].getType().getBagElementType();
576
4
    TypeEnumerator typeEnumerator(elementType);
577
    // get the first value from the typeEnumerator
578
4
    Node element = *typeEnumerator;
579
2
    return element;
580
  }
581
582
4
  if (n[0].getKind() == MK_BAG)
583
  {
584
2
    return n[0][0];
585
  }
586
2
  Assert(n[0].getKind() == UNION_DISJOINT);
587
  // return the first element
588
  // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
589
2
  return n[0][0][0];
590
}
591
592
6
Node NormalForm::evaluateCard(TNode n)
593
{
594
6
  Assert(n.getKind() == BAG_CARD);
595
  // Examples
596
  // --------
597
  //  - (card (emptyBag String)) = 0
598
  //  - (choose (MK_BAG "x" 4)) = 4
599
  //  - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
600
601
12
  std::map<Node, Rational> elements = getBagElements(n[0]);
602
12
  Rational sum(0);
603
12
  for (std::pair<Node, Rational> element : elements)
604
  {
605
6
    sum += element.second;
606
  }
607
608
6
  NodeManager* nm = NodeManager::currentNM();
609
6
  Node sumNode = nm->mkConst(sum);
610
12
  return sumNode;
611
}
612
613
8
Node NormalForm::evaluateIsSingleton(TNode n)
614
{
615
8
  Assert(n.getKind() == BAG_IS_SINGLETON);
616
  // Examples
617
  // --------
618
  // - (bag.is_singleton (emptyBag String)) = false
619
  // - (bag.is_singleton (MK_BAG "x" 1)) = true
620
  // - (bag.is_singleton (MK_BAG "x" 4)) = false
621
  // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false
622
623
8
  if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne())
624
  {
625
4
    return NodeManager::currentNM()->mkConst(true);
626
  }
627
4
  return NodeManager::currentNM()->mkConst(false);
628
}
629
630
6
Node NormalForm::evaluateFromSet(TNode n)
631
{
632
6
  Assert(n.getKind() == BAG_FROM_SET);
633
634
  // Examples
635
  // --------
636
  //  - (bag.from_set (emptyset String)) = (emptybag String)
637
  //  - (bag.from_set (singleton "x")) = (mkBag "x" 1)
638
  //  - (bag.from_set (union (singleton "x") (singleton "y"))) =
639
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
640
641
6
  NodeManager* nm = NodeManager::currentNM();
642
  std::set<Node> setElements =
643
12
      sets::NormalForm::getElementsFromNormalConstant(n[0]);
644
12
  Rational one = Rational(1);
645
12
  std::map<Node, Rational> bagElements;
646
12
  for (const Node& element : setElements)
647
  {
648
6
    bagElements[element] = one;
649
  }
650
12
  TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
651
6
  Node bag = constructConstantBagFromElements(bagType, bagElements);
652
12
  return bag;
653
}
654
655
6
Node NormalForm::evaluateToSet(TNode n)
656
{
657
6
  Assert(n.getKind() == BAG_TO_SET);
658
659
  // Examples
660
  // --------
661
  //  - (bag.to_set (emptybag String)) = (emptyset String)
662
  //  - (bag.to_set (mkBag "x" 4)) = (singleton "x")
663
  //  - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
664
  //     (union (singleton "x") (singleton "y")))
665
666
6
  NodeManager* nm = NodeManager::currentNM();
667
12
  std::map<Node, Rational> bagElements = getBagElements(n[0]);
668
12
  std::set<Node> setElements;
669
6
  std::map<Node, Rational>::const_reverse_iterator it;
670
12
  for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
671
  {
672
6
    setElements.insert(it->first);
673
  }
674
12
  TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
675
6
  Node set = sets::NormalForm::elementsToSet(setElements, setType);
676
12
  return set;
677
}
678
679
680
2
Node NormalForm::evaluateBagMap(TNode n)
681
{
682
2
  Assert(n.getKind() == BAG_MAP);
683
684
  // Examples
685
  // --------
686
  // - (bag.map ((lambda ((x String)) "z")
687
  //            (union_disjoint (bag "a" 2) (bag "b" 3)) =
688
  //     (union_disjoint
689
  //       (bag ((lambda ((x String)) "z") "a") 2)
690
  //       (bag ((lambda ((x String)) "z") "b") 3)) =
691
  //     (bag "z" 5)
692
693
4
  std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
694
4
  std::map<Node, Rational> mappedElements;
695
2
  std::map<Node, Rational>::iterator it = elements.begin();
696
2
  NodeManager* nm = NodeManager::currentNM();
697
10
  while (it != elements.end())
698
  {
699
8
    Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
700
4
    mappedElements[mappedElement] = it->second;
701
4
    ++it;
702
  }
703
4
  TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
704
2
  Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
705
4
  return ret;
706
}
707
708
}  // namespace bags
709
}  // namespace theory
710
22746
}  // namespace cvc5