GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bag_solver.cpp Lines: 103 111 92.8 %
Date: 2021-05-22 Branches: 159 405 39.3 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Solver for the theory of bags.
14
 */
15
16
#include "theory/bags/bag_solver.h"
17
18
#include "theory/bags/inference_generator.h"
19
#include "theory/bags/inference_manager.h"
20
#include "theory/bags/normal_form.h"
21
#include "theory/bags/solver_state.h"
22
#include "theory/bags/term_registry.h"
23
#include "theory/uf/equality_engine_iterator.h"
24
25
using namespace std;
26
using namespace cvc5::context;
27
using namespace cvc5::kind;
28
29
namespace cvc5 {
30
namespace theory {
31
namespace bags {
32
33
9459
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
34
9459
    : d_state(s), d_ig(&s, &im), d_im(im), d_termReg(tr)
35
{
36
9459
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
37
9459
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
38
9459
  d_true = NodeManager::currentNM()->mkConst(true);
39
9459
  d_false = NodeManager::currentNM()->mkConst(false);
40
9459
}
41
42
9459
BagSolver::~BagSolver() {}
43
44
11992
void BagSolver::postCheck()
45
{
46
11992
  d_state.initialize();
47
48
11992
  checkDisequalBagTerms();
49
50
  // At this point, all bag and count representatives should be in the solver
51
  // state.
52
12982
  for (const Node& bag : d_state.getBags())
53
  {
54
    // iterate through all bags terms in each equivalent class
55
    eq::EqClassIterator it =
56
990
        eq::EqClassIterator(bag, d_state.getEqualityEngine());
57
4478
    while (!it.isFinished())
58
    {
59
3488
      Node n = (*it);
60
1744
      Kind k = n.getKind();
61
1744
      switch (k)
62
      {
63
107
        case kind::EMPTYBAG: checkEmpty(n); break;
64
252
        case kind::MK_BAG: checkMkBag(n); break;
65
184
        case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
66
102
        case kind::UNION_MAX: checkUnionMax(n); break;
67
58
        case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
68
40
        case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
69
        case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
70
28
        case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
71
973
        default: break;
72
      }
73
1744
      it++;
74
    }
75
  }
76
77
  // add non negative constraints for all multiplicities
78
12982
  for (const Node& n : d_state.getBags())
79
  {
80
2325
    for (const Node& e : d_state.getElements(n))
81
    {
82
1335
      checkNonNegativeCountTerms(n, e);
83
    }
84
  }
85
11992
}
86
87
384
set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
88
{
89
384
  set<Node> elements;
90
384
  const set<Node>& downwards = d_state.getElements(n);
91
384
  const set<Node>& upwards0 = d_state.getElements(n[0]);
92
384
  const set<Node>& upwards1 = d_state.getElements(n[1]);
93
94
384
  set_union(downwards.begin(),
95
            downwards.end(),
96
            upwards0.begin(),
97
            upwards0.end(),
98
384
            inserter(elements, elements.begin()));
99
384
  elements.insert(upwards1.begin(), upwards1.end());
100
384
  return elements;
101
}
102
103
107
void BagSolver::checkEmpty(const Node& n)
104
{
105
107
  Assert(n.getKind() == EMPTYBAG);
106
129
  for (const Node& e : d_state.getElements(n))
107
  {
108
44
    InferInfo i = d_ig.empty(n, e);
109
22
    d_im.lemmaTheoryInference(&i);
110
  }
111
107
}
112
113
184
void BagSolver::checkUnionDisjoint(const Node& n)
114
{
115
184
  Assert(n.getKind() == UNION_DISJOINT);
116
368
  std::set<Node> elements = getElementsForBinaryOperator(n);
117
472
  for (const Node& e : elements)
118
  {
119
576
    InferInfo i = d_ig.unionDisjoint(n, e);
120
288
    d_im.lemmaTheoryInference(&i);
121
  }
122
184
}
123
124
102
void BagSolver::checkUnionMax(const Node& n)
125
{
126
102
  Assert(n.getKind() == UNION_MAX);
127
204
  std::set<Node> elements = getElementsForBinaryOperator(n);
128
280
  for (const Node& e : elements)
129
  {
130
356
    InferInfo i = d_ig.unionMax(n, e);
131
178
    d_im.lemmaTheoryInference(&i);
132
  }
133
102
}
134
135
58
void BagSolver::checkIntersectionMin(const Node& n)
136
{
137
58
  Assert(n.getKind() == INTERSECTION_MIN);
138
116
  std::set<Node> elements = getElementsForBinaryOperator(n);
139
154
  for (const Node& e : elements)
140
  {
141
192
    InferInfo i = d_ig.intersection(n, e);
142
96
    d_im.lemmaTheoryInference(&i);
143
  }
144
58
}
145
146
40
void BagSolver::checkDifferenceSubtract(const Node& n)
147
{
148
40
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
149
80
  std::set<Node> elements = getElementsForBinaryOperator(n);
150
96
  for (const Node& e : elements)
151
  {
152
112
    InferInfo i = d_ig.differenceSubtract(n, e);
153
56
    d_im.lemmaTheoryInference(&i);
154
  }
155
40
}
156
157
252
void BagSolver::checkMkBag(const Node& n)
158
{
159
252
  Assert(n.getKind() == MK_BAG);
160
504
  Trace("bags::BagSolver::postCheck")
161
252
      << "BagSolver::checkMkBag Elements of " << n
162
252
      << " are: " << d_state.getElements(n) << std::endl;
163
586
  for (const Node& e : d_state.getElements(n))
164
  {
165
668
    InferInfo i = d_ig.mkBag(n, e);
166
334
    d_im.lemmaTheoryInference(&i);
167
  }
168
252
}
169
1335
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
170
{
171
2670
  InferInfo i = d_ig.nonNegativeCount(bag, element);
172
1335
  d_im.lemmaTheoryInference(&i);
173
1335
}
174
175
void BagSolver::checkDifferenceRemove(const Node& n)
176
{
177
  Assert(n.getKind() == DIFFERENCE_REMOVE);
178
  std::set<Node> elements = getElementsForBinaryOperator(n);
179
  for (const Node& e : elements)
180
  {
181
    InferInfo i = d_ig.differenceRemove(n, e);
182
    d_im.lemmaTheoryInference(&i);
183
  }
184
}
185
186
28
void BagSolver::checkDuplicateRemoval(Node n)
187
{
188
28
  Assert(n.getKind() == DUPLICATE_REMOVAL);
189
56
  set<Node> elements;
190
28
  const set<Node>& downwards = d_state.getElements(n);
191
28
  const set<Node>& upwards = d_state.getElements(n[0]);
192
193
28
  elements.insert(downwards.begin(), downwards.end());
194
28
  elements.insert(upwards.begin(), upwards.end());
195
196
70
  for (const Node& e : elements)
197
  {
198
84
    InferInfo i = d_ig.duplicateRemoval(n, e);
199
42
    d_im.lemmaTheoryInference(&i);
200
  }
201
28
}
202
203
11992
void BagSolver::checkDisequalBagTerms()
204
{
205
12262
  for (const Node& n : d_state.getDisequalBagTerms())
206
  {
207
540
    InferInfo info = d_ig.bagDisequality(n);
208
270
    d_im.lemmaTheoryInference(&info);
209
  }
210
11992
}
211
212
}  // namespace bags
213
}  // namespace theory
214
28191
}  // namespace cvc5