GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bag_solver.cpp Lines: 128 130 98.5 %
Date: 2021-11-07 Branches: 203 486 41.8 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Solver for the theory of bags.
14
 */
15
16
#include "theory/bags/bag_solver.h"
17
18
#include "theory/bags/inference_generator.h"
19
#include "theory/bags/inference_manager.h"
20
#include "theory/bags/normal_form.h"
21
#include "theory/bags/solver_state.h"
22
#include "theory/bags/term_registry.h"
23
#include "theory/uf/equality_engine_iterator.h"
24
#include "util/rational.h"
25
26
using namespace std;
27
using namespace cvc5::context;
28
using namespace cvc5::kind;
29
30
namespace cvc5 {
31
namespace theory {
32
namespace bags {
33
34
15273
BagSolver::BagSolver(Env& env,
35
                     SolverState& s,
36
                     InferenceManager& im,
37
15273
                     TermRegistry& tr)
38
    : EnvObj(env),
39
      d_state(s),
40
      d_ig(&s, &im),
41
      d_im(im),
42
      d_termReg(tr),
43
15273
      d_mapCache(userContext())
44
{
45
15273
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
46
15273
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
47
15273
  d_true = NodeManager::currentNM()->mkConst(true);
48
15273
  d_false = NodeManager::currentNM()->mkConst(false);
49
15273
}
50
51
15268
BagSolver::~BagSolver() {}
52
53
18192
void BagSolver::postCheck()
54
{
55
18192
  d_state.initialize();
56
57
18192
  checkDisequalBagTerms();
58
59
  // At this point, all bag and count representatives should be in the solver
60
  // state.
61
19703
  for (const Node& bag : d_state.getBags())
62
  {
63
    // iterate through all bags terms in each equivalent class
64
    eq::EqClassIterator it =
65
1511
        eq::EqClassIterator(bag, d_state.getEqualityEngine());
66
7369
    while (!it.isFinished())
67
    {
68
5858
      Node n = (*it);
69
2929
      Kind k = n.getKind();
70
2929
      switch (k)
71
      {
72
140
        case kind::EMPTYBAG: checkEmpty(n); break;
73
580
        case kind::MK_BAG: checkMkBag(n); break;
74
189
        case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
75
113
        case kind::UNION_MAX: checkUnionMax(n); break;
76
86
        case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
77
58
        case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
78
88
        case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
79
30
        case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
80
32
        case kind::BAG_MAP: checkMap(n); break;
81
1613
        default: break;
82
      }
83
2929
      it++;
84
    }
85
  }
86
87
  // add non negative constraints for all multiplicities
88
19703
  for (const Node& n : d_state.getBags())
89
  {
90
3706
    for (const Node& e : d_state.getElements(n))
91
    {
92
2195
      checkNonNegativeCountTerms(n, e);
93
    }
94
  }
95
18192
}
96
97
534
set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
98
{
99
534
  set<Node> elements;
100
534
  const set<Node>& downwards = d_state.getElements(n);
101
534
  const set<Node>& upwards0 = d_state.getElements(n[0]);
102
534
  const set<Node>& upwards1 = d_state.getElements(n[1]);
103
104
534
  set_union(downwards.begin(),
105
            downwards.end(),
106
            upwards0.begin(),
107
            upwards0.end(),
108
534
            inserter(elements, elements.begin()));
109
534
  elements.insert(upwards1.begin(), upwards1.end());
110
534
  return elements;
111
}
112
113
140
void BagSolver::checkEmpty(const Node& n)
114
{
115
140
  Assert(n.getKind() == EMPTYBAG);
116
166
  for (const Node& e : d_state.getElements(n))
117
  {
118
52
    InferInfo i = d_ig.empty(n, e);
119
26
    d_im.lemmaTheoryInference(&i);
120
  }
121
140
}
122
123
189
void BagSolver::checkUnionDisjoint(const Node& n)
124
{
125
189
  Assert(n.getKind() == UNION_DISJOINT);
126
378
  std::set<Node> elements = getElementsForBinaryOperator(n);
127
469
  for (const Node& e : elements)
128
  {
129
560
    InferInfo i = d_ig.unionDisjoint(n, e);
130
280
    d_im.lemmaTheoryInference(&i);
131
  }
132
189
}
133
134
113
void BagSolver::checkUnionMax(const Node& n)
135
{
136
113
  Assert(n.getKind() == UNION_MAX);
137
226
  std::set<Node> elements = getElementsForBinaryOperator(n);
138
297
  for (const Node& e : elements)
139
  {
140
368
    InferInfo i = d_ig.unionMax(n, e);
141
184
    d_im.lemmaTheoryInference(&i);
142
  }
143
113
}
144
145
86
void BagSolver::checkIntersectionMin(const Node& n)
146
{
147
86
  Assert(n.getKind() == INTERSECTION_MIN);
148
172
  std::set<Node> elements = getElementsForBinaryOperator(n);
149
232
  for (const Node& e : elements)
150
  {
151
292
    InferInfo i = d_ig.intersection(n, e);
152
146
    d_im.lemmaTheoryInference(&i);
153
  }
154
86
}
155
156
58
void BagSolver::checkDifferenceSubtract(const Node& n)
157
{
158
58
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
159
116
  std::set<Node> elements = getElementsForBinaryOperator(n);
160
138
  for (const Node& e : elements)
161
  {
162
160
    InferInfo i = d_ig.differenceSubtract(n, e);
163
80
    d_im.lemmaTheoryInference(&i);
164
  }
165
58
}
166
167
580
void BagSolver::checkMkBag(const Node& n)
168
{
169
580
  Assert(n.getKind() == MK_BAG);
170
1160
  Trace("bags::BagSolver::postCheck")
171
580
      << "BagSolver::checkMkBag Elements of " << n
172
580
      << " are: " << d_state.getElements(n) << std::endl;
173
1491
  for (const Node& e : d_state.getElements(n))
174
  {
175
1822
    InferInfo i = d_ig.mkBag(n, e);
176
911
    d_im.lemmaTheoryInference(&i);
177
  }
178
580
}
179
2195
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
180
{
181
4390
  InferInfo i = d_ig.nonNegativeCount(bag, element);
182
2195
  d_im.lemmaTheoryInference(&i);
183
2195
}
184
185
88
void BagSolver::checkDifferenceRemove(const Node& n)
186
{
187
88
  Assert(n.getKind() == DIFFERENCE_REMOVE);
188
176
  std::set<Node> elements = getElementsForBinaryOperator(n);
189
306
  for (const Node& e : elements)
190
  {
191
436
    InferInfo i = d_ig.differenceRemove(n, e);
192
218
    d_im.lemmaTheoryInference(&i);
193
  }
194
88
}
195
196
30
void BagSolver::checkDuplicateRemoval(Node n)
197
{
198
30
  Assert(n.getKind() == DUPLICATE_REMOVAL);
199
60
  set<Node> elements;
200
30
  const set<Node>& downwards = d_state.getElements(n);
201
30
  const set<Node>& upwards = d_state.getElements(n[0]);
202
203
30
  elements.insert(downwards.begin(), downwards.end());
204
30
  elements.insert(upwards.begin(), upwards.end());
205
206
78
  for (const Node& e : elements)
207
  {
208
96
    InferInfo i = d_ig.duplicateRemoval(n, e);
209
48
    d_im.lemmaTheoryInference(&i);
210
  }
211
30
}
212
213
18192
void BagSolver::checkDisequalBagTerms()
214
{
215
18637
  for (const Node& n : d_state.getDisequalBagTerms())
216
  {
217
890
    InferInfo info = d_ig.bagDisequality(n);
218
445
    d_im.lemmaTheoryInference(&info);
219
  }
220
18192
}
221
222
32
void BagSolver::checkMap(Node n)
223
{
224
32
  Assert(n.getKind() == BAG_MAP);
225
32
  const set<Node>& downwards = d_state.getElements(n);
226
32
  const set<Node>& upwards = d_state.getElements(n[1]);
227
62
  for (const Node& y : downwards)
228
  {
229
30
    if (d_mapCache.count(n) && d_mapCache[n].get()->contains(y))
230
    {
231
26
      continue;
232
    }
233
8
    auto [downInference, uf, preImageSize] = d_ig.mapDownwards(n, y);
234
4
    d_im.lemmaTheoryInference(&downInference);
235
4
    for (const Node& x : upwards)
236
    {
237
      InferInfo upInference = d_ig.mapUpwards(n, uf, preImageSize, y, x);
238
      d_im.lemmaTheoryInference(&upInference);
239
    }
240
4
    if (!d_mapCache.count(n))
241
    {
242
      std::shared_ptr<context::CDHashSet<Node> > set =
243
8
          std::make_shared<context::CDHashSet<Node> >(userContext());
244
4
      d_mapCache.insert(n, set);
245
    }
246
4
    d_mapCache[n].get()->insert(y);
247
  }
248
32
}
249
250
}  // namespace bags
251
}  // namespace theory
252
31137
}  // namespace cvc5