GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/normal_form.cpp Lines: 243 261 93.1 %
Date: 2021-08-14 Branches: 556 1290 43.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Normal form for bag constants.
14
 */
15
#include "normal_form.h"
16
17
#include "expr/emptybag.h"
18
#include "theory/sets/normal_form.h"
19
#include "theory/type_enumerator.h"
20
#include "util/rational.h"
21
22
using namespace cvc5::kind;
23
24
namespace cvc5 {
25
namespace theory {
26
namespace bags {
27
28
117
bool NormalForm::isConstant(TNode n)
29
{
30
117
  if (n.getKind() == EMPTYBAG)
31
  {
32
    // empty bags are already normalized
33
    return true;
34
  }
35
117
  if (n.getKind() == MK_BAG)
36
  {
37
    // see the implementation in MkBagTypeRule::computeIsConst
38
    return n.isConst();
39
  }
40
117
  if (n.getKind() == UNION_DISJOINT)
41
  {
42
117
    if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst()))
43
    {
44
      // the first child is not a constant
45
59
      return false;
46
    }
47
    // store the previous element to check the ordering of elements
48
116
    Node previousElement = n[0][0];
49
116
    Node current = n[1];
50
74
    while (current.getKind() == UNION_DISJOINT)
51
    {
52
8
      if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst()))
53
      {
54
        // the current element is not a constant
55
        return false;
56
      }
57
8
      if (previousElement >= current[0][0])
58
      {
59
        // the ordering is violated
60
        return false;
61
      }
62
8
      previousElement = current[0][0];
63
8
      current = current[1];
64
    }
65
    // check last element
66
58
    if (!(current.getKind() == kind::MK_BAG && current.isConst()))
67
    {
68
      // the last element is not a constant
69
      return false;
70
    }
71
58
    if (previousElement >= current[0])
72
    {
73
      // the ordering is violated
74
8
      return false;
75
    }
76
50
    return true;
77
  }
78
79
  // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be
80
  // constants
81
  return false;
82
}
83
84
1504
bool NormalForm::areChildrenConstants(TNode n)
85
{
86
3493
  return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
87
}
88
89
146
Node NormalForm::evaluate(TNode n)
90
{
91
146
  Assert(areChildrenConstants(n));
92
146
  if (n.isConst())
93
  {
94
    // a constant node is already in a normal form
95
6
    return n;
96
  }
97
140
  switch (n.getKind())
98
  {
99
11
    case MK_BAG: return evaluateMakeBag(n);
100
58
    case BAG_COUNT: return evaluateBagCount(n);
101
7
    case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
102
18
    case UNION_DISJOINT: return evaluateUnionDisjoint(n);
103
6
    case UNION_MAX: return evaluateUnionMax(n);
104
3
    case INTERSECTION_MIN: return evaluateIntersectionMin(n);
105
3
    case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
106
2
    case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
107
6
    case BAG_CHOOSE: return evaluateChoose(n);
108
6
    case BAG_CARD: return evaluateCard(n);
109
8
    case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
110
6
    case BAG_FROM_SET: return evaluateFromSet(n);
111
6
    case BAG_TO_SET: return evaluateToSet(n);
112
    default: break;
113
  }
114
  Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
115
              << std::endl;
116
}
117
118
template <typename T1, typename T2, typename T3, typename T4, typename T5>
119
32
Node NormalForm::evaluateBinaryOperation(const TNode& n,
120
                                         T1&& equal,
121
                                         T2&& less,
122
                                         T3&& greaterOrEqual,
123
                                         T4&& remainderOfA,
124
                                         T5&& remainderOfB)
125
{
126
64
  std::map<Node, Rational> elementsA = getBagElements(n[0]);
127
64
  std::map<Node, Rational> elementsB = getBagElements(n[1]);
128
64
  std::map<Node, Rational> elements;
129
130
32
  std::map<Node, Rational>::const_iterator itA = elementsA.begin();
131
32
  std::map<Node, Rational>::const_iterator itB = elementsB.begin();
132
133
96
  Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
134
64
                         << n.getKind() << "] " << std::endl
135
32
                         << "elements A: " << elementsA << std::endl
136
32
                         << "elements B: " << elementsB << std::endl;
137
138
112
  while (itA != elementsA.end() && itB != elementsB.end())
139
  {
140
40
    if (itA->first == itB->first)
141
    {
142
20
      equal(elements, itA, itB);
143
20
      itA++;
144
20
      itB++;
145
    }
146
20
    else if (itA->first < itB->first)
147
    {
148
6
      less(elements, itA, itB);
149
6
      itA++;
150
    }
151
    else
152
    {
153
14
      greaterOrEqual(elements, itA, itB);
154
14
      itB++;
155
    }
156
  }
157
158
  // handle the remaining elements from A
159
32
  remainderOfA(elements, elementsA, itA);
160
  // handle the remaining elements from B
161
32
  remainderOfA(elements, elementsB, itB);
162
163
32
  Trace("bags-evaluate") << "elements: " << elements << std::endl;
164
32
  Node bag = constructConstantBagFromElements(n.getType(), elements);
165
32
  Trace("bags-evaluate") << "bag: " << bag << std::endl;
166
64
  return bag;
167
}
168
169
141
std::map<Node, Rational> NormalForm::getBagElements(TNode n)
170
{
171
141
  Assert(n.isConst()) << "node " << n << " is not in a normal form"
172
                      << std::endl;
173
141
  std::map<Node, Rational> elements;
174
141
  if (n.getKind() == EMPTYBAG)
175
  {
176
26
    return elements;
177
  }
178
231
  while (n.getKind() == kind::UNION_DISJOINT)
179
  {
180
58
    Assert(n[0].getKind() == kind::MK_BAG);
181
116
    Node element = n[0][0];
182
116
    Rational count = n[0][1].getConst<Rational>();
183
58
    elements[element] = count;
184
58
    n = n[1];
185
  }
186
115
  Assert(n.getKind() == kind::MK_BAG);
187
230
  Node lastElement = n[0];
188
230
  Rational lastCount = n[1].getConst<Rational>();
189
115
  elements[lastElement] = lastCount;
190
115
  return elements;
191
}
192
193
45
Node NormalForm::constructConstantBagFromElements(
194
    TypeNode t, const std::map<Node, Rational>& elements)
195
{
196
45
  Assert(t.isBag());
197
45
  NodeManager* nm = NodeManager::currentNM();
198
45
  if (elements.empty())
199
  {
200
6
    return nm->mkConst(EmptyBag(t));
201
  }
202
78
  TypeNode elementType = t.getBagElementType();
203
39
  std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
204
  Node bag =
205
78
      nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
206
91
  while (++it != elements.rend())
207
  {
208
    Node n =
209
52
        nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
210
26
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
211
  }
212
39
  return bag;
213
}
214
215
30
Node NormalForm::constructBagFromElements(TypeNode t,
216
                                          const std::map<Node, Node>& elements)
217
{
218
30
  Assert(t.isBag());
219
30
  NodeManager* nm = NodeManager::currentNM();
220
30
  if (elements.empty())
221
  {
222
2
    return nm->mkConst(EmptyBag(t));
223
  }
224
56
  TypeNode elementType = t.getBagElementType();
225
28
  std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
226
56
  Node bag = nm->mkBag(elementType, it->first, it->second);
227
56
  while (++it != elements.rend())
228
  {
229
28
    Node n = nm->mkBag(elementType, it->first, it->second);
230
14
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
231
  }
232
28
  return bag;
233
}
234
235
11
Node NormalForm::evaluateMakeBag(TNode n)
236
{
237
  // the case where n is const should be handled earlier.
238
  // here we handle the case where the multiplicity is zero or negative
239
11
  Assert(n.getKind() == MK_BAG && !n.isConst()
240
         && n[1].getConst<Rational>().sgn() < 1);
241
11
  Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
242
11
  return emptybag;
243
}
244
245
58
Node NormalForm::evaluateBagCount(TNode n)
246
{
247
58
  Assert(n.getKind() == BAG_COUNT);
248
  // Examples
249
  // --------
250
  // - (bag.count "x" (emptybag String)) = 0
251
  // - (bag.count "x" (mkBag "y" 5)) = 0
252
  // - (bag.count "x" (mkBag "x" 4)) = 4
253
  // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
254
  // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
255
256
116
  std::map<Node, Rational> elements = getBagElements(n[1]);
257
58
  std::map<Node, Rational>::iterator it = elements.find(n[0]);
258
259
58
  NodeManager* nm = NodeManager::currentNM();
260
58
  if (it != elements.end())
261
  {
262
74
    Node count = nm->mkConst(it->second);
263
37
    return count;
264
  }
265
21
  return nm->mkConst(Rational(0));
266
}
267
268
7
Node NormalForm::evaluateDuplicateRemoval(TNode n)
269
{
270
7
  Assert(n.getKind() == DUPLICATE_REMOVAL);
271
272
  // Examples
273
  // --------
274
  //  - (duplicate_removal (emptybag String)) = (emptybag String)
275
  //  - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1)
276
  //  - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
277
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1)
278
279
14
  std::map<Node, Rational> oldElements = getBagElements(n[0]);
280
  // copy elements from the old bag
281
14
  std::map<Node, Rational> newElements(oldElements);
282
14
  Rational one = Rational(1);
283
7
  std::map<Node, Rational>::iterator it;
284
14
  for (it = newElements.begin(); it != newElements.end(); it++)
285
  {
286
7
    it->second = one;
287
  }
288
7
  Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
289
14
  return bag;
290
}
291
292
18
Node NormalForm::evaluateUnionDisjoint(TNode n)
293
{
294
18
  Assert(n.getKind() == UNION_DISJOINT);
295
  // Example
296
  // -------
297
  // input: (union_disjoint A B)
298
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
299
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
300
  // output:
301
  //    (union_disjoint A B)
302
  //        where A = (MK_BAG "x" 7)
303
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
304
305
  auto equal = [](std::map<Node, Rational>& elements,
306
                  std::map<Node, Rational>::const_iterator& itA,
307
6
                  std::map<Node, Rational>::const_iterator& itB) {
308
    // compute the sum of the multiplicities
309
6
    elements[itA->first] = itA->second + itB->second;
310
6
  };
311
312
  auto less = [](std::map<Node, Rational>& elements,
313
                 std::map<Node, Rational>::const_iterator& itA,
314
4
                 std::map<Node, Rational>::const_iterator& itB) {
315
    // add the element to the result
316
4
    elements[itA->first] = itA->second;
317
4
  };
318
319
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
320
                           std::map<Node, Rational>::const_iterator& itA,
321
6
                           std::map<Node, Rational>::const_iterator& itB) {
322
    // add the element to the result
323
6
    elements[itB->first] = itB->second;
324
6
  };
325
326
  auto remainderOfA = [](std::map<Node, Rational>& elements,
327
                         std::map<Node, Rational>& elementsA,
328
48
                         std::map<Node, Rational>::const_iterator& itA) {
329
    // append the remainder of A
330
60
    while (itA != elementsA.end())
331
    {
332
12
      elements[itA->first] = itA->second;
333
12
      itA++;
334
    }
335
36
  };
336
337
  auto remainderOfB = [](std::map<Node, Rational>& elements,
338
                         std::map<Node, Rational>& elementsB,
339
                         std::map<Node, Rational>::const_iterator& itB) {
340
    // append the remainder of B
341
    while (itB != elementsB.end())
342
    {
343
      elements[itB->first] = itB->second;
344
      itB++;
345
    }
346
  };
347
348
  return evaluateBinaryOperation(
349
18
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
350
}
351
352
6
Node NormalForm::evaluateUnionMax(TNode n)
353
{
354
6
  Assert(n.getKind() == UNION_MAX);
355
  // Example
356
  // -------
357
  // input: (union_max A B)
358
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
359
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
360
  // output:
361
  //    (union_disjoint A B)
362
  //        where A = (MK_BAG "x" 4)
363
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
364
365
  auto equal = [](std::map<Node, Rational>& elements,
366
                  std::map<Node, Rational>::const_iterator& itA,
367
4
                  std::map<Node, Rational>::const_iterator& itB) {
368
    // compute the maximum multiplicity
369
4
    elements[itA->first] = std::max(itA->second, itB->second);
370
4
  };
371
372
  auto less = [](std::map<Node, Rational>& elements,
373
                 std::map<Node, Rational>::const_iterator& itA,
374
2
                 std::map<Node, Rational>::const_iterator& itB) {
375
    // add to the result
376
2
    elements[itA->first] = itA->second;
377
2
  };
378
379
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
380
                           std::map<Node, Rational>::const_iterator& itA,
381
2
                           std::map<Node, Rational>::const_iterator& itB) {
382
    // add to the result
383
2
    elements[itB->first] = itB->second;
384
2
  };
385
386
  auto remainderOfA = [](std::map<Node, Rational>& elements,
387
                         std::map<Node, Rational>& elementsA,
388
16
                         std::map<Node, Rational>::const_iterator& itA) {
389
    // append the remainder of A
390
20
    while (itA != elementsA.end())
391
    {
392
4
      elements[itA->first] = itA->second;
393
4
      itA++;
394
    }
395
12
  };
396
397
  auto remainderOfB = [](std::map<Node, Rational>& elements,
398
                         std::map<Node, Rational>& elementsB,
399
                         std::map<Node, Rational>::const_iterator& itB) {
400
    // append the remainder of B
401
    while (itB != elementsB.end())
402
    {
403
      elements[itB->first] = itB->second;
404
      itB++;
405
    }
406
  };
407
408
  return evaluateBinaryOperation(
409
6
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
410
}
411
412
3
Node NormalForm::evaluateIntersectionMin(TNode n)
413
{
414
3
  Assert(n.getKind() == INTERSECTION_MIN);
415
  // Example
416
  // -------
417
  // input: (intersectionMin A B)
418
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
419
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
420
  // output:
421
  //        (MK_BAG "x" 3)
422
423
  auto equal = [](std::map<Node, Rational>& elements,
424
                  std::map<Node, Rational>::const_iterator& itA,
425
4
                  std::map<Node, Rational>::const_iterator& itB) {
426
    // compute the minimum multiplicity
427
4
    elements[itA->first] = std::min(itA->second, itB->second);
428
4
  };
429
430
  auto less = [](std::map<Node, Rational>& elements,
431
                 std::map<Node, Rational>::const_iterator& itA,
432
                 std::map<Node, Rational>::const_iterator& itB) {
433
    // do nothing
434
  };
435
436
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
437
                           std::map<Node, Rational>::const_iterator& itA,
438
2
                           std::map<Node, Rational>::const_iterator& itB) {
439
    // do nothing
440
2
  };
441
442
  auto remainderOfA = [](std::map<Node, Rational>& elements,
443
                         std::map<Node, Rational>& elementsA,
444
6
                         std::map<Node, Rational>::const_iterator& itA) {
445
    // do nothing
446
6
  };
447
448
  auto remainderOfB = [](std::map<Node, Rational>& elements,
449
                         std::map<Node, Rational>& elementsB,
450
                         std::map<Node, Rational>::const_iterator& itB) {
451
    // do nothing
452
  };
453
454
  return evaluateBinaryOperation(
455
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
456
}
457
458
3
Node NormalForm::evaluateDifferenceSubtract(TNode n)
459
{
460
3
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
461
  // Example
462
  // -------
463
  // input: (difference_subtract A B)
464
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
465
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
466
  // output:
467
  //    (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
468
469
  auto equal = [](std::map<Node, Rational>& elements,
470
                  std::map<Node, Rational>::const_iterator& itA,
471
4
                  std::map<Node, Rational>::const_iterator& itB) {
472
    // subtract the multiplicities
473
4
    elements[itA->first] = itA->second - itB->second;
474
4
  };
475
476
  auto less = [](std::map<Node, Rational>& elements,
477
                 std::map<Node, Rational>::const_iterator& itA,
478
                 std::map<Node, Rational>::const_iterator& itB) {
479
    // itA->first is not in B, so we add it to the difference subtract
480
    elements[itA->first] = itA->second;
481
  };
482
483
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
484
                           std::map<Node, Rational>::const_iterator& itA,
485
2
                           std::map<Node, Rational>::const_iterator& itB) {
486
    // itB->first is not in A, so we just skip it
487
2
  };
488
489
  auto remainderOfA = [](std::map<Node, Rational>& elements,
490
                         std::map<Node, Rational>& elementsA,
491
8
                         std::map<Node, Rational>::const_iterator& itA) {
492
    // append the remainder of A
493
10
    while (itA != elementsA.end())
494
    {
495
2
      elements[itA->first] = itA->second;
496
2
      itA++;
497
    }
498
6
  };
499
500
  auto remainderOfB = [](std::map<Node, Rational>& elements,
501
                         std::map<Node, Rational>& elementsB,
502
                         std::map<Node, Rational>::const_iterator& itB) {
503
    // do nothing
504
  };
505
506
  return evaluateBinaryOperation(
507
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
508
}
509
510
2
Node NormalForm::evaluateDifferenceRemove(TNode n)
511
{
512
2
  Assert(n.getKind() == DIFFERENCE_REMOVE);
513
  // Example
514
  // -------
515
  // input: (difference_subtract A B)
516
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
517
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
518
  // output:
519
  //    (MK_BAG "z" 2)
520
521
  auto equal = [](std::map<Node, Rational>& elements,
522
                  std::map<Node, Rational>::const_iterator& itA,
523
2
                  std::map<Node, Rational>::const_iterator& itB) {
524
    // skip the shared element by doing nothing
525
2
  };
526
527
  auto less = [](std::map<Node, Rational>& elements,
528
                 std::map<Node, Rational>::const_iterator& itA,
529
                 std::map<Node, Rational>::const_iterator& itB) {
530
    // itA->first is not in B, so we add it to the difference remove
531
    elements[itA->first] = itA->second;
532
  };
533
534
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
535
                           std::map<Node, Rational>::const_iterator& itA,
536
2
                           std::map<Node, Rational>::const_iterator& itB) {
537
    // itB->first is not in A, so we just skip it
538
2
  };
539
540
  auto remainderOfA = [](std::map<Node, Rational>& elements,
541
                         std::map<Node, Rational>& elementsA,
542
6
                         std::map<Node, Rational>::const_iterator& itA) {
543
    // append the remainder of A
544
8
    while (itA != elementsA.end())
545
    {
546
2
      elements[itA->first] = itA->second;
547
2
      itA++;
548
    }
549
4
  };
550
551
  auto remainderOfB = [](std::map<Node, Rational>& elements,
552
                         std::map<Node, Rational>& elementsB,
553
                         std::map<Node, Rational>::const_iterator& itB) {
554
    // do nothing
555
  };
556
557
  return evaluateBinaryOperation(
558
2
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
559
}
560
561
6
Node NormalForm::evaluateChoose(TNode n)
562
{
563
6
  Assert(n.getKind() == BAG_CHOOSE);
564
  // Examples
565
  // --------
566
  // - (choose (emptyBag String)) = "" // the empty string which is the first
567
  //   element returned by the type enumerator
568
  // - (choose (MK_BAG "x" 4)) = "x"
569
  // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x"
570
  //     deterministically return the first element
571
572
6
  if (n[0].getKind() == EMPTYBAG)
573
  {
574
4
    TypeNode elementType = n[0].getType().getBagElementType();
575
4
    TypeEnumerator typeEnumerator(elementType);
576
    // get the first value from the typeEnumerator
577
4
    Node element = *typeEnumerator;
578
2
    return element;
579
  }
580
581
4
  if (n[0].getKind() == MK_BAG)
582
  {
583
2
    return n[0][0];
584
  }
585
2
  Assert(n[0].getKind() == UNION_DISJOINT);
586
  // return the first element
587
  // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
588
2
  return n[0][0][0];
589
}
590
591
6
Node NormalForm::evaluateCard(TNode n)
592
{
593
6
  Assert(n.getKind() == BAG_CARD);
594
  // Examples
595
  // --------
596
  //  - (card (emptyBag String)) = 0
597
  //  - (choose (MK_BAG "x" 4)) = 4
598
  //  - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
599
600
12
  std::map<Node, Rational> elements = getBagElements(n[0]);
601
12
  Rational sum(0);
602
12
  for (std::pair<Node, Rational> element : elements)
603
  {
604
6
    sum += element.second;
605
  }
606
607
6
  NodeManager* nm = NodeManager::currentNM();
608
6
  Node sumNode = nm->mkConst(sum);
609
12
  return sumNode;
610
}
611
612
8
Node NormalForm::evaluateIsSingleton(TNode n)
613
{
614
8
  Assert(n.getKind() == BAG_IS_SINGLETON);
615
  // Examples
616
  // --------
617
  // - (bag.is_singleton (emptyBag String)) = false
618
  // - (bag.is_singleton (MK_BAG "x" 1)) = true
619
  // - (bag.is_singleton (MK_BAG "x" 4)) = false
620
  // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false
621
622
8
  if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne())
623
  {
624
4
    return NodeManager::currentNM()->mkConst(true);
625
  }
626
4
  return NodeManager::currentNM()->mkConst(false);
627
}
628
629
6
Node NormalForm::evaluateFromSet(TNode n)
630
{
631
6
  Assert(n.getKind() == BAG_FROM_SET);
632
633
  // Examples
634
  // --------
635
  //  - (bag.from_set (emptyset String)) = (emptybag String)
636
  //  - (bag.from_set (singleton "x")) = (mkBag "x" 1)
637
  //  - (bag.from_set (union (singleton "x") (singleton "y"))) =
638
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
639
640
6
  NodeManager* nm = NodeManager::currentNM();
641
  std::set<Node> setElements =
642
12
      sets::NormalForm::getElementsFromNormalConstant(n[0]);
643
12
  Rational one = Rational(1);
644
12
  std::map<Node, Rational> bagElements;
645
12
  for (const Node& element : setElements)
646
  {
647
6
    bagElements[element] = one;
648
  }
649
12
  TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
650
6
  Node bag = constructConstantBagFromElements(bagType, bagElements);
651
12
  return bag;
652
}
653
654
6
Node NormalForm::evaluateToSet(TNode n)
655
{
656
6
  Assert(n.getKind() == BAG_TO_SET);
657
658
  // Examples
659
  // --------
660
  //  - (bag.to_set (emptybag String)) = (emptyset String)
661
  //  - (bag.to_set (mkBag "x" 4)) = (singleton "x")
662
  //  - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
663
  //     (union (singleton "x") (singleton "y")))
664
665
6
  NodeManager* nm = NodeManager::currentNM();
666
12
  std::map<Node, Rational> bagElements = getBagElements(n[0]);
667
12
  std::set<Node> setElements;
668
6
  std::map<Node, Rational>::const_reverse_iterator it;
669
12
  for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
670
  {
671
6
    setElements.insert(it->first);
672
  }
673
12
  TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
674
6
  Node set = sets::NormalForm::elementsToSet(setElements, setType);
675
12
  return set;
676
}
677
678
}  // namespace bags
679
}  // namespace theory
680
29340
}  // namespace cvc5