GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/normal_form.cpp Lines: 243 261 93.1 %
Date: 2021-05-22 Branches: 556 1292 43.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Normal form for bag constants.
14
 */
15
#include "normal_form.h"
16
17
#include "theory/sets/normal_form.h"
18
#include "theory/type_enumerator.h"
19
20
using namespace cvc5::kind;
21
22
namespace cvc5 {
23
namespace theory {
24
namespace bags {
25
26
117
bool NormalForm::isConstant(TNode n)
27
{
28
117
  if (n.getKind() == EMPTYBAG)
29
  {
30
    // empty bags are already normalized
31
    return true;
32
  }
33
117
  if (n.getKind() == MK_BAG)
34
  {
35
    // see the implementation in MkBagTypeRule::computeIsConst
36
    return n.isConst();
37
  }
38
117
  if (n.getKind() == UNION_DISJOINT)
39
  {
40
117
    if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst()))
41
    {
42
      // the first child is not a constant
43
59
      return false;
44
    }
45
    // store the previous element to check the ordering of elements
46
116
    Node previousElement = n[0][0];
47
116
    Node current = n[1];
48
74
    while (current.getKind() == UNION_DISJOINT)
49
    {
50
8
      if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst()))
51
      {
52
        // the current element is not a constant
53
        return false;
54
      }
55
8
      if (previousElement >= current[0][0])
56
      {
57
        // the ordering is violated
58
        return false;
59
      }
60
8
      previousElement = current[0][0];
61
8
      current = current[1];
62
    }
63
    // check last element
64
58
    if (!(current.getKind() == kind::MK_BAG && current.isConst()))
65
    {
66
      // the last element is not a constant
67
      return false;
68
    }
69
58
    if (previousElement >= current[0])
70
    {
71
      // the ordering is violated
72
8
      return false;
73
    }
74
50
    return true;
75
  }
76
77
  // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be
78
  // constants
79
  return false;
80
}
81
82
1509
bool NormalForm::areChildrenConstants(TNode n)
83
{
84
3503
  return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
85
}
86
87
146
Node NormalForm::evaluate(TNode n)
88
{
89
146
  Assert(areChildrenConstants(n));
90
146
  if (n.isConst())
91
  {
92
    // a constant node is already in a normal form
93
6
    return n;
94
  }
95
140
  switch (n.getKind())
96
  {
97
11
    case MK_BAG: return evaluateMakeBag(n);
98
58
    case BAG_COUNT: return evaluateBagCount(n);
99
7
    case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
100
18
    case UNION_DISJOINT: return evaluateUnionDisjoint(n);
101
6
    case UNION_MAX: return evaluateUnionMax(n);
102
3
    case INTERSECTION_MIN: return evaluateIntersectionMin(n);
103
3
    case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
104
2
    case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
105
6
    case BAG_CHOOSE: return evaluateChoose(n);
106
6
    case BAG_CARD: return evaluateCard(n);
107
8
    case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
108
6
    case BAG_FROM_SET: return evaluateFromSet(n);
109
6
    case BAG_TO_SET: return evaluateToSet(n);
110
    default: break;
111
  }
112
  Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
113
              << std::endl;
114
}
115
116
template <typename T1, typename T2, typename T3, typename T4, typename T5>
117
32
Node NormalForm::evaluateBinaryOperation(const TNode& n,
118
                                         T1&& equal,
119
                                         T2&& less,
120
                                         T3&& greaterOrEqual,
121
                                         T4&& remainderOfA,
122
                                         T5&& remainderOfB)
123
{
124
64
  std::map<Node, Rational> elementsA = getBagElements(n[0]);
125
64
  std::map<Node, Rational> elementsB = getBagElements(n[1]);
126
64
  std::map<Node, Rational> elements;
127
128
32
  std::map<Node, Rational>::const_iterator itA = elementsA.begin();
129
32
  std::map<Node, Rational>::const_iterator itB = elementsB.begin();
130
131
96
  Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
132
64
                         << n.getKind() << "] " << std::endl
133
32
                         << "elements A: " << elementsA << std::endl
134
32
                         << "elements B: " << elementsB << std::endl;
135
136
112
  while (itA != elementsA.end() && itB != elementsB.end())
137
  {
138
40
    if (itA->first == itB->first)
139
    {
140
20
      equal(elements, itA, itB);
141
20
      itA++;
142
20
      itB++;
143
    }
144
20
    else if (itA->first < itB->first)
145
    {
146
6
      less(elements, itA, itB);
147
6
      itA++;
148
    }
149
    else
150
    {
151
14
      greaterOrEqual(elements, itA, itB);
152
14
      itB++;
153
    }
154
  }
155
156
  // handle the remaining elements from A
157
32
  remainderOfA(elements, elementsA, itA);
158
  // handle the remaining elements from B
159
32
  remainderOfA(elements, elementsB, itB);
160
161
32
  Trace("bags-evaluate") << "elements: " << elements << std::endl;
162
32
  Node bag = constructConstantBagFromElements(n.getType(), elements);
163
32
  Trace("bags-evaluate") << "bag: " << bag << std::endl;
164
64
  return bag;
165
}
166
167
141
std::map<Node, Rational> NormalForm::getBagElements(TNode n)
168
{
169
141
  Assert(n.isConst()) << "node " << n << " is not in a normal form"
170
                      << std::endl;
171
141
  std::map<Node, Rational> elements;
172
141
  if (n.getKind() == EMPTYBAG)
173
  {
174
26
    return elements;
175
  }
176
231
  while (n.getKind() == kind::UNION_DISJOINT)
177
  {
178
58
    Assert(n[0].getKind() == kind::MK_BAG);
179
116
    Node element = n[0][0];
180
116
    Rational count = n[0][1].getConst<Rational>();
181
58
    elements[element] = count;
182
58
    n = n[1];
183
  }
184
115
  Assert(n.getKind() == kind::MK_BAG);
185
230
  Node lastElement = n[0];
186
230
  Rational lastCount = n[1].getConst<Rational>();
187
115
  elements[lastElement] = lastCount;
188
115
  return elements;
189
}
190
191
45
Node NormalForm::constructConstantBagFromElements(
192
    TypeNode t, const std::map<Node, Rational>& elements)
193
{
194
45
  Assert(t.isBag());
195
45
  NodeManager* nm = NodeManager::currentNM();
196
45
  if (elements.empty())
197
  {
198
6
    return nm->mkConst(EmptyBag(t));
199
  }
200
78
  TypeNode elementType = t.getBagElementType();
201
39
  std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
202
  Node bag =
203
78
      nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
204
91
  while (++it != elements.rend())
205
  {
206
    Node n =
207
52
        nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
208
26
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
209
  }
210
39
  return bag;
211
}
212
213
30
Node NormalForm::constructBagFromElements(TypeNode t,
214
                                          const std::map<Node, Node>& elements)
215
{
216
30
  Assert(t.isBag());
217
30
  NodeManager* nm = NodeManager::currentNM();
218
30
  if (elements.empty())
219
  {
220
2
    return nm->mkConst(EmptyBag(t));
221
  }
222
56
  TypeNode elementType = t.getBagElementType();
223
28
  std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
224
56
  Node bag = nm->mkBag(elementType, it->first, it->second);
225
56
  while (++it != elements.rend())
226
  {
227
28
    Node n = nm->mkBag(elementType, it->first, it->second);
228
14
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
229
  }
230
28
  return bag;
231
}
232
233
11
Node NormalForm::evaluateMakeBag(TNode n)
234
{
235
  // the case where n is const should be handled earlier.
236
  // here we handle the case where the multiplicity is zero or negative
237
11
  Assert(n.getKind() == MK_BAG && !n.isConst()
238
         && n[1].getConst<Rational>().sgn() < 1);
239
11
  Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
240
11
  return emptybag;
241
}
242
243
58
Node NormalForm::evaluateBagCount(TNode n)
244
{
245
58
  Assert(n.getKind() == BAG_COUNT);
246
  // Examples
247
  // --------
248
  // - (bag.count "x" (emptybag String)) = 0
249
  // - (bag.count "x" (mkBag "y" 5)) = 0
250
  // - (bag.count "x" (mkBag "x" 4)) = 4
251
  // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
252
  // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
253
254
116
  std::map<Node, Rational> elements = getBagElements(n[1]);
255
58
  std::map<Node, Rational>::iterator it = elements.find(n[0]);
256
257
58
  NodeManager* nm = NodeManager::currentNM();
258
58
  if (it != elements.end())
259
  {
260
74
    Node count = nm->mkConst(it->second);
261
37
    return count;
262
  }
263
21
  return nm->mkConst(Rational(0));
264
}
265
266
7
Node NormalForm::evaluateDuplicateRemoval(TNode n)
267
{
268
7
  Assert(n.getKind() == DUPLICATE_REMOVAL);
269
270
  // Examples
271
  // --------
272
  //  - (duplicate_removal (emptybag String)) = (emptybag String)
273
  //  - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1)
274
  //  - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
275
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1)
276
277
14
  std::map<Node, Rational> oldElements = getBagElements(n[0]);
278
  // copy elements from the old bag
279
14
  std::map<Node, Rational> newElements(oldElements);
280
14
  Rational one = Rational(1);
281
7
  std::map<Node, Rational>::iterator it;
282
14
  for (it = newElements.begin(); it != newElements.end(); it++)
283
  {
284
7
    it->second = one;
285
  }
286
7
  Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
287
14
  return bag;
288
}
289
290
18
Node NormalForm::evaluateUnionDisjoint(TNode n)
291
{
292
18
  Assert(n.getKind() == UNION_DISJOINT);
293
  // Example
294
  // -------
295
  // input: (union_disjoint A B)
296
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
297
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
298
  // output:
299
  //    (union_disjoint A B)
300
  //        where A = (MK_BAG "x" 7)
301
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
302
303
  auto equal = [](std::map<Node, Rational>& elements,
304
                  std::map<Node, Rational>::const_iterator& itA,
305
6
                  std::map<Node, Rational>::const_iterator& itB) {
306
    // compute the sum of the multiplicities
307
6
    elements[itA->first] = itA->second + itB->second;
308
6
  };
309
310
  auto less = [](std::map<Node, Rational>& elements,
311
                 std::map<Node, Rational>::const_iterator& itA,
312
4
                 std::map<Node, Rational>::const_iterator& itB) {
313
    // add the element to the result
314
4
    elements[itA->first] = itA->second;
315
4
  };
316
317
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
318
                           std::map<Node, Rational>::const_iterator& itA,
319
6
                           std::map<Node, Rational>::const_iterator& itB) {
320
    // add the element to the result
321
6
    elements[itB->first] = itB->second;
322
6
  };
323
324
  auto remainderOfA = [](std::map<Node, Rational>& elements,
325
                         std::map<Node, Rational>& elementsA,
326
48
                         std::map<Node, Rational>::const_iterator& itA) {
327
    // append the remainder of A
328
60
    while (itA != elementsA.end())
329
    {
330
12
      elements[itA->first] = itA->second;
331
12
      itA++;
332
    }
333
36
  };
334
335
  auto remainderOfB = [](std::map<Node, Rational>& elements,
336
                         std::map<Node, Rational>& elementsB,
337
                         std::map<Node, Rational>::const_iterator& itB) {
338
    // append the remainder of B
339
    while (itB != elementsB.end())
340
    {
341
      elements[itB->first] = itB->second;
342
      itB++;
343
    }
344
  };
345
346
  return evaluateBinaryOperation(
347
18
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
348
}
349
350
6
Node NormalForm::evaluateUnionMax(TNode n)
351
{
352
6
  Assert(n.getKind() == UNION_MAX);
353
  // Example
354
  // -------
355
  // input: (union_max A B)
356
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
357
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
358
  // output:
359
  //    (union_disjoint A B)
360
  //        where A = (MK_BAG "x" 4)
361
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
362
363
  auto equal = [](std::map<Node, Rational>& elements,
364
                  std::map<Node, Rational>::const_iterator& itA,
365
4
                  std::map<Node, Rational>::const_iterator& itB) {
366
    // compute the maximum multiplicity
367
4
    elements[itA->first] = std::max(itA->second, itB->second);
368
4
  };
369
370
  auto less = [](std::map<Node, Rational>& elements,
371
                 std::map<Node, Rational>::const_iterator& itA,
372
2
                 std::map<Node, Rational>::const_iterator& itB) {
373
    // add to the result
374
2
    elements[itA->first] = itA->second;
375
2
  };
376
377
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
378
                           std::map<Node, Rational>::const_iterator& itA,
379
2
                           std::map<Node, Rational>::const_iterator& itB) {
380
    // add to the result
381
2
    elements[itB->first] = itB->second;
382
2
  };
383
384
  auto remainderOfA = [](std::map<Node, Rational>& elements,
385
                         std::map<Node, Rational>& elementsA,
386
16
                         std::map<Node, Rational>::const_iterator& itA) {
387
    // append the remainder of A
388
20
    while (itA != elementsA.end())
389
    {
390
4
      elements[itA->first] = itA->second;
391
4
      itA++;
392
    }
393
12
  };
394
395
  auto remainderOfB = [](std::map<Node, Rational>& elements,
396
                         std::map<Node, Rational>& elementsB,
397
                         std::map<Node, Rational>::const_iterator& itB) {
398
    // append the remainder of B
399
    while (itB != elementsB.end())
400
    {
401
      elements[itB->first] = itB->second;
402
      itB++;
403
    }
404
  };
405
406
  return evaluateBinaryOperation(
407
6
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
408
}
409
410
3
Node NormalForm::evaluateIntersectionMin(TNode n)
411
{
412
3
  Assert(n.getKind() == INTERSECTION_MIN);
413
  // Example
414
  // -------
415
  // input: (intersectionMin A B)
416
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
417
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
418
  // output:
419
  //        (MK_BAG "x" 3)
420
421
  auto equal = [](std::map<Node, Rational>& elements,
422
                  std::map<Node, Rational>::const_iterator& itA,
423
4
                  std::map<Node, Rational>::const_iterator& itB) {
424
    // compute the minimum multiplicity
425
4
    elements[itA->first] = std::min(itA->second, itB->second);
426
4
  };
427
428
  auto less = [](std::map<Node, Rational>& elements,
429
                 std::map<Node, Rational>::const_iterator& itA,
430
                 std::map<Node, Rational>::const_iterator& itB) {
431
    // do nothing
432
  };
433
434
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
435
                           std::map<Node, Rational>::const_iterator& itA,
436
2
                           std::map<Node, Rational>::const_iterator& itB) {
437
    // do nothing
438
2
  };
439
440
  auto remainderOfA = [](std::map<Node, Rational>& elements,
441
                         std::map<Node, Rational>& elementsA,
442
6
                         std::map<Node, Rational>::const_iterator& itA) {
443
    // do nothing
444
6
  };
445
446
  auto remainderOfB = [](std::map<Node, Rational>& elements,
447
                         std::map<Node, Rational>& elementsB,
448
                         std::map<Node, Rational>::const_iterator& itB) {
449
    // do nothing
450
  };
451
452
  return evaluateBinaryOperation(
453
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
454
}
455
456
3
Node NormalForm::evaluateDifferenceSubtract(TNode n)
457
{
458
3
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
459
  // Example
460
  // -------
461
  // input: (difference_subtract A B)
462
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
463
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
464
  // output:
465
  //    (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
466
467
  auto equal = [](std::map<Node, Rational>& elements,
468
                  std::map<Node, Rational>::const_iterator& itA,
469
4
                  std::map<Node, Rational>::const_iterator& itB) {
470
    // subtract the multiplicities
471
4
    elements[itA->first] = itA->second - itB->second;
472
4
  };
473
474
  auto less = [](std::map<Node, Rational>& elements,
475
                 std::map<Node, Rational>::const_iterator& itA,
476
                 std::map<Node, Rational>::const_iterator& itB) {
477
    // itA->first is not in B, so we add it to the difference subtract
478
    elements[itA->first] = itA->second;
479
  };
480
481
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
482
                           std::map<Node, Rational>::const_iterator& itA,
483
2
                           std::map<Node, Rational>::const_iterator& itB) {
484
    // itB->first is not in A, so we just skip it
485
2
  };
486
487
  auto remainderOfA = [](std::map<Node, Rational>& elements,
488
                         std::map<Node, Rational>& elementsA,
489
8
                         std::map<Node, Rational>::const_iterator& itA) {
490
    // append the remainder of A
491
10
    while (itA != elementsA.end())
492
    {
493
2
      elements[itA->first] = itA->second;
494
2
      itA++;
495
    }
496
6
  };
497
498
  auto remainderOfB = [](std::map<Node, Rational>& elements,
499
                         std::map<Node, Rational>& elementsB,
500
                         std::map<Node, Rational>::const_iterator& itB) {
501
    // do nothing
502
  };
503
504
  return evaluateBinaryOperation(
505
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
506
}
507
508
2
Node NormalForm::evaluateDifferenceRemove(TNode n)
509
{
510
2
  Assert(n.getKind() == DIFFERENCE_REMOVE);
511
  // Example
512
  // -------
513
  // input: (difference_subtract A B)
514
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
515
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
516
  // output:
517
  //    (MK_BAG "z" 2)
518
519
  auto equal = [](std::map<Node, Rational>& elements,
520
                  std::map<Node, Rational>::const_iterator& itA,
521
2
                  std::map<Node, Rational>::const_iterator& itB) {
522
    // skip the shared element by doing nothing
523
2
  };
524
525
  auto less = [](std::map<Node, Rational>& elements,
526
                 std::map<Node, Rational>::const_iterator& itA,
527
                 std::map<Node, Rational>::const_iterator& itB) {
528
    // itA->first is not in B, so we add it to the difference remove
529
    elements[itA->first] = itA->second;
530
  };
531
532
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
533
                           std::map<Node, Rational>::const_iterator& itA,
534
2
                           std::map<Node, Rational>::const_iterator& itB) {
535
    // itB->first is not in A, so we just skip it
536
2
  };
537
538
  auto remainderOfA = [](std::map<Node, Rational>& elements,
539
                         std::map<Node, Rational>& elementsA,
540
6
                         std::map<Node, Rational>::const_iterator& itA) {
541
    // append the remainder of A
542
8
    while (itA != elementsA.end())
543
    {
544
2
      elements[itA->first] = itA->second;
545
2
      itA++;
546
    }
547
4
  };
548
549
  auto remainderOfB = [](std::map<Node, Rational>& elements,
550
                         std::map<Node, Rational>& elementsB,
551
                         std::map<Node, Rational>::const_iterator& itB) {
552
    // do nothing
553
  };
554
555
  return evaluateBinaryOperation(
556
2
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
557
}
558
559
6
Node NormalForm::evaluateChoose(TNode n)
560
{
561
6
  Assert(n.getKind() == BAG_CHOOSE);
562
  // Examples
563
  // --------
564
  // - (choose (emptyBag String)) = "" // the empty string which is the first
565
  //   element returned by the type enumerator
566
  // - (choose (MK_BAG "x" 4)) = "x"
567
  // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x"
568
  //     deterministically return the first element
569
570
6
  if (n[0].getKind() == EMPTYBAG)
571
  {
572
4
    TypeNode elementType = n[0].getType().getBagElementType();
573
4
    TypeEnumerator typeEnumerator(elementType);
574
    // get the first value from the typeEnumerator
575
4
    Node element = *typeEnumerator;
576
2
    return element;
577
  }
578
579
4
  if (n[0].getKind() == MK_BAG)
580
  {
581
2
    return n[0][0];
582
  }
583
2
  Assert(n[0].getKind() == UNION_DISJOINT);
584
  // return the first element
585
  // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
586
2
  return n[0][0][0];
587
}
588
589
6
Node NormalForm::evaluateCard(TNode n)
590
{
591
6
  Assert(n.getKind() == BAG_CARD);
592
  // Examples
593
  // --------
594
  //  - (card (emptyBag String)) = 0
595
  //  - (choose (MK_BAG "x" 4)) = 4
596
  //  - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
597
598
12
  std::map<Node, Rational> elements = getBagElements(n[0]);
599
12
  Rational sum(0);
600
12
  for (std::pair<Node, Rational> element : elements)
601
  {
602
6
    sum += element.second;
603
  }
604
605
6
  NodeManager* nm = NodeManager::currentNM();
606
6
  Node sumNode = nm->mkConst(sum);
607
12
  return sumNode;
608
}
609
610
8
Node NormalForm::evaluateIsSingleton(TNode n)
611
{
612
8
  Assert(n.getKind() == BAG_IS_SINGLETON);
613
  // Examples
614
  // --------
615
  // - (bag.is_singleton (emptyBag String)) = false
616
  // - (bag.is_singleton (MK_BAG "x" 1)) = true
617
  // - (bag.is_singleton (MK_BAG "x" 4)) = false
618
  // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false
619
620
8
  if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne())
621
  {
622
4
    return NodeManager::currentNM()->mkConst(true);
623
  }
624
4
  return NodeManager::currentNM()->mkConst(false);
625
}
626
627
6
Node NormalForm::evaluateFromSet(TNode n)
628
{
629
6
  Assert(n.getKind() == BAG_FROM_SET);
630
631
  // Examples
632
  // --------
633
  //  - (bag.from_set (emptyset String)) = (emptybag String)
634
  //  - (bag.from_set (singleton "x")) = (mkBag "x" 1)
635
  //  - (bag.from_set (union (singleton "x") (singleton "y"))) =
636
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
637
638
6
  NodeManager* nm = NodeManager::currentNM();
639
  std::set<Node> setElements =
640
12
      sets::NormalForm::getElementsFromNormalConstant(n[0]);
641
12
  Rational one = Rational(1);
642
12
  std::map<Node, Rational> bagElements;
643
12
  for (const Node& element : setElements)
644
  {
645
6
    bagElements[element] = one;
646
  }
647
12
  TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
648
6
  Node bag = constructConstantBagFromElements(bagType, bagElements);
649
12
  return bag;
650
}
651
652
6
Node NormalForm::evaluateToSet(TNode n)
653
{
654
6
  Assert(n.getKind() == BAG_TO_SET);
655
656
  // Examples
657
  // --------
658
  //  - (bag.to_set (emptybag String)) = (emptyset String)
659
  //  - (bag.to_set (mkBag "x" 4)) = (singleton "x")
660
  //  - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
661
  //     (union (singleton "x") (singleton "y")))
662
663
6
  NodeManager* nm = NodeManager::currentNM();
664
12
  std::map<Node, Rational> bagElements = getBagElements(n[0]);
665
12
  std::set<Node> setElements;
666
6
  std::map<Node, Rational>::const_reverse_iterator it;
667
12
  for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
668
  {
669
6
    setElements.insert(it->first);
670
  }
671
12
  TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
672
6
  Node set = sets::NormalForm::elementsToSet(setElements, setType);
673
12
  return set;
674
}
675
676
}  // namespace bags
677
}  // namespace theory
678
28191
}  // namespace cvc5