GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bags_rewriter.cpp Lines: 207 219 94.5 %
Date: 2021-08-06 Branches: 827 1797 46.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Gereon Kremer, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Bags theory rewriter.
14
 */
15
16
#include "theory/bags/bags_rewriter.h"
17
18
#include "expr/emptybag.h"
19
#include "theory/bags/normal_form.h"
20
#include "util/rational.h"
21
#include "util/statistics_registry.h"
22
23
using namespace cvc5::kind;
24
25
namespace cvc5 {
26
namespace theory {
27
namespace bags {
28
29
3121
BagsRewriteResponse::BagsRewriteResponse()
30
3121
    : d_node(Node::null()), d_rewrite(Rewrite::NONE)
31
{
32
3121
}
33
34
3121
BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
35
3121
    : d_node(n), d_rewrite(rewrite)
36
{
37
3121
}
38
39
BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
40
    : d_node(r.d_node), d_rewrite(r.d_rewrite)
41
{
42
}
43
44
9915
BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
45
9915
    : d_statistics(statistics)
46
{
47
9915
  d_nm = NodeManager::currentNM();
48
9915
  d_zero = d_nm->mkConst(Rational(0));
49
9915
  d_one = d_nm->mkConst(Rational(1));
50
9915
}
51
52
1998
RewriteResponse BagsRewriter::postRewrite(TNode n)
53
{
54
3996
  BagsRewriteResponse response;
55
1998
  if (n.isConst())
56
  {
57
    // no need to rewrite n if it is already in a normal form
58
140
    response = BagsRewriteResponse(n, Rewrite::NONE);
59
  }
60
1858
  else if (n.getKind() == EQUAL)
61
  {
62
500
    response = postRewriteEqual(n);
63
  }
64
1358
  else if (NormalForm::areChildrenConstants(n))
65
  {
66
144
    Node value = NormalForm::evaluate(n);
67
72
    response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
68
  }
69
  else
70
  {
71
1286
    Kind k = n.getKind();
72
1286
    switch (k)
73
    {
74
103
      case MK_BAG: response = rewriteMakeBag(n); break;
75
928
      case BAG_COUNT: response = rewriteBagCount(n); break;
76
14
      case DUPLICATE_REMOVAL: response = rewriteDuplicateRemoval(n); break;
77
64
      case UNION_MAX: response = rewriteUnionMax(n); break;
78
67
      case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
79
34
      case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break;
80
46
      case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break;
81
18
      case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break;
82
2
      case BAG_CHOOSE: response = rewriteChoose(n); break;
83
4
      case BAG_CARD: response = rewriteCard(n); break;
84
2
      case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
85
2
      case BAG_FROM_SET: response = rewriteFromSet(n); break;
86
2
      case BAG_TO_SET: response = rewriteToSet(n); break;
87
      default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
88
    }
89
  }
90
91
3996
  Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
92
1998
                        << " by " << response.d_rewrite << "." << std::endl;
93
94
1998
  if (d_statistics != nullptr)
95
  {
96
1866
    (*d_statistics) << response.d_rewrite;
97
  }
98
1998
  if (response.d_node != n)
99
  {
100
296
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
101
  }
102
1702
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
103
}
104
105
1123
RewriteResponse BagsRewriter::preRewrite(TNode n)
106
{
107
2246
  BagsRewriteResponse response;
108
1123
  Kind k = n.getKind();
109
1123
  switch (k)
110
  {
111
347
    case EQUAL: response = preRewriteEqual(n); break;
112
12
    case SUBBAG: response = rewriteSubBag(n); break;
113
764
    default: response = BagsRewriteResponse(n, Rewrite::NONE);
114
  }
115
116
2246
  Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
117
1123
                        << " by " << response.d_rewrite << "." << std::endl;
118
119
1123
  if (d_statistics != nullptr)
120
  {
121
1121
    (*d_statistics) << response.d_rewrite;
122
  }
123
1123
  if (response.d_node != n)
124
  {
125
76
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
126
  }
127
1047
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
128
}
129
130
347
BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
131
{
132
347
  Assert(n.getKind() == EQUAL);
133
347
  if (n[0] == n[1])
134
  {
135
    // (= A A) = true where A is a bag
136
64
    return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
137
  }
138
283
  return BagsRewriteResponse(n, Rewrite::NONE);
139
}
140
141
12
BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
142
{
143
12
  Assert(n.getKind() == SUBBAG);
144
145
  // (bag.is_included A B) = ((difference_subtract A B) == emptybag)
146
24
  Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
147
24
  Node subtract = d_nm->mkNode(DIFFERENCE_SUBTRACT, n[0], n[1]);
148
24
  Node equal = subtract.eqNode(emptybag);
149
24
  return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
150
}
151
152
103
BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
153
{
154
103
  Assert(n.getKind() == MK_BAG);
155
  // return emptybag for negative or zero multiplicity
156
103
  if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
157
  {
158
    // (mkBag x c) = emptybag where c <= 0
159
18
    Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
160
9
    return BagsRewriteResponse(emptybag, Rewrite::MK_BAG_COUNT_NEGATIVE);
161
  }
162
94
  return BagsRewriteResponse(n, Rewrite::NONE);
163
}
164
165
928
BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
166
{
167
928
  Assert(n.getKind() == BAG_COUNT);
168
928
  if (n[1].isConst() && n[1].getKind() == EMPTYBAG)
169
  {
170
    // (bag.count x emptybag) = 0
171
26
    return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
172
  }
173
902
  if (n[1].getKind() == MK_BAG && n[0] == n[1][0])
174
  {
175
    // (bag.count x (mkBag x c) = c
176
22
    return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG);
177
  }
178
880
  return BagsRewriteResponse(n, Rewrite::NONE);
179
}
180
181
14
BagsRewriteResponse BagsRewriter::rewriteDuplicateRemoval(const TNode& n) const
182
{
183
14
  Assert(n.getKind() == DUPLICATE_REMOVAL);
184
44
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
185
44
      && n[0][1].getConst<Rational>().sgn() == 1)
186
  {
187
    // (duplicate_removal (mkBag x n)) = (mkBag x 1)
188
    //  where n is a positive constant
189
4
    Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], d_one);
190
2
    return BagsRewriteResponse(bag, Rewrite::DUPLICATE_REMOVAL_MK_BAG);
191
  }
192
12
  return BagsRewriteResponse(n, Rewrite::NONE);
193
}
194
195
64
BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
196
{
197
64
  Assert(n.getKind() == UNION_MAX);
198
64
  if (n[1].getKind() == EMPTYBAG || n[0] == n[1])
199
  {
200
    // (union_max A A) = A
201
    // (union_max A emptybag) = A
202
4
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
203
  }
204
60
  if (n[0].getKind() == EMPTYBAG)
205
  {
206
    // (union_max emptybag A) = A
207
2
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
208
  }
209
210
228
  if ((n[1].getKind() == UNION_MAX || n[1].getKind() == UNION_DISJOINT)
211
182
      && (n[0] == n[1][0] || n[0] == n[1][1]))
212
  {
213
    // (union_max A (union_max A B)) = (union_max A B)
214
    // (union_max A (union_max B A)) = (union_max B A)
215
    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
216
    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
217
8
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
218
  }
219
220
196
  if ((n[0].getKind() == UNION_MAX || n[0].getKind() == UNION_DISJOINT)
221
158
      && (n[0][0] == n[1] || n[0][1] == n[1]))
222
  {
223
    // (union_max (union_max A B) A)) = (union_max A B)
224
    // (union_max (union_max B A) A)) = (union_max B A)
225
    // (union_max (union_disjoint A B) A)) = (union_disjoint A B)
226
    // (union_max (union_disjoint B A) A)) = (union_disjoint B A)
227
8
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
228
  }
229
42
  return BagsRewriteResponse(n, Rewrite::NONE);
230
}
231
232
67
BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
233
{
234
67
  Assert(n.getKind() == UNION_DISJOINT);
235
67
  if (n[1].getKind() == EMPTYBAG)
236
  {
237
    // (union_disjoint A emptybag) = A
238
2
    return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
239
  }
240
65
  if (n[0].getKind() == EMPTYBAG)
241
  {
242
    // (union_disjoint emptybag A) = A
243
3
    return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
244
  }
245
192
  if ((n[0].getKind() == UNION_MAX && n[1].getKind() == INTERSECTION_MIN)
246
242
      || (n[1].getKind() == UNION_MAX && n[0].getKind() == INTERSECTION_MIN))
247
248
  {
249
    // (union_disjoint (union_max A B) (intersection_min A B)) =
250
    //         (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
251
    // check if the operands of union_max and intersection_min are the same
252
8
    std::set<Node> left(n[0].begin(), n[0].end());
253
8
    std::set<Node> right(n[1].begin(), n[1].end());
254
6
    if (left == right)
255
    {
256
8
      Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
257
4
      return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
258
    }
259
  }
260
58
  return BagsRewriteResponse(n, Rewrite::NONE);
261
}
262
263
34
BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
264
{
265
34
  Assert(n.getKind() == INTERSECTION_MIN);
266
34
  if (n[0].getKind() == EMPTYBAG)
267
  {
268
    // (intersection_min emptybag A) = emptybag
269
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
270
  }
271
32
  if (n[1].getKind() == EMPTYBAG)
272
  {
273
    // (intersection_min A emptybag) = emptybag
274
2
    return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
275
  }
276
30
  if (n[0] == n[1])
277
  {
278
    // (intersection_min A A) = A
279
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
280
  }
281
28
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
282
  {
283
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
284
    {
285
      // (intersection_min A (union_disjoint A B)) = A
286
      // (intersection_min A (union_disjoint B A)) = A
287
      // (intersection_min A (union_max A B)) = A
288
      // (intersection_min A (union_max B A)) = A
289
8
      return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
290
    }
291
  }
292
293
20
  if (n[0].getKind() == UNION_DISJOINT || n[0].getKind() == UNION_MAX)
294
  {
295
8
    if (n[1] == n[0][0] || n[1] == n[0][1])
296
    {
297
      // (intersection_min (union_disjoint A B) A) = A
298
      // (intersection_min (union_disjoint B A) A) = A
299
      // (intersection_min (union_max A B) A) = A
300
      // (intersection_min (union_max B A) A) = A
301
8
      return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
302
    }
303
  }
304
305
12
  return BagsRewriteResponse(n, Rewrite::NONE);
306
}
307
308
46
BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
309
    const TNode& n) const
310
{
311
46
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
312
46
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
313
  {
314
    // (difference_subtract A emptybag) = A
315
    // (difference_subtract emptybag A) = emptybag
316
4
    return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
317
  }
318
42
  if (n[0] == n[1])
319
  {
320
    // (difference_subtract A A) = emptybag
321
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
322
2
    return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
323
  }
324
325
40
  if (n[0].getKind() == UNION_DISJOINT)
326
  {
327
4
    if (n[1] == n[0][0])
328
    {
329
      // (difference_subtract (union_disjoint A B) A) = B
330
      return BagsRewriteResponse(n[0][1],
331
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
332
    }
333
2
    if (n[1] == n[0][1])
334
    {
335
      // (difference_subtract (union_disjoint B A) A) = B
336
      return BagsRewriteResponse(n[0][0],
337
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
338
    }
339
  }
340
341
36
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
342
  {
343
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
344
    {
345
      // (difference_subtract A (union_disjoint A B)) = emptybag
346
      // (difference_subtract A (union_disjoint B A)) = emptybag
347
      // (difference_subtract A (union_max A B)) = emptybag
348
      // (difference_subtract A (union_max B A)) = emptybag
349
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
350
8
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
351
    }
352
  }
353
354
28
  if (n[0].getKind() == INTERSECTION_MIN)
355
  {
356
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
357
    {
358
      // (difference_subtract (intersection_min A B) A) = emptybag
359
      // (difference_subtract (intersection_min B A) A) = emptybag
360
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
361
4
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
362
    }
363
  }
364
365
24
  return BagsRewriteResponse(n, Rewrite::NONE);
366
}
367
368
18
BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
369
{
370
18
  Assert(n.getKind() == DIFFERENCE_REMOVE);
371
372
18
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
373
  {
374
    // (difference_remove A emptybag) = A
375
    // (difference_remove emptybag B) = emptybag
376
4
    return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
377
  }
378
379
14
  if (n[0] == n[1])
380
  {
381
    // (difference_remove A A) = emptybag
382
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
383
2
    return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
384
  }
385
386
12
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
387
  {
388
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
389
    {
390
      // (difference_remove A (union_disjoint A B)) = emptybag
391
      // (difference_remove A (union_disjoint B A)) = emptybag
392
      // (difference_remove A (union_max A B)) = emptybag
393
      // (difference_remove A (union_max B A)) = emptybag
394
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
395
8
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
396
    }
397
  }
398
399
4
  if (n[0].getKind() == INTERSECTION_MIN)
400
  {
401
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
402
    {
403
      // (difference_remove (intersection_min A B) A) = emptybag
404
      // (difference_remove (intersection_min B A) A) = emptybag
405
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
406
4
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
407
    }
408
  }
409
410
  return BagsRewriteResponse(n, Rewrite::NONE);
411
}
412
413
2
BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
414
{
415
2
  Assert(n.getKind() == BAG_CHOOSE);
416
2
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
417
  {
418
    // (bag.choose (mkBag x c)) = x where c is a constant > 0
419
2
    return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG);
420
  }
421
  return BagsRewriteResponse(n, Rewrite::NONE);
422
}
423
424
4
BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
425
{
426
4
  Assert(n.getKind() == BAG_CARD);
427
4
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
428
  {
429
    // (bag.card (mkBag x c)) = c where c is a constant > 0
430
2
    return BagsRewriteResponse(n[0][1], Rewrite::CARD_MK_BAG);
431
  }
432
433
2
  if (n[0].getKind() == UNION_DISJOINT)
434
  {
435
    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
436
4
    Node A = d_nm->mkNode(BAG_CARD, n[0][0]);
437
4
    Node B = d_nm->mkNode(BAG_CARD, n[0][1]);
438
4
    Node plus = d_nm->mkNode(PLUS, A, B);
439
2
    return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT);
440
  }
441
442
  return BagsRewriteResponse(n, Rewrite::NONE);
443
}
444
445
2
BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const
446
{
447
2
  Assert(n.getKind() == BAG_IS_SINGLETON);
448
2
  if (n[0].getKind() == MK_BAG)
449
  {
450
    // (bag.is_singleton (mkBag x c)) = (c == 1)
451
4
    Node equal = n[0][1].eqNode(d_one);
452
2
    return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG);
453
  }
454
  return BagsRewriteResponse(n, Rewrite::NONE);
455
}
456
457
2
BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const
458
{
459
2
  Assert(n.getKind() == BAG_FROM_SET);
460
2
  if (n[0].getKind() == SINGLETON)
461
  {
462
    // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
463
4
    TypeNode type = n[0].getType().getSetElementType();
464
4
    Node bag = d_nm->mkBag(type, n[0][0], d_one);
465
2
    return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON);
466
  }
467
  return BagsRewriteResponse(n, Rewrite::NONE);
468
}
469
470
2
BagsRewriteResponse BagsRewriter::rewriteToSet(const TNode& n) const
471
{
472
2
  Assert(n.getKind() == BAG_TO_SET);
473
8
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
474
8
      && n[0][1].getConst<Rational>().sgn() == 1)
475
  {
476
    // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
477
    // where n is a positive constant and T is the type of the bag's elements
478
4
    Node set = d_nm->mkSingleton(n[0][0].getType(), n[0][0]);
479
2
    return BagsRewriteResponse(set, Rewrite::TO_SINGLETON);
480
  }
481
  return BagsRewriteResponse(n, Rewrite::NONE);
482
}
483
484
500
BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
485
{
486
500
  Assert(n.getKind() == kind::EQUAL);
487
500
  if (n[0] == n[1])
488
  {
489
    Node ret = d_nm->mkConst(true);
490
    return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
491
  }
492
493
500
  if (n[0].isConst() && n[1].isConst())
494
  {
495
26
    Node ret = d_nm->mkConst(false);
496
13
    return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
497
  }
498
499
  // standard ordering
500
487
  if (n[0] > n[1])
501
  {
502
94
    Node ret = d_nm->mkNode(kind::EQUAL, n[1], n[0]);
503
47
    return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
504
  }
505
440
  return BagsRewriteResponse(n, Rewrite::NONE);
506
}
507
508
}  // namespace bags
509
}  // namespace theory
510
29322
}  // namespace cvc5