GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bags_rewriter.cpp Lines: 207 219 94.5 %
Date: 2021-03-22 Branches: 827 1797 46.0 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file bags_rewriter.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Mudathir Mohamed
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Bags theory rewriter.
13
 **/
14
15
#include "theory/bags/bags_rewriter.h"
16
17
#include "theory/bags/normal_form.h"
18
19
using namespace CVC4::kind;
20
21
namespace CVC4 {
22
namespace theory {
23
namespace bags {
24
25
3082
BagsRewriteResponse::BagsRewriteResponse()
26
3082
    : d_node(Node::null()), d_rewrite(Rewrite::NONE)
27
{
28
3082
}
29
30
3082
BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
31
3082
    : d_node(n), d_rewrite(rewrite)
32
{
33
3082
}
34
35
BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
36
    : d_node(r.d_node), d_rewrite(r.d_rewrite)
37
{
38
}
39
40
9057
BagsRewriter::BagsRewriter(IntegralHistogramStat<Rewrite>* statistics)
41
9057
    : d_statistics(statistics)
42
{
43
9057
  d_nm = NodeManager::currentNM();
44
9057
  d_zero = d_nm->mkConst(Rational(0));
45
9057
  d_one = d_nm->mkConst(Rational(1));
46
9057
}
47
48
1971
RewriteResponse BagsRewriter::postRewrite(TNode n)
49
{
50
3942
  BagsRewriteResponse response;
51
1971
  if (n.isConst())
52
  {
53
    // no need to rewrite n if it is already in a normal form
54
144
    response = BagsRewriteResponse(n, Rewrite::NONE);
55
  }
56
1827
  else if (n.getKind() == EQUAL)
57
  {
58
500
    response = postRewriteEqual(n);
59
  }
60
1327
  else if (NormalForm::areChildrenConstants(n))
61
  {
62
146
    Node value = NormalForm::evaluate(n);
63
73
    response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
64
  }
65
  else
66
  {
67
1254
    Kind k = n.getKind();
68
1254
    switch (k)
69
    {
70
103
      case MK_BAG: response = rewriteMakeBag(n); break;
71
894
      case BAG_COUNT: response = rewriteBagCount(n); break;
72
14
      case DUPLICATE_REMOVAL: response = rewriteDuplicateRemoval(n); break;
73
66
      case UNION_MAX: response = rewriteUnionMax(n); break;
74
67
      case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
75
34
      case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break;
76
46
      case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break;
77
18
      case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break;
78
2
      case BAG_CHOOSE: response = rewriteChoose(n); break;
79
4
      case BAG_CARD: response = rewriteCard(n); break;
80
2
      case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
81
2
      case BAG_FROM_SET: response = rewriteFromSet(n); break;
82
2
      case BAG_TO_SET: response = rewriteToSet(n); break;
83
      default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
84
    }
85
  }
86
87
3942
  Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
88
1971
                        << " by " << response.d_rewrite << "." << std::endl;
89
90
1971
  if (d_statistics != nullptr)
91
  {
92
1839
    (*d_statistics) << response.d_rewrite;
93
  }
94
1971
  if (response.d_node != n)
95
  {
96
295
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
97
  }
98
1676
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
99
}
100
101
1111
RewriteResponse BagsRewriter::preRewrite(TNode n)
102
{
103
2222
  BagsRewriteResponse response;
104
1111
  Kind k = n.getKind();
105
1111
  switch (k)
106
  {
107
348
    case EQUAL: response = preRewriteEqual(n); break;
108
12
    case SUBBAG: response = rewriteSubBag(n); break;
109
751
    default: response = BagsRewriteResponse(n, Rewrite::NONE);
110
  }
111
112
2222
  Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
113
1111
                        << " by " << response.d_rewrite << "." << std::endl;
114
115
1111
  if (d_statistics != nullptr)
116
  {
117
1109
    (*d_statistics) << response.d_rewrite;
118
  }
119
1111
  if (response.d_node != n)
120
  {
121
79
    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
122
  }
123
1032
  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
124
}
125
126
348
BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
127
{
128
348
  Assert(n.getKind() == EQUAL);
129
348
  if (n[0] == n[1])
130
  {
131
    // (= A A) = true where A is a bag
132
67
    return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
133
  }
134
281
  return BagsRewriteResponse(n, Rewrite::NONE);
135
}
136
137
12
BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
138
{
139
12
  Assert(n.getKind() == SUBBAG);
140
141
  // (bag.is_included A B) = ((difference_subtract A B) == emptybag)
142
24
  Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
143
24
  Node subtract = d_nm->mkNode(DIFFERENCE_SUBTRACT, n[0], n[1]);
144
24
  Node equal = subtract.eqNode(emptybag);
145
24
  return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
146
}
147
148
103
BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
149
{
150
103
  Assert(n.getKind() == MK_BAG);
151
  // return emptybag for negative or zero multiplicity
152
103
  if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
153
  {
154
    // (mkBag x c) = emptybag where c <= 0
155
18
    Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
156
9
    return BagsRewriteResponse(emptybag, Rewrite::MK_BAG_COUNT_NEGATIVE);
157
  }
158
94
  return BagsRewriteResponse(n, Rewrite::NONE);
159
}
160
161
894
BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
162
{
163
894
  Assert(n.getKind() == BAG_COUNT);
164
894
  if (n[1].isConst() && n[1].getKind() == EMPTYBAG)
165
  {
166
    // (bag.count x emptybag) = 0
167
27
    return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
168
  }
169
867
  if (n[1].getKind() == MK_BAG && n[0] == n[1][0])
170
  {
171
    // (bag.count x (mkBag x c) = c
172
21
    return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG);
173
  }
174
846
  return BagsRewriteResponse(n, Rewrite::NONE);
175
}
176
177
14
BagsRewriteResponse BagsRewriter::rewriteDuplicateRemoval(const TNode& n) const
178
{
179
14
  Assert(n.getKind() == DUPLICATE_REMOVAL);
180
44
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
181
44
      && n[0][1].getConst<Rational>().sgn() == 1)
182
  {
183
    // (duplicate_removal (mkBag x n)) = (mkBag x 1)
184
    //  where n is a positive constant
185
4
    Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], d_one);
186
2
    return BagsRewriteResponse(bag, Rewrite::DUPLICATE_REMOVAL_MK_BAG);
187
  }
188
12
  return BagsRewriteResponse(n, Rewrite::NONE);
189
}
190
191
66
BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
192
{
193
66
  Assert(n.getKind() == UNION_MAX);
194
66
  if (n[1].getKind() == EMPTYBAG || n[0] == n[1])
195
  {
196
    // (union_max A A) = A
197
    // (union_max A emptybag) = A
198
4
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
199
  }
200
62
  if (n[0].getKind() == EMPTYBAG)
201
  {
202
    // (union_max emptybag A) = A
203
2
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
204
  }
205
206
236
  if ((n[1].getKind() == UNION_MAX || n[1].getKind() == UNION_DISJOINT)
207
188
      && (n[0] == n[1][0] || n[0] == n[1][1]))
208
  {
209
    // (union_max A (union_max A B)) = (union_max A B)
210
    // (union_max A (union_max B A)) = (union_max B A)
211
    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
212
    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
213
8
    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
214
  }
215
216
204
  if ((n[0].getKind() == UNION_MAX || n[0].getKind() == UNION_DISJOINT)
217
164
      && (n[0][0] == n[1] || n[0][1] == n[1]))
218
  {
219
    // (union_max (union_max A B) A)) = (union_max A B)
220
    // (union_max (union_max B A) A)) = (union_max B A)
221
    // (union_max (union_disjoint A B) A)) = (union_disjoint A B)
222
    // (union_max (union_disjoint B A) A)) = (union_disjoint B A)
223
8
    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
224
  }
225
44
  return BagsRewriteResponse(n, Rewrite::NONE);
226
}
227
228
67
BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
229
{
230
67
  Assert(n.getKind() == UNION_DISJOINT);
231
67
  if (n[1].getKind() == EMPTYBAG)
232
  {
233
    // (union_disjoint A emptybag) = A
234
2
    return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
235
  }
236
65
  if (n[0].getKind() == EMPTYBAG)
237
  {
238
    // (union_disjoint emptybag A) = A
239
3
    return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
240
  }
241
192
  if ((n[0].getKind() == UNION_MAX && n[1].getKind() == INTERSECTION_MIN)
242
242
      || (n[1].getKind() == UNION_MAX && n[0].getKind() == INTERSECTION_MIN))
243
244
  {
245
    // (union_disjoint (union_max A B) (intersection_min A B)) =
246
    //         (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
247
    // check if the operands of union_max and intersection_min are the same
248
8
    std::set<Node> left(n[0].begin(), n[0].end());
249
8
    std::set<Node> right(n[1].begin(), n[1].end());
250
6
    if (left == right)
251
    {
252
8
      Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
253
4
      return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
254
    }
255
  }
256
58
  return BagsRewriteResponse(n, Rewrite::NONE);
257
}
258
259
34
BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
260
{
261
34
  Assert(n.getKind() == INTERSECTION_MIN);
262
34
  if (n[0].getKind() == EMPTYBAG)
263
  {
264
    // (intersection_min emptybag A) = emptybag
265
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
266
  }
267
32
  if (n[1].getKind() == EMPTYBAG)
268
  {
269
    // (intersection_min A emptybag) = emptybag
270
2
    return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
271
  }
272
30
  if (n[0] == n[1])
273
  {
274
    // (intersection_min A A) = A
275
2
    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
276
  }
277
28
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
278
  {
279
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
280
    {
281
      // (intersection_min A (union_disjoint A B)) = A
282
      // (intersection_min A (union_disjoint B A)) = A
283
      // (intersection_min A (union_max A B)) = A
284
      // (intersection_min A (union_max B A)) = A
285
8
      return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
286
    }
287
  }
288
289
20
  if (n[0].getKind() == UNION_DISJOINT || n[0].getKind() == UNION_MAX)
290
  {
291
8
    if (n[1] == n[0][0] || n[1] == n[0][1])
292
    {
293
      // (intersection_min (union_disjoint A B) A) = A
294
      // (intersection_min (union_disjoint B A) A) = A
295
      // (intersection_min (union_max A B) A) = A
296
      // (intersection_min (union_max B A) A) = A
297
8
      return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
298
    }
299
  }
300
301
12
  return BagsRewriteResponse(n, Rewrite::NONE);
302
}
303
304
46
BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
305
    const TNode& n) const
306
{
307
46
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
308
46
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
309
  {
310
    // (difference_subtract A emptybag) = A
311
    // (difference_subtract emptybag A) = emptybag
312
4
    return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
313
  }
314
42
  if (n[0] == n[1])
315
  {
316
    // (difference_subtract A A) = emptybag
317
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
318
2
    return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
319
  }
320
321
40
  if (n[0].getKind() == UNION_DISJOINT)
322
  {
323
4
    if (n[1] == n[0][0])
324
    {
325
      // (difference_subtract (union_disjoint A B) A) = B
326
      return BagsRewriteResponse(n[0][1],
327
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
328
    }
329
2
    if (n[1] == n[0][1])
330
    {
331
      // (difference_subtract (union_disjoint B A) A) = B
332
      return BagsRewriteResponse(n[0][0],
333
2
                                 Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
334
    }
335
  }
336
337
36
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
338
  {
339
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
340
    {
341
      // (difference_subtract A (union_disjoint A B)) = emptybag
342
      // (difference_subtract A (union_disjoint B A)) = emptybag
343
      // (difference_subtract A (union_max A B)) = emptybag
344
      // (difference_subtract A (union_max B A)) = emptybag
345
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
346
8
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
347
    }
348
  }
349
350
28
  if (n[0].getKind() == INTERSECTION_MIN)
351
  {
352
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
353
    {
354
      // (difference_subtract (intersection_min A B) A) = emptybag
355
      // (difference_subtract (intersection_min B A) A) = emptybag
356
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
357
4
      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
358
    }
359
  }
360
361
24
  return BagsRewriteResponse(n, Rewrite::NONE);
362
}
363
364
18
BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
365
{
366
18
  Assert(n.getKind() == DIFFERENCE_REMOVE);
367
368
18
  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
369
  {
370
    // (difference_remove A emptybag) = A
371
    // (difference_remove emptybag B) = emptybag
372
4
    return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
373
  }
374
375
14
  if (n[0] == n[1])
376
  {
377
    // (difference_remove A A) = emptybag
378
4
    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
379
2
    return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
380
  }
381
382
12
  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
383
  {
384
8
    if (n[0] == n[1][0] || n[0] == n[1][1])
385
    {
386
      // (difference_remove A (union_disjoint A B)) = emptybag
387
      // (difference_remove A (union_disjoint B A)) = emptybag
388
      // (difference_remove A (union_max A B)) = emptybag
389
      // (difference_remove A (union_max B A)) = emptybag
390
16
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
391
8
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
392
    }
393
  }
394
395
4
  if (n[0].getKind() == INTERSECTION_MIN)
396
  {
397
4
    if (n[1] == n[0][0] || n[1] == n[0][1])
398
    {
399
      // (difference_remove (intersection_min A B) A) = emptybag
400
      // (difference_remove (intersection_min B A) A) = emptybag
401
8
      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
402
4
      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
403
    }
404
  }
405
406
  return BagsRewriteResponse(n, Rewrite::NONE);
407
}
408
409
2
BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
410
{
411
2
  Assert(n.getKind() == BAG_CHOOSE);
412
2
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
413
  {
414
    // (bag.choose (mkBag x c)) = x where c is a constant > 0
415
2
    return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG);
416
  }
417
  return BagsRewriteResponse(n, Rewrite::NONE);
418
}
419
420
4
BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
421
{
422
4
  Assert(n.getKind() == BAG_CARD);
423
4
  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
424
  {
425
    // (bag.card (mkBag x c)) = c where c is a constant > 0
426
2
    return BagsRewriteResponse(n[0][1], Rewrite::CARD_MK_BAG);
427
  }
428
429
2
  if (n[0].getKind() == UNION_DISJOINT)
430
  {
431
    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
432
4
    Node A = d_nm->mkNode(BAG_CARD, n[0][0]);
433
4
    Node B = d_nm->mkNode(BAG_CARD, n[0][1]);
434
4
    Node plus = d_nm->mkNode(PLUS, A, B);
435
2
    return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT);
436
  }
437
438
  return BagsRewriteResponse(n, Rewrite::NONE);
439
}
440
441
2
BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const
442
{
443
2
  Assert(n.getKind() == BAG_IS_SINGLETON);
444
2
  if (n[0].getKind() == MK_BAG)
445
  {
446
    // (bag.is_singleton (mkBag x c)) = (c == 1)
447
4
    Node equal = n[0][1].eqNode(d_one);
448
2
    return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG);
449
  }
450
  return BagsRewriteResponse(n, Rewrite::NONE);
451
}
452
453
2
BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const
454
{
455
2
  Assert(n.getKind() == BAG_FROM_SET);
456
2
  if (n[0].getKind() == SINGLETON)
457
  {
458
    // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
459
4
    TypeNode type = n[0].getType().getSetElementType();
460
4
    Node bag = d_nm->mkBag(type, n[0][0], d_one);
461
2
    return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON);
462
  }
463
  return BagsRewriteResponse(n, Rewrite::NONE);
464
}
465
466
2
BagsRewriteResponse BagsRewriter::rewriteToSet(const TNode& n) const
467
{
468
2
  Assert(n.getKind() == BAG_TO_SET);
469
8
  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
470
8
      && n[0][1].getConst<Rational>().sgn() == 1)
471
  {
472
    // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
473
    // where n is a positive constant and T is the type of the bag's elements
474
4
    Node set = d_nm->mkSingleton(n[0][0].getType(), n[0][0]);
475
2
    return BagsRewriteResponse(set, Rewrite::TO_SINGLETON);
476
  }
477
  return BagsRewriteResponse(n, Rewrite::NONE);
478
}
479
480
500
BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
481
{
482
500
  Assert(n.getKind() == kind::EQUAL);
483
500
  if (n[0] == n[1])
484
  {
485
    Node ret = d_nm->mkConst(true);
486
    return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
487
  }
488
489
500
  if (n[0].isConst() && n[1].isConst())
490
  {
491
26
    Node ret = d_nm->mkConst(false);
492
13
    return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
493
  }
494
495
  // standard ordering
496
487
  if (n[0] > n[1])
497
  {
498
90
    Node ret = d_nm->mkNode(kind::EQUAL, n[1], n[0]);
499
45
    return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
500
  }
501
442
  return BagsRewriteResponse(n, Rewrite::NONE);
502
}
503
504
}  // namespace bags
505
}  // namespace theory
506
26676
}  // namespace CVC4