GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bags_rewriter.cpp Lines: 207 219 94.5 %
Date: 2021-05-22 Branches: 827 1797 46.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Gereon Kremer, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Bags theory rewriter.
14
 */
15
16
#include "theory/bags/bags_rewriter.h"
17
18
#include "theory/bags/normal_form.h"
19
#include "util/statistics_registry.h"
20
21
using namespace cvc5::kind;
22
23
namespace cvc5 {
24
namespace theory {
25
namespace bags {
26
27
3137
BagsRewriteResponse::BagsRewriteResponse()
28
3137
    : d_node(Node::null()), d_rewrite(Rewrite::NONE)
29
{
30
3137
}
31
32
3137
BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
33
3137
    : d_node(n), d_rewrite(rewrite)
34
{
35
3137
}
36
37
BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
38
    : d_node(r.d_node), d_rewrite(r.d_rewrite)
39
{
40
}
41
42
9521
BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
43
9521
    : d_statistics(statistics)
44
{
45
9521
  d_nm = NodeManager::currentNM();
46
9521
  d_zero = d_nm->mkConst(Rational(0));
47
9521
  d_one = d_nm->mkConst(Rational(1));
48
9521
}
49
50
2007
RewriteResponse BagsRewriter::postRewrite(TNode n)
51
{
52
4014
  BagsRewriteResponse response;
53
2007
  if (n.isConst())
54
  {
55
    // no need to rewrite n if it is already in a normal form
56
140
    response = BagsRewriteResponse(n, Rewrite::NONE);
57
  }
58
1867
  else if (n.getKind() == EQUAL)
59
  {
60
504
    response = postRewriteEqual(n);
61
  }
62
1363
  else if (NormalForm::areChildrenConstants(n))
63
  {
64
144
    Node value = NormalForm::evaluate(n);
65
72
    response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
66
  }
67
  else
68
  {
69
1291
    Kind k = n.getKind();
70
1291
    switch (k)
71
    {
72
103
      case MK_BAG: response = rewriteMakeBag(n); break;
73
931
      case BAG_COUNT: response = rewriteBagCount(n); break;
74
14
      case DUPLICATE_REMOVAL: response = rewriteDuplicateRemoval(n); break;
75
66
      case UNION_MAX: response = rewriteUnionMax(n); break;
76
67
      case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
77
34
      case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break;
78
46
      case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break;
79
18
      case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break;
80
2
      case BAG_CHOOSE: response = rewriteChoose(n); break;
81
4
      case BAG_CARD: response = rewriteCard(n); break;
82
2
      case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
83
2
      case BAG_FROM_SET: response = rewriteFromSet(n); break;
84
2
      case BAG_TO_SET: response = rewriteToSet(n); break;
85
      default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
86
    }
87
  }
88
89
4014
  Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
90
2007
                        << " by " << response.d_rewrite << "." << std::endl;
91
92
2007
  if (d_statistics != nullptr)
93
  {
94
1875
    (*d_statistics) << response.d_rewrite;
95
  }
96
2007
  if (response.d_node != n)
97
  {
98
297
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
99
  }
100
1710
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
101
}
102
103
1130
RewriteResponse BagsRewriter::preRewrite(TNode n)
104
{
105
2260
  BagsRewriteResponse response;
106
1130
  Kind k = n.getKind();
107
1130
  switch (k)
108
  {
109
351
    case EQUAL: response = preRewriteEqual(n); break;
110
12
    case SUBBAG: response = rewriteSubBag(n); break;
111
767
    default: response = BagsRewriteResponse(n, Rewrite::NONE);
112
  }
113
114
2260
  Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
115
1130
                        << " by " << response.d_rewrite << "." << std::endl;
116
117
1130
  if (d_statistics != nullptr)
118
  {
119
1128
    (*d_statistics) << response.d_rewrite;
120
  }
121
1130
  if (response.d_node != n)
122
  {
123
79
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
124
  }
125
1051
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
126
}
127
128
351
BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
129
{
130
351
  Assert(n.getKind() == EQUAL);
131
351
  if (n[0] == n[1])
132
  {
133
    // (= A A) = true where A is a bag
134
67
    return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
135
  }
136
284
  return BagsRewriteResponse(n, Rewrite::NONE);
137
}
138
139
12
BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
140
{
141
12
  Assert(n.getKind() == SUBBAG);
142
143
  // (bag.is_included A B) = ((difference_subtract A B) == emptybag)
144
24
  Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
145
24
  Node subtract = d_nm->mkNode(DIFFERENCE_SUBTRACT, n[0], n[1]);
146
24
  Node equal = subtract.eqNode(emptybag);
147
24
  return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
148
}
149
150
103
BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
151
{
152
103
  Assert(n.getKind() == MK_BAG);
153
  // return emptybag for negative or zero multiplicity
154
103
  if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
155
  {
156
    // (mkBag x c) = emptybag where c <= 0
157
18
    Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
158
9
    return BagsRewriteResponse(emptybag, Rewrite::MK_BAG_COUNT_NEGATIVE);
159
  }
160
94
  return BagsRewriteResponse(n, Rewrite::NONE);
161
}
162
163
931
BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
164
{
165
931
  Assert(n.getKind() == BAG_COUNT);
166
931
  if (n[1].isConst() && n[1].getKind() == EMPTYBAG)
167
  {
168
    // (bag.count x emptybag) = 0
169
27
    return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
170
  }
171
904
  if (n[1].getKind() == MK_BAG && n[0] == n[1][0])
172
  {
173
    // (bag.count x (mkBag x c) = c
174
22
    return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG);
175
  }
176
882
  return BagsRewriteResponse(n, Rewrite::NONE);
177
}
178
179
14
BagsRewriteResponse BagsRewriter::rewriteDuplicateRemoval(const TNode& n) const
180
{
181
14
  Assert(n.getKind() == DUPLICATE_REMOVAL);
182
44
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
183
44
      && n[0][1].getConst<Rational>().sgn() == 1)
184
  {
185
    // (duplicate_removal (mkBag x n)) = (mkBag x 1)
186
    //  where n is a positive constant
187
4
    Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], d_one);
188
2
    return BagsRewriteResponse(bag, Rewrite::DUPLICATE_REMOVAL_MK_BAG);
189
  }
190
12
  return BagsRewriteResponse(n, Rewrite::NONE);
191
}
192
193
66
BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
194
{
195
66
  Assert(n.getKind() == UNION_MAX);
196
66
  if (n[1].getKind() == EMPTYBAG || n[0] == n[1])
197
  {
198
    // (union_max A A) = A
199
    // (union_max A emptybag) = A
200
4
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
201
  }
202
62
  if (n[0].getKind() == EMPTYBAG)
203
  {
204
    // (union_max emptybag A) = A
205
2
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
206
  }
207
208
236
  if ((n[1].getKind() == UNION_MAX || n[1].getKind() == UNION_DISJOINT)
209
188
      && (n[0] == n[1][0] || n[0] == n[1][1]))
210
  {
211
    // (union_max A (union_max A B)) = (union_max A B)
212
    // (union_max A (union_max B A)) = (union_max B A)
213
    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
214
    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
215
8
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
216
  }
217
218
204
  if ((n[0].getKind() == UNION_MAX || n[0].getKind() == UNION_DISJOINT)
219
164
      && (n[0][0] == n[1] || n[0][1] == n[1]))
220
  {
221
    // (union_max (union_max A B) A)) = (union_max A B)
222
    // (union_max (union_max B A) A)) = (union_max B A)
223
    // (union_max (union_disjoint A B) A)) = (union_disjoint A B)
224
    // (union_max (union_disjoint B A) A)) = (union_disjoint B A)
225
8
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
226
  }
227
44
  return BagsRewriteResponse(n, Rewrite::NONE);
228
}
229
230
67
BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
231
{
232
67
  Assert(n.getKind() == UNION_DISJOINT);
233
67
  if (n[1].getKind() == EMPTYBAG)
234
  {
235
    // (union_disjoint A emptybag) = A
236
2
    return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
237
  }
238
65
  if (n[0].getKind() == EMPTYBAG)
239
  {
240
    // (union_disjoint emptybag A) = A
241
3
    return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
242
  }
243
192
  if ((n[0].getKind() == UNION_MAX && n[1].getKind() == INTERSECTION_MIN)
244
242
      || (n[1].getKind() == UNION_MAX && n[0].getKind() == INTERSECTION_MIN))
245
246
  {
247
    // (union_disjoint (union_max A B) (intersection_min A B)) =
248
    //         (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
249
    // check if the operands of union_max and intersection_min are the same
250
8
    std::set<Node> left(n[0].begin(), n[0].end());
251
8
    std::set<Node> right(n[1].begin(), n[1].end());
252
6
    if (left == right)
253
    {
254
8
      Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
255
4
      return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
256
    }
257
  }
258
58
  return BagsRewriteResponse(n, Rewrite::NONE);
259
}
260
261
34
BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
262
{
263
34
  Assert(n.getKind() == INTERSECTION_MIN);
264
34
  if (n[0].getKind() == EMPTYBAG)
265
  {
266
    // (intersection_min emptybag A) = emptybag
267
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
268
  }
269
32
  if (n[1].getKind() == EMPTYBAG)
270
  {
271
    // (intersection_min A emptybag) = emptybag
272
2
    return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
273
  }
274
30
  if (n[0] == n[1])
275
  {
276
    // (intersection_min A A) = A
277
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
278
  }
279
28
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
280
  {
281
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
282
    {
283
      // (intersection_min A (union_disjoint A B)) = A
284
      // (intersection_min A (union_disjoint B A)) = A
285
      // (intersection_min A (union_max A B)) = A
286
      // (intersection_min A (union_max B A)) = A
287
8
      return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
288
    }
289
  }
290
291
20
  if (n[0].getKind() == UNION_DISJOINT || n[0].getKind() == UNION_MAX)
292
  {
293
8
    if (n[1] == n[0][0] || n[1] == n[0][1])
294
    {
295
      // (intersection_min (union_disjoint A B) A) = A
296
      // (intersection_min (union_disjoint B A) A) = A
297
      // (intersection_min (union_max A B) A) = A
298
      // (intersection_min (union_max B A) A) = A
299
8
      return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
300
    }
301
  }
302
303
12
  return BagsRewriteResponse(n, Rewrite::NONE);
304
}
305
306
46
BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
307
    const TNode& n) const
308
{
309
46
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
310
46
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
311
  {
312
    // (difference_subtract A emptybag) = A
313
    // (difference_subtract emptybag A) = emptybag
314
4
    return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
315
  }
316
42
  if (n[0] == n[1])
317
  {
318
    // (difference_subtract A A) = emptybag
319
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
320
2
    return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
321
  }
322
323
40
  if (n[0].getKind() == UNION_DISJOINT)
324
  {
325
4
    if (n[1] == n[0][0])
326
    {
327
      // (difference_subtract (union_disjoint A B) A) = B
328
      return BagsRewriteResponse(n[0][1],
329
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
330
    }
331
2
    if (n[1] == n[0][1])
332
    {
333
      // (difference_subtract (union_disjoint B A) A) = B
334
      return BagsRewriteResponse(n[0][0],
335
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
336
    }
337
  }
338
339
36
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
340
  {
341
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
342
    {
343
      // (difference_subtract A (union_disjoint A B)) = emptybag
344
      // (difference_subtract A (union_disjoint B A)) = emptybag
345
      // (difference_subtract A (union_max A B)) = emptybag
346
      // (difference_subtract A (union_max B A)) = emptybag
347
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
348
8
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
349
    }
350
  }
351
352
28
  if (n[0].getKind() == INTERSECTION_MIN)
353
  {
354
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
355
    {
356
      // (difference_subtract (intersection_min A B) A) = emptybag
357
      // (difference_subtract (intersection_min B A) A) = emptybag
358
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
359
4
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
360
    }
361
  }
362
363
24
  return BagsRewriteResponse(n, Rewrite::NONE);
364
}
365
366
18
BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
367
{
368
18
  Assert(n.getKind() == DIFFERENCE_REMOVE);
369
370
18
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
371
  {
372
    // (difference_remove A emptybag) = A
373
    // (difference_remove emptybag B) = emptybag
374
4
    return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
375
  }
376
377
14
  if (n[0] == n[1])
378
  {
379
    // (difference_remove A A) = emptybag
380
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
381
2
    return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
382
  }
383
384
12
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
385
  {
386
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
387
    {
388
      // (difference_remove A (union_disjoint A B)) = emptybag
389
      // (difference_remove A (union_disjoint B A)) = emptybag
390
      // (difference_remove A (union_max A B)) = emptybag
391
      // (difference_remove A (union_max B A)) = emptybag
392
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
393
8
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
394
    }
395
  }
396
397
4
  if (n[0].getKind() == INTERSECTION_MIN)
398
  {
399
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
400
    {
401
      // (difference_remove (intersection_min A B) A) = emptybag
402
      // (difference_remove (intersection_min B A) A) = emptybag
403
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
404
4
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
405
    }
406
  }
407
408
  return BagsRewriteResponse(n, Rewrite::NONE);
409
}
410
411
2
BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
412
{
413
2
  Assert(n.getKind() == BAG_CHOOSE);
414
2
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
415
  {
416
    // (bag.choose (mkBag x c)) = x where c is a constant > 0
417
2
    return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG);
418
  }
419
  return BagsRewriteResponse(n, Rewrite::NONE);
420
}
421
422
4
BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
423
{
424
4
  Assert(n.getKind() == BAG_CARD);
425
4
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
426
  {
427
    // (bag.card (mkBag x c)) = c where c is a constant > 0
428
2
    return BagsRewriteResponse(n[0][1], Rewrite::CARD_MK_BAG);
429
  }
430
431
2
  if (n[0].getKind() == UNION_DISJOINT)
432
  {
433
    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
434
4
    Node A = d_nm->mkNode(BAG_CARD, n[0][0]);
435
4
    Node B = d_nm->mkNode(BAG_CARD, n[0][1]);
436
4
    Node plus = d_nm->mkNode(PLUS, A, B);
437
2
    return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT);
438
  }
439
440
  return BagsRewriteResponse(n, Rewrite::NONE);
441
}
442
443
2
BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const
444
{
445
2
  Assert(n.getKind() == BAG_IS_SINGLETON);
446
2
  if (n[0].getKind() == MK_BAG)
447
  {
448
    // (bag.is_singleton (mkBag x c)) = (c == 1)
449
4
    Node equal = n[0][1].eqNode(d_one);
450
2
    return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG);
451
  }
452
  return BagsRewriteResponse(n, Rewrite::NONE);
453
}
454
455
2
BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const
456
{
457
2
  Assert(n.getKind() == BAG_FROM_SET);
458
2
  if (n[0].getKind() == SINGLETON)
459
  {
460
    // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
461
4
    TypeNode type = n[0].getType().getSetElementType();
462
4
    Node bag = d_nm->mkBag(type, n[0][0], d_one);
463
2
    return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON);
464
  }
465
  return BagsRewriteResponse(n, Rewrite::NONE);
466
}
467
468
2
BagsRewriteResponse BagsRewriter::rewriteToSet(const TNode& n) const
469
{
470
2
  Assert(n.getKind() == BAG_TO_SET);
471
8
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
472
8
      && n[0][1].getConst<Rational>().sgn() == 1)
473
  {
474
    // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
475
    // where n is a positive constant and T is the type of the bag's elements
476
4
    Node set = d_nm->mkSingleton(n[0][0].getType(), n[0][0]);
477
2
    return BagsRewriteResponse(set, Rewrite::TO_SINGLETON);
478
  }
479
  return BagsRewriteResponse(n, Rewrite::NONE);
480
}
481
482
504
BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
483
{
484
504
  Assert(n.getKind() == kind::EQUAL);
485
504
  if (n[0] == n[1])
486
  {
487
    Node ret = d_nm->mkConst(true);
488
    return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
489
  }
490
491
504
  if (n[0].isConst() && n[1].isConst())
492
  {
493
26
    Node ret = d_nm->mkConst(false);
494
13
    return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
495
  }
496
497
  // standard ordering
498
491
  if (n[0] > n[1])
499
  {
500
94
    Node ret = d_nm->mkNode(kind::EQUAL, n[1], n[0]);
501
47
    return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
502
  }
503
444
  return BagsRewriteResponse(n, Rewrite::NONE);
504
}
505
506
}  // namespace bags
507
}  // namespace theory
508
28191
}  // namespace cvc5