GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bag_solver.cpp Lines: 103 111 92.8 %
Date: 2021-09-29 Branches: 159 403 39.5 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Mudathir Mohamed, Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Solver for the theory of bags.
14
 */
15
16
#include "theory/bags/bag_solver.h"
17
18
#include "theory/bags/inference_generator.h"
19
#include "theory/bags/inference_manager.h"
20
#include "theory/bags/normal_form.h"
21
#include "theory/bags/solver_state.h"
22
#include "theory/bags/term_registry.h"
23
#include "theory/uf/equality_engine_iterator.h"
24
#include "util/rational.h"
25
26
using namespace std;
27
using namespace cvc5::context;
28
using namespace cvc5::kind;
29
30
namespace cvc5 {
31
namespace theory {
32
namespace bags {
33
34
6271
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
35
6271
    : d_state(s), d_ig(&s, &im), d_im(im), d_termReg(tr)
36
{
37
6271
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
38
6271
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
39
6271
  d_true = NodeManager::currentNM()->mkConst(true);
40
6271
  d_false = NodeManager::currentNM()->mkConst(false);
41
6271
}
42
43
6268
BagSolver::~BagSolver() {}
44
45
10660
void BagSolver::postCheck()
46
{
47
10660
  d_state.initialize();
48
49
10660
  checkDisequalBagTerms();
50
51
  // At this point, all bag and count representatives should be in the solver
52
  // state.
53
11329
  for (const Node& bag : d_state.getBags())
54
  {
55
    // iterate through all bags terms in each equivalent class
56
    eq::EqClassIterator it =
57
669
        eq::EqClassIterator(bag, d_state.getEqualityEngine());
58
2955
    while (!it.isFinished())
59
    {
60
2286
      Node n = (*it);
61
1143
      Kind k = n.getKind();
62
1143
      switch (k)
63
      {
64
75
        case kind::EMPTYBAG: checkEmpty(n); break;
65
172
        case kind::MK_BAG: checkMkBag(n); break;
66
130
        case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
67
50
        case kind::UNION_MAX: checkUnionMax(n); break;
68
48
        case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
69
22
        case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
70
        case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
71
10
        case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
72
636
        default: break;
73
      }
74
1143
      it++;
75
    }
76
  }
77
78
  // add non negative constraints for all multiplicities
79
11329
  for (const Node& n : d_state.getBags())
80
  {
81
1608
    for (const Node& e : d_state.getElements(n))
82
    {
83
939
      checkNonNegativeCountTerms(n, e);
84
    }
85
  }
86
10660
}
87
88
250
set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
89
{
90
250
  set<Node> elements;
91
250
  const set<Node>& downwards = d_state.getElements(n);
92
250
  const set<Node>& upwards0 = d_state.getElements(n[0]);
93
250
  const set<Node>& upwards1 = d_state.getElements(n[1]);
94
95
250
  set_union(downwards.begin(),
96
            downwards.end(),
97
            upwards0.begin(),
98
            upwards0.end(),
99
250
            inserter(elements, elements.begin()));
100
250
  elements.insert(upwards1.begin(), upwards1.end());
101
250
  return elements;
102
}
103
104
75
void BagSolver::checkEmpty(const Node& n)
105
{
106
75
  Assert(n.getKind() == EMPTYBAG);
107
91
  for (const Node& e : d_state.getElements(n))
108
  {
109
32
    InferInfo i = d_ig.empty(n, e);
110
16
    d_im.lemmaTheoryInference(&i);
111
  }
112
75
}
113
114
130
void BagSolver::checkUnionDisjoint(const Node& n)
115
{
116
130
  Assert(n.getKind() == UNION_DISJOINT);
117
260
  std::set<Node> elements = getElementsForBinaryOperator(n);
118
322
  for (const Node& e : elements)
119
  {
120
384
    InferInfo i = d_ig.unionDisjoint(n, e);
121
192
    d_im.lemmaTheoryInference(&i);
122
  }
123
130
}
124
125
50
void BagSolver::checkUnionMax(const Node& n)
126
{
127
50
  Assert(n.getKind() == UNION_MAX);
128
100
  std::set<Node> elements = getElementsForBinaryOperator(n);
129
111
  for (const Node& e : elements)
130
  {
131
122
    InferInfo i = d_ig.unionMax(n, e);
132
61
    d_im.lemmaTheoryInference(&i);
133
  }
134
50
}
135
136
48
void BagSolver::checkIntersectionMin(const Node& n)
137
{
138
48
  Assert(n.getKind() == INTERSECTION_MIN);
139
96
  std::set<Node> elements = getElementsForBinaryOperator(n);
140
153
  for (const Node& e : elements)
141
  {
142
210
    InferInfo i = d_ig.intersection(n, e);
143
105
    d_im.lemmaTheoryInference(&i);
144
  }
145
48
}
146
147
22
void BagSolver::checkDifferenceSubtract(const Node& n)
148
{
149
22
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
150
44
  std::set<Node> elements = getElementsForBinaryOperator(n);
151
64
  for (const Node& e : elements)
152
  {
153
84
    InferInfo i = d_ig.differenceSubtract(n, e);
154
42
    d_im.lemmaTheoryInference(&i);
155
  }
156
22
}
157
158
172
void BagSolver::checkMkBag(const Node& n)
159
{
160
172
  Assert(n.getKind() == MK_BAG);
161
344
  Trace("bags::BagSolver::postCheck")
162
172
      << "BagSolver::checkMkBag Elements of " << n
163
172
      << " are: " << d_state.getElements(n) << std::endl;
164
383
  for (const Node& e : d_state.getElements(n))
165
  {
166
422
    InferInfo i = d_ig.mkBag(n, e);
167
211
    d_im.lemmaTheoryInference(&i);
168
  }
169
172
}
170
939
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
171
{
172
1878
  InferInfo i = d_ig.nonNegativeCount(bag, element);
173
939
  d_im.lemmaTheoryInference(&i);
174
939
}
175
176
void BagSolver::checkDifferenceRemove(const Node& n)
177
{
178
  Assert(n.getKind() == DIFFERENCE_REMOVE);
179
  std::set<Node> elements = getElementsForBinaryOperator(n);
180
  for (const Node& e : elements)
181
  {
182
    InferInfo i = d_ig.differenceRemove(n, e);
183
    d_im.lemmaTheoryInference(&i);
184
  }
185
}
186
187
10
void BagSolver::checkDuplicateRemoval(Node n)
188
{
189
10
  Assert(n.getKind() == DUPLICATE_REMOVAL);
190
20
  set<Node> elements;
191
10
  const set<Node>& downwards = d_state.getElements(n);
192
10
  const set<Node>& upwards = d_state.getElements(n[0]);
193
194
10
  elements.insert(downwards.begin(), downwards.end());
195
10
  elements.insert(upwards.begin(), upwards.end());
196
197
21
  for (const Node& e : elements)
198
  {
199
22
    InferInfo i = d_ig.duplicateRemoval(n, e);
200
11
    d_im.lemmaTheoryInference(&i);
201
  }
202
10
}
203
204
10660
void BagSolver::checkDisequalBagTerms()
205
{
206
10940
  for (const Node& n : d_state.getDisequalBagTerms())
207
  {
208
560
    InferInfo info = d_ig.bagDisequality(n);
209
280
    d_im.lemmaTheoryInference(&info);
210
  }
211
10660
}
212
213
}  // namespace bags
214
}  // namespace theory
215
22746
}  // namespace cvc5