GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bags_rewriter.cpp Lines: 227 248 91.5 %
Date: 2021-11-07 Branches: 864 1925 44.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Gereon Kremer, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Bags theory rewriter.
14
 */
15
16
#include "theory/bags/bags_rewriter.h"
17
18
#include "expr/emptybag.h"
19
#include "theory/bags/normal_form.h"
20
#include "util/rational.h"
21
#include "util/statistics_registry.h"
22
23
using namespace cvc5::kind;
24
25
namespace cvc5 {
26
namespace theory {
27
namespace bags {
28
29
4648
BagsRewriteResponse::BagsRewriteResponse()
30
4648
    : d_node(Node::null()), d_rewrite(Rewrite::NONE)
31
{
32
4648
}
33
34
4648
BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
35
4648
    : d_node(n), d_rewrite(rewrite)
36
{
37
4648
}
38
39
BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
40
    : d_node(r.d_node), d_rewrite(r.d_rewrite)
41
{
42
}
43
44
15337
BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
45
15337
    : d_statistics(statistics)
46
{
47
15337
  d_nm = NodeManager::currentNM();
48
15337
  d_zero = d_nm->mkConst(Rational(0));
49
15337
  d_one = d_nm->mkConst(Rational(1));
50
15337
}
51
52
2974
RewriteResponse BagsRewriter::postRewrite(TNode n)
53
{
54
5948
  BagsRewriteResponse response;
55
2974
  if (n.isConst())
56
  {
57
    // no need to rewrite n if it is already in a normal form
58
220
    response = BagsRewriteResponse(n, Rewrite::NONE);
59
  }
60
2754
  else if (n.getKind() == EQUAL)
61
  {
62
721
    response = postRewriteEqual(n);
63
  }
64
2033
  else if (NormalForm::areChildrenConstants(n))
65
  {
66
220
    Node value = NormalForm::evaluate(n);
67
110
    response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
68
  }
69
  else
70
  {
71
1923
    Kind k = n.getKind();
72
1923
    switch (k)
73
    {
74
185
      case MK_BAG: response = rewriteMakeBag(n); break;
75
1387
      case BAG_COUNT: response = rewriteBagCount(n); break;
76
14
      case DUPLICATE_REMOVAL: response = rewriteDuplicateRemoval(n); break;
77
76
      case UNION_MAX: response = rewriteUnionMax(n); break;
78
80
      case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
79
42
      case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break;
80
62
      case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break;
81
30
      case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break;
82
2
      case BAG_CHOOSE: response = rewriteChoose(n); break;
83
4
      case BAG_CARD: response = rewriteCard(n); break;
84
2
      case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
85
2
      case BAG_FROM_SET: response = rewriteFromSet(n); break;
86
2
      case BAG_TO_SET: response = rewriteToSet(n); break;
87
35
      case BAG_MAP: response = postRewriteMap(n); break;
88
      default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
89
    }
90
  }
91
92
5948
  Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
93
2974
                        << " by " << response.d_rewrite << "." << std::endl;
94
95
2974
  if (d_statistics != nullptr)
96
  {
97
2840
    (*d_statistics) << response.d_rewrite;
98
  }
99
2974
  if (response.d_node != n)
100
  {
101
372
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
102
  }
103
2602
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
104
}
105
106
1674
RewriteResponse BagsRewriter::preRewrite(TNode n)
107
{
108
3348
  BagsRewriteResponse response;
109
1674
  Kind k = n.getKind();
110
1674
  switch (k)
111
  {
112
488
    case EQUAL: response = preRewriteEqual(n); break;
113
12
    case SUBBAG: response = rewriteSubBag(n); break;
114
1174
    default: response = BagsRewriteResponse(n, Rewrite::NONE);
115
  }
116
117
3348
  Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
118
1674
                        << " by " << response.d_rewrite << "." << std::endl;
119
120
1674
  if (d_statistics != nullptr)
121
  {
122
1672
    (*d_statistics) << response.d_rewrite;
123
  }
124
1674
  if (response.d_node != n)
125
  {
126
105
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
127
  }
128
1569
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
129
}
130
131
488
BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
132
{
133
488
  Assert(n.getKind() == EQUAL);
134
488
  if (n[0] == n[1])
135
  {
136
    // (= A A) = true where A is a bag
137
93
    return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
138
  }
139
395
  return BagsRewriteResponse(n, Rewrite::NONE);
140
}
141
142
12
BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
143
{
144
12
  Assert(n.getKind() == SUBBAG);
145
146
  // (bag.is_included A B) = ((difference_subtract A B) == emptybag)
147
24
  Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
148
24
  Node subtract = d_nm->mkNode(DIFFERENCE_SUBTRACT, n[0], n[1]);
149
24
  Node equal = subtract.eqNode(emptybag);
150
24
  return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
151
}
152
153
185
BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
154
{
155
185
  Assert(n.getKind() == MK_BAG);
156
  // return emptybag for negative or zero multiplicity
157
185
  if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
158
  {
159
    // (mkBag x c) = emptybag where c <= 0
160
22
    Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
161
11
    return BagsRewriteResponse(emptybag, Rewrite::MK_BAG_COUNT_NEGATIVE);
162
  }
163
174
  return BagsRewriteResponse(n, Rewrite::NONE);
164
}
165
166
1387
BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
167
{
168
1387
  Assert(n.getKind() == BAG_COUNT);
169
1387
  if (n[1].isConst() && n[1].getKind() == EMPTYBAG)
170
  {
171
    // (bag.count x emptybag) = 0
172
29
    return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
173
  }
174
1358
  if (n[1].getKind() == MK_BAG && n[0] == n[1][0])
175
  {
176
    // (bag.count x (mkBag x c)) = (ite (>= c 1) c 0)
177
84
    Node c = n[1][1];
178
84
    Node geq = d_nm->mkNode(GEQ, c, d_one);
179
84
    Node ite = d_nm->mkNode(ITE, geq, c, d_zero);
180
42
    return BagsRewriteResponse(ite, Rewrite::COUNT_MK_BAG);
181
  }
182
1316
  return BagsRewriteResponse(n, Rewrite::NONE);
183
}
184
185
14
BagsRewriteResponse BagsRewriter::rewriteDuplicateRemoval(const TNode& n) const
186
{
187
14
  Assert(n.getKind() == DUPLICATE_REMOVAL);
188
44
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
189
44
      && n[0][1].getConst<Rational>().sgn() == 1)
190
  {
191
    // (duplicate_removal (mkBag x n)) = (mkBag x 1)
192
    //  where n is a positive constant
193
4
    Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], d_one);
194
2
    return BagsRewriteResponse(bag, Rewrite::DUPLICATE_REMOVAL_MK_BAG);
195
  }
196
12
  return BagsRewriteResponse(n, Rewrite::NONE);
197
}
198
199
76
BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
200
{
201
76
  Assert(n.getKind() == UNION_MAX);
202
76
  if (n[1].getKind() == EMPTYBAG || n[0] == n[1])
203
  {
204
    // (union_max A A) = A
205
    // (union_max A emptybag) = A
206
4
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
207
  }
208
72
  if (n[0].getKind() == EMPTYBAG)
209
  {
210
    // (union_max emptybag A) = A
211
2
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
212
  }
213
214
276
  if ((n[1].getKind() == UNION_MAX || n[1].getKind() == UNION_DISJOINT)
215
218
      && (n[0] == n[1][0] || n[0] == n[1][1]))
216
  {
217
    // (union_max A (union_max A B)) = (union_max A B)
218
    // (union_max A (union_max B A)) = (union_max B A)
219
    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
220
    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
221
8
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
222
  }
223
224
244
  if ((n[0].getKind() == UNION_MAX || n[0].getKind() == UNION_DISJOINT)
225
194
      && (n[0][0] == n[1] || n[0][1] == n[1]))
226
  {
227
    // (union_max (union_max A B) A)) = (union_max A B)
228
    // (union_max (union_max B A) A)) = (union_max B A)
229
    // (union_max (union_disjoint A B) A)) = (union_disjoint A B)
230
    // (union_max (union_disjoint B A) A)) = (union_disjoint B A)
231
8
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
232
  }
233
54
  return BagsRewriteResponse(n, Rewrite::NONE);
234
}
235
236
80
BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
237
{
238
80
  Assert(n.getKind() == UNION_DISJOINT);
239
80
  if (n[1].getKind() == EMPTYBAG)
240
  {
241
    // (union_disjoint A emptybag) = A
242
2
    return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
243
  }
244
78
  if (n[0].getKind() == EMPTYBAG)
245
  {
246
    // (union_disjoint emptybag A) = A
247
4
    return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
248
  }
249
228
  if ((n[0].getKind() == UNION_MAX && n[1].getKind() == INTERSECTION_MIN)
250
290
      || (n[1].getKind() == UNION_MAX && n[0].getKind() == INTERSECTION_MIN))
251
252
  {
253
    // (union_disjoint (union_max A B) (intersection_min A B)) =
254
    //         (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
255
    // check if the operands of union_max and intersection_min are the same
256
8
    std::set<Node> left(n[0].begin(), n[0].end());
257
8
    std::set<Node> right(n[1].begin(), n[1].end());
258
6
    if (left == right)
259
    {
260
8
      Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
261
4
      return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
262
    }
263
  }
264
70
  return BagsRewriteResponse(n, Rewrite::NONE);
265
}
266
267
42
BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
268
{
269
42
  Assert(n.getKind() == INTERSECTION_MIN);
270
42
  if (n[0].getKind() == EMPTYBAG)
271
  {
272
    // (intersection_min emptybag A) = emptybag
273
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
274
  }
275
40
  if (n[1].getKind() == EMPTYBAG)
276
  {
277
    // (intersection_min A emptybag) = emptybag
278
2
    return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
279
  }
280
38
  if (n[0] == n[1])
281
  {
282
    // (intersection_min A A) = A
283
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
284
  }
285
36
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
286
  {
287
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
288
    {
289
      // (intersection_min A (union_disjoint A B)) = A
290
      // (intersection_min A (union_disjoint B A)) = A
291
      // (intersection_min A (union_max A B)) = A
292
      // (intersection_min A (union_max B A)) = A
293
8
      return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
294
    }
295
  }
296
297
28
  if (n[0].getKind() == UNION_DISJOINT || n[0].getKind() == UNION_MAX)
298
  {
299
8
    if (n[1] == n[0][0] || n[1] == n[0][1])
300
    {
301
      // (intersection_min (union_disjoint A B) A) = A
302
      // (intersection_min (union_disjoint B A) A) = A
303
      // (intersection_min (union_max A B) A) = A
304
      // (intersection_min (union_max B A) A) = A
305
8
      return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
306
    }
307
  }
308
309
20
  return BagsRewriteResponse(n, Rewrite::NONE);
310
}
311
312
62
BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
313
    const TNode& n) const
314
{
315
62
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
316
62
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
317
  {
318
    // (difference_subtract A emptybag) = A
319
    // (difference_subtract emptybag A) = emptybag
320
4
    return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
321
  }
322
58
  if (n[0] == n[1])
323
  {
324
    // (difference_subtract A A) = emptybag
325
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
326
2
    return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
327
  }
328
329
56
  if (n[0].getKind() == UNION_DISJOINT)
330
  {
331
4
    if (n[1] == n[0][0])
332
    {
333
      // (difference_subtract (union_disjoint A B) A) = B
334
      return BagsRewriteResponse(n[0][1],
335
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
336
    }
337
2
    if (n[1] == n[0][1])
338
    {
339
      // (difference_subtract (union_disjoint B A) A) = B
340
      return BagsRewriteResponse(n[0][0],
341
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
342
    }
343
  }
344
345
52
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
346
  {
347
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
348
    {
349
      // (difference_subtract A (union_disjoint A B)) = emptybag
350
      // (difference_subtract A (union_disjoint B A)) = emptybag
351
      // (difference_subtract A (union_max A B)) = emptybag
352
      // (difference_subtract A (union_max B A)) = emptybag
353
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
354
8
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
355
    }
356
  }
357
358
44
  if (n[0].getKind() == INTERSECTION_MIN)
359
  {
360
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
361
    {
362
      // (difference_subtract (intersection_min A B) A) = emptybag
363
      // (difference_subtract (intersection_min B A) A) = emptybag
364
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
365
4
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
366
    }
367
  }
368
369
40
  return BagsRewriteResponse(n, Rewrite::NONE);
370
}
371
372
30
BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
373
{
374
30
  Assert(n.getKind() == DIFFERENCE_REMOVE);
375
376
30
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
377
  {
378
    // (difference_remove A emptybag) = A
379
    // (difference_remove emptybag B) = emptybag
380
4
    return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
381
  }
382
383
26
  if (n[0] == n[1])
384
  {
385
    // (difference_remove A A) = emptybag
386
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
387
2
    return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
388
  }
389
390
24
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
391
  {
392
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
393
    {
394
      // (difference_remove A (union_disjoint A B)) = emptybag
395
      // (difference_remove A (union_disjoint B A)) = emptybag
396
      // (difference_remove A (union_max A B)) = emptybag
397
      // (difference_remove A (union_max B A)) = emptybag
398
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
399
8
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
400
    }
401
  }
402
403
16
  if (n[0].getKind() == INTERSECTION_MIN)
404
  {
405
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
406
    {
407
      // (difference_remove (intersection_min A B) A) = emptybag
408
      // (difference_remove (intersection_min B A) A) = emptybag
409
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
410
4
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
411
    }
412
  }
413
414
12
  return BagsRewriteResponse(n, Rewrite::NONE);
415
}
416
417
2
BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
418
{
419
2
  Assert(n.getKind() == BAG_CHOOSE);
420
2
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
421
  {
422
    // (bag.choose (mkBag x c)) = x where c is a constant > 0
423
2
    return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG);
424
  }
425
  return BagsRewriteResponse(n, Rewrite::NONE);
426
}
427
428
4
BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
429
{
430
4
  Assert(n.getKind() == BAG_CARD);
431
4
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
432
  {
433
    // (bag.card (mkBag x c)) = c where c is a constant > 0
434
2
    return BagsRewriteResponse(n[0][1], Rewrite::CARD_MK_BAG);
435
  }
436
437
2
  if (n[0].getKind() == UNION_DISJOINT)
438
  {
439
    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
440
4
    Node A = d_nm->mkNode(BAG_CARD, n[0][0]);
441
4
    Node B = d_nm->mkNode(BAG_CARD, n[0][1]);
442
4
    Node plus = d_nm->mkNode(PLUS, A, B);
443
2
    return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT);
444
  }
445
446
  return BagsRewriteResponse(n, Rewrite::NONE);
447
}
448
449
2
BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const
450
{
451
2
  Assert(n.getKind() == BAG_IS_SINGLETON);
452
2
  if (n[0].getKind() == MK_BAG)
453
  {
454
    // (bag.is_singleton (mkBag x c)) = (c == 1)
455
4
    Node equal = n[0][1].eqNode(d_one);
456
2
    return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG);
457
  }
458
  return BagsRewriteResponse(n, Rewrite::NONE);
459
}
460
461
2
BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const
462
{
463
2
  Assert(n.getKind() == BAG_FROM_SET);
464
2
  if (n[0].getKind() == SINGLETON)
465
  {
466
    // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
467
4
    TypeNode type = n[0].getType().getSetElementType();
468
4
    Node bag = d_nm->mkBag(type, n[0][0], d_one);
469
2
    return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON);
470
  }
471
  return BagsRewriteResponse(n, Rewrite::NONE);
472
}
473
474
2
BagsRewriteResponse BagsRewriter::rewriteToSet(const TNode& n) const
475
{
476
2
  Assert(n.getKind() == BAG_TO_SET);
477
8
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
478
8
      && n[0][1].getConst<Rational>().sgn() == 1)
479
  {
480
    // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
481
    // where n is a positive constant and T is the type of the bag's elements
482
4
    Node set = d_nm->mkSingleton(n[0][0].getType(), n[0][0]);
483
2
    return BagsRewriteResponse(set, Rewrite::TO_SINGLETON);
484
  }
485
  return BagsRewriteResponse(n, Rewrite::NONE);
486
}
487
488
721
BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
489
{
490
721
  Assert(n.getKind() == kind::EQUAL);
491
721
  if (n[0] == n[1])
492
  {
493
4
    Node ret = d_nm->mkConst(true);
494
2
    return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
495
  }
496
497
719
  if (n[0].isConst() && n[1].isConst())
498
  {
499
36
    Node ret = d_nm->mkConst(false);
500
18
    return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
501
  }
502
503
  // standard ordering
504
701
  if (n[0] > n[1])
505
  {
506
94
    Node ret = d_nm->mkNode(kind::EQUAL, n[1], n[0]);
507
47
    return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
508
  }
509
654
  return BagsRewriteResponse(n, Rewrite::NONE);
510
}
511
512
35
BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
513
{
514
35
  Assert(n.getKind() == kind::BAG_MAP);
515
35
  if (n[1].isConst())
516
  {
517
    // (bag.map f emptybag) = emptybag
518
    // (bag.map f (bag "a" 3)) = (bag (f "a") 3)
519
10
    std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
520
10
    std::map<Node, Rational> mappedElements;
521
5
    std::map<Node, Rational>::iterator it = elements.begin();
522
5
    while (it != elements.end())
523
    {
524
      Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], it->first);
525
      mappedElements[mappedElement] = it->second;
526
      ++it;
527
    }
528
10
    TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType());
529
10
    Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
530
5
    return BagsRewriteResponse(ret, Rewrite::MAP_CONST);
531
  }
532
30
  Kind k = n[1].getKind();
533
30
  switch (k)
534
  {
535
    case MK_BAG:
536
    {
537
      Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], n[1][0]);
538
      Node ret =
539
          d_nm->mkBag(n[0].getType().getRangeType(), mappedElement, n[1][1]);
540
      return BagsRewriteResponse(ret, Rewrite::MAP_MK_BAG);
541
    }
542
543
    case UNION_DISJOINT:
544
    {
545
      Node a = d_nm->mkNode(BAG_MAP, n[1][0]);
546
      Node b = d_nm->mkNode(BAG_MAP, n[1][1]);
547
      Node ret = d_nm->mkNode(UNION_DISJOINT, a, b);
548
      return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT);
549
    }
550
551
30
    default: return BagsRewriteResponse(n, Rewrite::NONE);
552
  }
553
}
554
}  // namespace bags
555
}  // namespace theory
556
31137
}  // namespace cvc5