GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bag_solver.cpp Lines: 103 111 92.8 %
Date: 2021-08-16 Branches: 159 403 39.5 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Solver for the theory of bags.
14
 */
15
16
#include "theory/bags/bag_solver.h"
17
18
#include "theory/bags/inference_generator.h"
19
#include "theory/bags/inference_manager.h"
20
#include "theory/bags/normal_form.h"
21
#include "theory/bags/solver_state.h"
22
#include "theory/bags/term_registry.h"
23
#include "theory/uf/equality_engine_iterator.h"
24
#include "util/rational.h"
25
26
using namespace std;
27
using namespace cvc5::context;
28
using namespace cvc5::kind;
29
30
namespace cvc5 {
31
namespace theory {
32
namespace bags {
33
34
9853
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
35
9853
    : d_state(s), d_ig(&s, &im), d_im(im), d_termReg(tr)
36
{
37
9853
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
38
9853
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
39
9853
  d_true = NodeManager::currentNM()->mkConst(true);
40
9853
  d_false = NodeManager::currentNM()->mkConst(false);
41
9853
}
42
43
9853
BagSolver::~BagSolver() {}
44
45
15057
void BagSolver::postCheck()
46
{
47
15057
  d_state.initialize();
48
49
15057
  checkDisequalBagTerms();
50
51
  // At this point, all bag and count representatives should be in the solver
52
  // state.
53
16129
  for (const Node& bag : d_state.getBags())
54
  {
55
    // iterate through all bags terms in each equivalent class
56
    eq::EqClassIterator it =
57
1072
        eq::EqClassIterator(bag, d_state.getEqualityEngine());
58
4788
    while (!it.isFinished())
59
    {
60
3716
      Node n = (*it);
61
1858
      Kind k = n.getKind();
62
1858
      switch (k)
63
      {
64
127
        case kind::EMPTYBAG: checkEmpty(n); break;
65
256
        case kind::MK_BAG: checkMkBag(n); break;
66
186
        case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
67
106
        case kind::UNION_MAX: checkUnionMax(n); break;
68
74
        case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
69
40
        case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
70
        case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
71
30
        case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
72
1039
        default: break;
73
      }
74
1858
      it++;
75
    }
76
  }
77
78
  // add non negative constraints for all multiplicities
79
16129
  for (const Node& n : d_state.getBags())
80
  {
81
2537
    for (const Node& e : d_state.getElements(n))
82
    {
83
1465
      checkNonNegativeCountTerms(n, e);
84
    }
85
  }
86
15057
}
87
88
406
set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
89
{
90
406
  set<Node> elements;
91
406
  const set<Node>& downwards = d_state.getElements(n);
92
406
  const set<Node>& upwards0 = d_state.getElements(n[0]);
93
406
  const set<Node>& upwards1 = d_state.getElements(n[1]);
94
95
406
  set_union(downwards.begin(),
96
            downwards.end(),
97
            upwards0.begin(),
98
            upwards0.end(),
99
406
            inserter(elements, elements.begin()));
100
406
  elements.insert(upwards1.begin(), upwards1.end());
101
406
  return elements;
102
}
103
104
127
void BagSolver::checkEmpty(const Node& n)
105
{
106
127
  Assert(n.getKind() == EMPTYBAG);
107
149
  for (const Node& e : d_state.getElements(n))
108
  {
109
44
    InferInfo i = d_ig.empty(n, e);
110
22
    d_im.lemmaTheoryInference(&i);
111
  }
112
127
}
113
114
186
void BagSolver::checkUnionDisjoint(const Node& n)
115
{
116
186
  Assert(n.getKind() == UNION_DISJOINT);
117
372
  std::set<Node> elements = getElementsForBinaryOperator(n);
118
478
  for (const Node& e : elements)
119
  {
120
584
    InferInfo i = d_ig.unionDisjoint(n, e);
121
292
    d_im.lemmaTheoryInference(&i);
122
  }
123
186
}
124
125
106
void BagSolver::checkUnionMax(const Node& n)
126
{
127
106
  Assert(n.getKind() == UNION_MAX);
128
212
  std::set<Node> elements = getElementsForBinaryOperator(n);
129
294
  for (const Node& e : elements)
130
  {
131
376
    InferInfo i = d_ig.unionMax(n, e);
132
188
    d_im.lemmaTheoryInference(&i);
133
  }
134
106
}
135
136
74
void BagSolver::checkIntersectionMin(const Node& n)
137
{
138
74
  Assert(n.getKind() == INTERSECTION_MIN);
139
148
  std::set<Node> elements = getElementsForBinaryOperator(n);
140
202
  for (const Node& e : elements)
141
  {
142
256
    InferInfo i = d_ig.intersection(n, e);
143
128
    d_im.lemmaTheoryInference(&i);
144
  }
145
74
}
146
147
40
void BagSolver::checkDifferenceSubtract(const Node& n)
148
{
149
40
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
150
80
  std::set<Node> elements = getElementsForBinaryOperator(n);
151
96
  for (const Node& e : elements)
152
  {
153
112
    InferInfo i = d_ig.differenceSubtract(n, e);
154
56
    d_im.lemmaTheoryInference(&i);
155
  }
156
40
}
157
158
256
void BagSolver::checkMkBag(const Node& n)
159
{
160
256
  Assert(n.getKind() == MK_BAG);
161
512
  Trace("bags::BagSolver::postCheck")
162
256
      << "BagSolver::checkMkBag Elements of " << n
163
256
      << " are: " << d_state.getElements(n) << std::endl;
164
596
  for (const Node& e : d_state.getElements(n))
165
  {
166
680
    InferInfo i = d_ig.mkBag(n, e);
167
340
    d_im.lemmaTheoryInference(&i);
168
  }
169
256
}
170
1465
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
171
{
172
2930
  InferInfo i = d_ig.nonNegativeCount(bag, element);
173
1465
  d_im.lemmaTheoryInference(&i);
174
1465
}
175
176
void BagSolver::checkDifferenceRemove(const Node& n)
177
{
178
  Assert(n.getKind() == DIFFERENCE_REMOVE);
179
  std::set<Node> elements = getElementsForBinaryOperator(n);
180
  for (const Node& e : elements)
181
  {
182
    InferInfo i = d_ig.differenceRemove(n, e);
183
    d_im.lemmaTheoryInference(&i);
184
  }
185
}
186
187
30
void BagSolver::checkDuplicateRemoval(Node n)
188
{
189
30
  Assert(n.getKind() == DUPLICATE_REMOVAL);
190
60
  set<Node> elements;
191
30
  const set<Node>& downwards = d_state.getElements(n);
192
30
  const set<Node>& upwards = d_state.getElements(n[0]);
193
194
30
  elements.insert(downwards.begin(), downwards.end());
195
30
  elements.insert(upwards.begin(), upwards.end());
196
197
78
  for (const Node& e : elements)
198
  {
199
96
    InferInfo i = d_ig.duplicateRemoval(n, e);
200
48
    d_im.lemmaTheoryInference(&i);
201
  }
202
30
}
203
204
15057
void BagSolver::checkDisequalBagTerms()
205
{
206
15401
  for (const Node& n : d_state.getDisequalBagTerms())
207
  {
208
688
    InferInfo info = d_ig.bagDisequality(n);
209
344
    d_im.lemmaTheoryInference(&info);
210
  }
211
15057
}
212
213
}  // namespace bags
214
}  // namespace theory
215
29340
}  // namespace cvc5