GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/normal_form.cpp Lines: 244 261 93.5 %
Date: 2021-03-23 Branches: 557 1292 43.1 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file normal_form.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Mudathir Mohamed
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Normal form for bag constants.
13
 **/
14
#include "normal_form.h"
15
16
#include "theory/sets/normal_form.h"
17
#include "theory/type_enumerator.h"
18
19
using namespace CVC4::kind;
20
21
namespace CVC4 {
22
namespace theory {
23
namespace bags {
24
25
121
bool NormalForm::isConstant(TNode n)
26
{
27
121
  if (n.getKind() == EMPTYBAG)
28
  {
29
    // empty bags are already normalized
30
    return true;
31
  }
32
121
  if (n.getKind() == MK_BAG)
33
  {
34
    // see the implementation in MkBagTypeRule::computeIsConst
35
    return n.isConst();
36
  }
37
121
  if (n.getKind() == UNION_DISJOINT)
38
  {
39
121
    if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst()))
40
    {
41
      // the first child is not a constant
42
59
      return false;
43
    }
44
    // store the previous element to check the ordering of elements
45
124
    Node previousElement = n[0][0];
46
124
    Node current = n[1];
47
78
    while (current.getKind() == UNION_DISJOINT)
48
    {
49
10
      if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst()))
50
      {
51
        // the current element is not a constant
52
        return false;
53
      }
54
10
      if (previousElement >= current[0][0])
55
      {
56
        // the ordering is violated
57
2
        return false;
58
      }
59
8
      previousElement = current[0][0];
60
8
      current = current[1];
61
    }
62
    // check last element
63
60
    if (!(current.getKind() == kind::MK_BAG && current.isConst()))
64
    {
65
      // the last element is not a constant
66
      return false;
67
    }
68
60
    if (previousElement >= current[0])
69
    {
70
      // the ordering is violated
71
8
      return false;
72
    }
73
52
    return true;
74
  }
75
76
  // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be
77
  // constants
78
  return false;
79
}
80
81
1474
bool NormalForm::areChildrenConstants(TNode n)
82
{
83
3433
  return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
84
}
85
86
147
Node NormalForm::evaluate(TNode n)
87
{
88
147
  Assert(areChildrenConstants(n));
89
147
  if (n.isConst())
90
  {
91
    // a constant node is already in a normal form
92
6
    return n;
93
  }
94
141
  switch (n.getKind())
95
  {
96
11
    case MK_BAG: return evaluateMakeBag(n);
97
57
    case BAG_COUNT: return evaluateBagCount(n);
98
7
    case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
99
20
    case UNION_DISJOINT: return evaluateUnionDisjoint(n);
100
6
    case UNION_MAX: return evaluateUnionMax(n);
101
3
    case INTERSECTION_MIN: return evaluateIntersectionMin(n);
102
3
    case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
103
2
    case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
104
6
    case BAG_CHOOSE: return evaluateChoose(n);
105
6
    case BAG_CARD: return evaluateCard(n);
106
8
    case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
107
6
    case BAG_FROM_SET: return evaluateFromSet(n);
108
6
    case BAG_TO_SET: return evaluateToSet(n);
109
    default: break;
110
  }
111
  Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
112
              << std::endl;
113
}
114
115
template <typename T1, typename T2, typename T3, typename T4, typename T5>
116
34
Node NormalForm::evaluateBinaryOperation(const TNode& n,
117
                                         T1&& equal,
118
                                         T2&& less,
119
                                         T3&& greaterOrEqual,
120
                                         T4&& remainderOfA,
121
                                         T5&& remainderOfB)
122
{
123
68
  std::map<Node, Rational> elementsA = getBagElements(n[0]);
124
68
  std::map<Node, Rational> elementsB = getBagElements(n[1]);
125
68
  std::map<Node, Rational> elements;
126
127
34
  std::map<Node, Rational>::const_iterator itA = elementsA.begin();
128
34
  std::map<Node, Rational>::const_iterator itB = elementsB.begin();
129
130
102
  Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
131
68
                         << n.getKind() << "] " << std::endl
132
34
                         << "elements A: " << elementsA << std::endl
133
34
                         << "elements B: " << elementsB << std::endl;
134
135
122
  while (itA != elementsA.end() && itB != elementsB.end())
136
  {
137
44
    if (itA->first == itB->first)
138
    {
139
20
      equal(elements, itA, itB);
140
20
      itA++;
141
20
      itB++;
142
    }
143
24
    else if (itA->first < itB->first)
144
    {
145
8
      less(elements, itA, itB);
146
8
      itA++;
147
    }
148
    else
149
    {
150
16
      greaterOrEqual(elements, itA, itB);
151
16
      itB++;
152
    }
153
  }
154
155
  // handle the remaining elements from A
156
34
  remainderOfA(elements, elementsA, itA);
157
  // handle the remaining elements from B
158
34
  remainderOfA(elements, elementsB, itB);
159
160
34
  Trace("bags-evaluate") << "elements: " << elements << std::endl;
161
34
  Node bag = constructConstantBagFromElements(n.getType(), elements);
162
34
  Trace("bags-evaluate") << "bag: " << bag << std::endl;
163
68
  return bag;
164
}
165
166
144
std::map<Node, Rational> NormalForm::getBagElements(TNode n)
167
{
168
144
  Assert(n.isConst()) << "node " << n << " is not in a normal form"
169
                      << std::endl;
170
144
  std::map<Node, Rational> elements;
171
144
  if (n.getKind() == EMPTYBAG)
172
  {
173
25
    return elements;
174
  }
175
239
  while (n.getKind() == kind::UNION_DISJOINT)
176
  {
177
60
    Assert(n[0].getKind() == kind::MK_BAG);
178
120
    Node element = n[0][0];
179
120
    Rational count = n[0][1].getConst<Rational>();
180
60
    elements[element] = count;
181
60
    n = n[1];
182
  }
183
119
  Assert(n.getKind() == kind::MK_BAG);
184
238
  Node lastElement = n[0];
185
238
  Rational lastCount = n[1].getConst<Rational>();
186
119
  elements[lastElement] = lastCount;
187
119
  return elements;
188
}
189
190
47
Node NormalForm::constructConstantBagFromElements(
191
    TypeNode t, const std::map<Node, Rational>& elements)
192
{
193
47
  Assert(t.isBag());
194
47
  NodeManager* nm = NodeManager::currentNM();
195
47
  if (elements.empty())
196
  {
197
6
    return nm->mkConst(EmptyBag(t));
198
  }
199
82
  TypeNode elementType = t.getBagElementType();
200
41
  std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
201
  Node bag =
202
82
      nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
203
101
  while (++it != elements.rend())
204
  {
205
    Node n =
206
60
        nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
207
30
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
208
  }
209
41
  return bag;
210
}
211
212
30
Node NormalForm::constructBagFromElements(TypeNode t,
213
                                          const std::map<Node, Node>& elements)
214
{
215
30
  Assert(t.isBag());
216
30
  NodeManager* nm = NodeManager::currentNM();
217
30
  if (elements.empty())
218
  {
219
2
    return nm->mkConst(EmptyBag(t));
220
  }
221
56
  TypeNode elementType = t.getBagElementType();
222
28
  std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
223
56
  Node bag = nm->mkBag(elementType, it->first, it->second);
224
56
  while (++it != elements.rend())
225
  {
226
28
    Node n = nm->mkBag(elementType, it->first, it->second);
227
14
    bag = nm->mkNode(UNION_DISJOINT, n, bag);
228
  }
229
28
  return bag;
230
}
231
232
11
Node NormalForm::evaluateMakeBag(TNode n)
233
{
234
  // the case where n is const should be handled earlier.
235
  // here we handle the case where the multiplicity is zero or negative
236
11
  Assert(n.getKind() == MK_BAG && !n.isConst()
237
         && n[1].getConst<Rational>().sgn() < 1);
238
11
  Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
239
11
  return emptybag;
240
}
241
242
57
Node NormalForm::evaluateBagCount(TNode n)
243
{
244
57
  Assert(n.getKind() == BAG_COUNT);
245
  // Examples
246
  // --------
247
  // - (bag.count "x" (emptybag String)) = 0
248
  // - (bag.count "x" (mkBag "y" 5)) = 0
249
  // - (bag.count "x" (mkBag "x" 4)) = 4
250
  // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
251
  // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
252
253
114
  std::map<Node, Rational> elements = getBagElements(n[1]);
254
57
  std::map<Node, Rational>::iterator it = elements.find(n[0]);
255
256
57
  NodeManager* nm = NodeManager::currentNM();
257
57
  if (it != elements.end())
258
  {
259
74
    Node count = nm->mkConst(it->second);
260
37
    return count;
261
  }
262
20
  return nm->mkConst(Rational(0));
263
}
264
265
7
Node NormalForm::evaluateDuplicateRemoval(TNode n)
266
{
267
7
  Assert(n.getKind() == DUPLICATE_REMOVAL);
268
269
  // Examples
270
  // --------
271
  //  - (duplicate_removal (emptybag String)) = (emptybag String)
272
  //  - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1)
273
  //  - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
274
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1)
275
276
14
  std::map<Node, Rational> oldElements = getBagElements(n[0]);
277
  // copy elements from the old bag
278
14
  std::map<Node, Rational> newElements(oldElements);
279
14
  Rational one = Rational(1);
280
7
  std::map<Node, Rational>::iterator it;
281
14
  for (it = newElements.begin(); it != newElements.end(); it++)
282
  {
283
7
    it->second = one;
284
  }
285
7
  Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
286
14
  return bag;
287
}
288
289
20
Node NormalForm::evaluateUnionDisjoint(TNode n)
290
{
291
20
  Assert(n.getKind() == UNION_DISJOINT);
292
  // Example
293
  // -------
294
  // input: (union_disjoint A B)
295
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
296
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
297
  // output:
298
  //    (union_disjoint A B)
299
  //        where A = (MK_BAG "x" 7)
300
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
301
302
  auto equal = [](std::map<Node, Rational>& elements,
303
                  std::map<Node, Rational>::const_iterator& itA,
304
6
                  std::map<Node, Rational>::const_iterator& itB) {
305
    // compute the sum of the multiplicities
306
6
    elements[itA->first] = itA->second + itB->second;
307
6
  };
308
309
  auto less = [](std::map<Node, Rational>& elements,
310
                 std::map<Node, Rational>::const_iterator& itA,
311
6
                 std::map<Node, Rational>::const_iterator& itB) {
312
    // add the element to the result
313
6
    elements[itA->first] = itA->second;
314
6
  };
315
316
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
317
                           std::map<Node, Rational>::const_iterator& itA,
318
8
                           std::map<Node, Rational>::const_iterator& itB) {
319
    // add the element to the result
320
8
    elements[itB->first] = itB->second;
321
8
  };
322
323
  auto remainderOfA = [](std::map<Node, Rational>& elements,
324
                         std::map<Node, Rational>& elementsA,
325
54
                         std::map<Node, Rational>::const_iterator& itA) {
326
    // append the remainder of A
327
68
    while (itA != elementsA.end())
328
    {
329
14
      elements[itA->first] = itA->second;
330
14
      itA++;
331
    }
332
40
  };
333
334
  auto remainderOfB = [](std::map<Node, Rational>& elements,
335
                         std::map<Node, Rational>& elementsB,
336
                         std::map<Node, Rational>::const_iterator& itB) {
337
    // append the remainder of B
338
    while (itB != elementsB.end())
339
    {
340
      elements[itB->first] = itB->second;
341
      itB++;
342
    }
343
  };
344
345
  return evaluateBinaryOperation(
346
20
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
347
}
348
349
6
Node NormalForm::evaluateUnionMax(TNode n)
350
{
351
6
  Assert(n.getKind() == UNION_MAX);
352
  // Example
353
  // -------
354
  // input: (union_max A B)
355
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
356
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
357
  // output:
358
  //    (union_disjoint A B)
359
  //        where A = (MK_BAG "x" 4)
360
  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
361
362
  auto equal = [](std::map<Node, Rational>& elements,
363
                  std::map<Node, Rational>::const_iterator& itA,
364
4
                  std::map<Node, Rational>::const_iterator& itB) {
365
    // compute the maximum multiplicity
366
4
    elements[itA->first] = std::max(itA->second, itB->second);
367
4
  };
368
369
  auto less = [](std::map<Node, Rational>& elements,
370
                 std::map<Node, Rational>::const_iterator& itA,
371
2
                 std::map<Node, Rational>::const_iterator& itB) {
372
    // add to the result
373
2
    elements[itA->first] = itA->second;
374
2
  };
375
376
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
377
                           std::map<Node, Rational>::const_iterator& itA,
378
2
                           std::map<Node, Rational>::const_iterator& itB) {
379
    // add to the result
380
2
    elements[itB->first] = itB->second;
381
2
  };
382
383
  auto remainderOfA = [](std::map<Node, Rational>& elements,
384
                         std::map<Node, Rational>& elementsA,
385
16
                         std::map<Node, Rational>::const_iterator& itA) {
386
    // append the remainder of A
387
20
    while (itA != elementsA.end())
388
    {
389
4
      elements[itA->first] = itA->second;
390
4
      itA++;
391
    }
392
12
  };
393
394
  auto remainderOfB = [](std::map<Node, Rational>& elements,
395
                         std::map<Node, Rational>& elementsB,
396
                         std::map<Node, Rational>::const_iterator& itB) {
397
    // append the remainder of B
398
    while (itB != elementsB.end())
399
    {
400
      elements[itB->first] = itB->second;
401
      itB++;
402
    }
403
  };
404
405
  return evaluateBinaryOperation(
406
6
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
407
}
408
409
3
Node NormalForm::evaluateIntersectionMin(TNode n)
410
{
411
3
  Assert(n.getKind() == INTERSECTION_MIN);
412
  // Example
413
  // -------
414
  // input: (intersectionMin A B)
415
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
416
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
417
  // output:
418
  //        (MK_BAG "x" 3)
419
420
  auto equal = [](std::map<Node, Rational>& elements,
421
                  std::map<Node, Rational>::const_iterator& itA,
422
4
                  std::map<Node, Rational>::const_iterator& itB) {
423
    // compute the minimum multiplicity
424
4
    elements[itA->first] = std::min(itA->second, itB->second);
425
4
  };
426
427
  auto less = [](std::map<Node, Rational>& elements,
428
                 std::map<Node, Rational>::const_iterator& itA,
429
                 std::map<Node, Rational>::const_iterator& itB) {
430
    // do nothing
431
  };
432
433
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
434
                           std::map<Node, Rational>::const_iterator& itA,
435
2
                           std::map<Node, Rational>::const_iterator& itB) {
436
    // do nothing
437
2
  };
438
439
  auto remainderOfA = [](std::map<Node, Rational>& elements,
440
                         std::map<Node, Rational>& elementsA,
441
6
                         std::map<Node, Rational>::const_iterator& itA) {
442
    // do nothing
443
6
  };
444
445
  auto remainderOfB = [](std::map<Node, Rational>& elements,
446
                         std::map<Node, Rational>& elementsB,
447
                         std::map<Node, Rational>::const_iterator& itB) {
448
    // do nothing
449
  };
450
451
  return evaluateBinaryOperation(
452
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
453
}
454
455
3
Node NormalForm::evaluateDifferenceSubtract(TNode n)
456
{
457
3
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
458
  // Example
459
  // -------
460
  // input: (difference_subtract A B)
461
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
462
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
463
  // output:
464
  //    (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
465
466
  auto equal = [](std::map<Node, Rational>& elements,
467
                  std::map<Node, Rational>::const_iterator& itA,
468
4
                  std::map<Node, Rational>::const_iterator& itB) {
469
    // subtract the multiplicities
470
4
    elements[itA->first] = itA->second - itB->second;
471
4
  };
472
473
  auto less = [](std::map<Node, Rational>& elements,
474
                 std::map<Node, Rational>::const_iterator& itA,
475
                 std::map<Node, Rational>::const_iterator& itB) {
476
    // itA->first is not in B, so we add it to the difference subtract
477
    elements[itA->first] = itA->second;
478
  };
479
480
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
481
                           std::map<Node, Rational>::const_iterator& itA,
482
2
                           std::map<Node, Rational>::const_iterator& itB) {
483
    // itB->first is not in A, so we just skip it
484
2
  };
485
486
  auto remainderOfA = [](std::map<Node, Rational>& elements,
487
                         std::map<Node, Rational>& elementsA,
488
8
                         std::map<Node, Rational>::const_iterator& itA) {
489
    // append the remainder of A
490
10
    while (itA != elementsA.end())
491
    {
492
2
      elements[itA->first] = itA->second;
493
2
      itA++;
494
    }
495
6
  };
496
497
  auto remainderOfB = [](std::map<Node, Rational>& elements,
498
                         std::map<Node, Rational>& elementsB,
499
                         std::map<Node, Rational>::const_iterator& itB) {
500
    // do nothing
501
  };
502
503
  return evaluateBinaryOperation(
504
3
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
505
}
506
507
2
Node NormalForm::evaluateDifferenceRemove(TNode n)
508
{
509
2
  Assert(n.getKind() == DIFFERENCE_REMOVE);
510
  // Example
511
  // -------
512
  // input: (difference_subtract A B)
513
  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
514
  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
515
  // output:
516
  //    (MK_BAG "z" 2)
517
518
  auto equal = [](std::map<Node, Rational>& elements,
519
                  std::map<Node, Rational>::const_iterator& itA,
520
2
                  std::map<Node, Rational>::const_iterator& itB) {
521
    // skip the shared element by doing nothing
522
2
  };
523
524
  auto less = [](std::map<Node, Rational>& elements,
525
                 std::map<Node, Rational>::const_iterator& itA,
526
                 std::map<Node, Rational>::const_iterator& itB) {
527
    // itA->first is not in B, so we add it to the difference remove
528
    elements[itA->first] = itA->second;
529
  };
530
531
  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
532
                           std::map<Node, Rational>::const_iterator& itA,
533
2
                           std::map<Node, Rational>::const_iterator& itB) {
534
    // itB->first is not in A, so we just skip it
535
2
  };
536
537
  auto remainderOfA = [](std::map<Node, Rational>& elements,
538
                         std::map<Node, Rational>& elementsA,
539
6
                         std::map<Node, Rational>::const_iterator& itA) {
540
    // append the remainder of A
541
8
    while (itA != elementsA.end())
542
    {
543
2
      elements[itA->first] = itA->second;
544
2
      itA++;
545
    }
546
4
  };
547
548
  auto remainderOfB = [](std::map<Node, Rational>& elements,
549
                         std::map<Node, Rational>& elementsB,
550
                         std::map<Node, Rational>::const_iterator& itB) {
551
    // do nothing
552
  };
553
554
  return evaluateBinaryOperation(
555
2
      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
556
}
557
558
6
Node NormalForm::evaluateChoose(TNode n)
559
{
560
6
  Assert(n.getKind() == BAG_CHOOSE);
561
  // Examples
562
  // --------
563
  // - (choose (emptyBag String)) = "" // the empty string which is the first
564
  //   element returned by the type enumerator
565
  // - (choose (MK_BAG "x" 4)) = "x"
566
  // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x"
567
  //     deterministically return the first element
568
569
6
  if (n[0].getKind() == EMPTYBAG)
570
  {
571
4
    TypeNode elementType = n[0].getType().getBagElementType();
572
4
    TypeEnumerator typeEnumerator(elementType);
573
    // get the first value from the typeEnumerator
574
4
    Node element = *typeEnumerator;
575
2
    return element;
576
  }
577
578
4
  if (n[0].getKind() == MK_BAG)
579
  {
580
2
    return n[0][0];
581
  }
582
2
  Assert(n[0].getKind() == UNION_DISJOINT);
583
  // return the first element
584
  // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
585
2
  return n[0][0][0];
586
}
587
588
6
Node NormalForm::evaluateCard(TNode n)
589
{
590
6
  Assert(n.getKind() == BAG_CARD);
591
  // Examples
592
  // --------
593
  //  - (card (emptyBag String)) = 0
594
  //  - (choose (MK_BAG "x" 4)) = 4
595
  //  - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
596
597
12
  std::map<Node, Rational> elements = getBagElements(n[0]);
598
12
  Rational sum(0);
599
12
  for (std::pair<Node, Rational> element : elements)
600
  {
601
6
    sum += element.second;
602
  }
603
604
6
  NodeManager* nm = NodeManager::currentNM();
605
6
  Node sumNode = nm->mkConst(sum);
606
12
  return sumNode;
607
}
608
609
8
Node NormalForm::evaluateIsSingleton(TNode n)
610
{
611
8
  Assert(n.getKind() == BAG_IS_SINGLETON);
612
  // Examples
613
  // --------
614
  // - (bag.is_singleton (emptyBag String)) = false
615
  // - (bag.is_singleton (MK_BAG "x" 1)) = true
616
  // - (bag.is_singleton (MK_BAG "x" 4)) = false
617
  // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false
618
619
8
  if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne())
620
  {
621
4
    return NodeManager::currentNM()->mkConst(true);
622
  }
623
4
  return NodeManager::currentNM()->mkConst(false);
624
}
625
626
6
Node NormalForm::evaluateFromSet(TNode n)
627
{
628
6
  Assert(n.getKind() == BAG_FROM_SET);
629
630
  // Examples
631
  // --------
632
  //  - (bag.from_set (emptyset String)) = (emptybag String)
633
  //  - (bag.from_set (singleton "x")) = (mkBag "x" 1)
634
  //  - (bag.from_set (union (singleton "x") (singleton "y"))) =
635
  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
636
637
6
  NodeManager* nm = NodeManager::currentNM();
638
  std::set<Node> setElements =
639
12
      sets::NormalForm::getElementsFromNormalConstant(n[0]);
640
12
  Rational one = Rational(1);
641
12
  std::map<Node, Rational> bagElements;
642
12
  for (const Node& element : setElements)
643
  {
644
6
    bagElements[element] = one;
645
  }
646
12
  TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
647
6
  Node bag = constructConstantBagFromElements(bagType, bagElements);
648
12
  return bag;
649
}
650
651
6
Node NormalForm::evaluateToSet(TNode n)
652
{
653
6
  Assert(n.getKind() == BAG_TO_SET);
654
655
  // Examples
656
  // --------
657
  //  - (bag.to_set (emptybag String)) = (emptyset String)
658
  //  - (bag.to_set (mkBag "x" 4)) = (singleton "x")
659
  //  - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
660
  //     (union (singleton "x") (singleton "y")))
661
662
6
  NodeManager* nm = NodeManager::currentNM();
663
12
  std::map<Node, Rational> bagElements = getBagElements(n[0]);
664
12
  std::set<Node> setElements;
665
6
  std::map<Node, Rational>::const_reverse_iterator it;
666
12
  for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
667
  {
668
6
    setElements.insert(it->first);
669
  }
670
12
  TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
671
6
  Node set = sets::NormalForm::elementsToSet(setElements, setType);
672
12
  return set;
673
}
674
675
}  // namespace bags
676
}  // namespace theory
677
26685
}  // namespace CVC4