GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/bags/bag_solver.cpp Lines: 103 111 92.8 %
Date: 2021-03-23 Branches: 159 405 39.3 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file bag_solver.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Mudathir Mohamed, Andrew Reynolds
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief solver for the theory of bags.
13
 **
14
 ** solver for the theory of bags.
15
 **/
16
17
#include "theory/bags/bag_solver.h"
18
19
#include "theory/bags/inference_generator.h"
20
#include "theory/bags/inference_manager.h"
21
#include "theory/bags/normal_form.h"
22
#include "theory/bags/solver_state.h"
23
#include "theory/bags/term_registry.h"
24
#include "theory/uf/equality_engine_iterator.h"
25
26
using namespace std;
27
using namespace CVC4::context;
28
using namespace CVC4::kind;
29
30
namespace CVC4 {
31
namespace theory {
32
namespace bags {
33
34
8997
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
35
8997
    : d_state(s), d_ig(&s, &im), d_im(im), d_termReg(tr)
36
{
37
8997
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
38
8997
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
39
8997
  d_true = NodeManager::currentNM()->mkConst(true);
40
8997
  d_false = NodeManager::currentNM()->mkConst(false);
41
8997
}
42
43
8994
BagSolver::~BagSolver() {}
44
45
11257
void BagSolver::postCheck()
46
{
47
11257
  d_state.initialize();
48
49
11257
  checkDisequalBagTerms();
50
51
  // At this point, all bag and count representatives should be in the solver
52
  // state.
53
12072
  for (const Node& bag : d_state.getBags())
54
  {
55
    // iterate through all bags terms in each equivalent class
56
    eq::EqClassIterator it =
57
815
        eq::EqClassIterator(bag, d_state.getEqualityEngine());
58
3753
    while (!it.isFinished())
59
    {
60
2938
      Node n = (*it);
61
1469
      Kind k = n.getKind();
62
1469
      switch (k)
63
      {
64
105
        case kind::EMPTYBAG: checkEmpty(n); break;
65
220
        case kind::MK_BAG: checkMkBag(n); break;
66
170
        case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
67
83
        case kind::UNION_MAX: checkUnionMax(n); break;
68
44
        case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
69
28
        case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
70
        case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
71
23
        case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
72
796
        default: break;
73
      }
74
1469
      it++;
75
    }
76
  }
77
78
  // add non negative constraints for all multiplicities
79
12072
  for (const Node& n : d_state.getBags())
80
  {
81
1884
    for (const Node& e : d_state.getElements(n))
82
    {
83
1069
      checkNonNegativeCountTerms(n, e);
84
    }
85
  }
86
11257
}
87
88
325
set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
89
{
90
325
  set<Node> elements;
91
325
  const set<Node>& downwards = d_state.getElements(n);
92
325
  const set<Node>& upwards0 = d_state.getElements(n[0]);
93
325
  const set<Node>& upwards1 = d_state.getElements(n[1]);
94
95
325
  set_union(downwards.begin(),
96
            downwards.end(),
97
            upwards0.begin(),
98
            upwards0.end(),
99
325
            inserter(elements, elements.begin()));
100
325
  elements.insert(upwards1.begin(), upwards1.end());
101
325
  return elements;
102
}
103
104
105
void BagSolver::checkEmpty(const Node& n)
105
{
106
105
  Assert(n.getKind() == EMPTYBAG);
107
123
  for (const Node& e : d_state.getElements(n))
108
  {
109
36
    InferInfo i = d_ig.empty(n, e);
110
18
    d_im.lemmaTheoryInference(&i);
111
  }
112
105
}
113
114
170
void BagSolver::checkUnionDisjoint(const Node& n)
115
{
116
170
  Assert(n.getKind() == UNION_DISJOINT);
117
340
  std::set<Node> elements = getElementsForBinaryOperator(n);
118
422
  for (const Node& e : elements)
119
  {
120
504
    InferInfo i = d_ig.unionDisjoint(n, e);
121
252
    d_im.lemmaTheoryInference(&i);
122
  }
123
170
}
124
125
83
void BagSolver::checkUnionMax(const Node& n)
126
{
127
83
  Assert(n.getKind() == UNION_MAX);
128
166
  std::set<Node> elements = getElementsForBinaryOperator(n);
129
203
  for (const Node& e : elements)
130
  {
131
240
    InferInfo i = d_ig.unionMax(n, e);
132
120
    d_im.lemmaTheoryInference(&i);
133
  }
134
83
}
135
136
44
void BagSolver::checkIntersectionMin(const Node& n)
137
{
138
44
  Assert(n.getKind() == INTERSECTION_MIN);
139
88
  std::set<Node> elements = getElementsForBinaryOperator(n);
140
132
  for (const Node& e : elements)
141
  {
142
176
    InferInfo i = d_ig.intersection(n, e);
143
88
    d_im.lemmaTheoryInference(&i);
144
  }
145
44
}
146
147
28
void BagSolver::checkDifferenceSubtract(const Node& n)
148
{
149
28
  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
150
56
  std::set<Node> elements = getElementsForBinaryOperator(n);
151
76
  for (const Node& e : elements)
152
  {
153
96
    InferInfo i = d_ig.differenceSubtract(n, e);
154
48
    d_im.lemmaTheoryInference(&i);
155
  }
156
28
}
157
158
220
void BagSolver::checkMkBag(const Node& n)
159
{
160
220
  Assert(n.getKind() == MK_BAG);
161
440
  Trace("bags::BagSolver::postCheck")
162
220
      << "BagSolver::checkMkBag Elements of " << n
163
220
      << " are: " << d_state.getElements(n) << std::endl;
164
502
  for (const Node& e : d_state.getElements(n))
165
  {
166
564
    InferInfo i = d_ig.mkBag(n, e);
167
282
    d_im.lemmaTheoryInference(&i);
168
  }
169
220
}
170
1069
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
171
{
172
2138
  InferInfo i = d_ig.nonNegativeCount(bag, element);
173
1069
  d_im.lemmaTheoryInference(&i);
174
1069
}
175
176
void BagSolver::checkDifferenceRemove(const Node& n)
177
{
178
  Assert(n.getKind() == DIFFERENCE_REMOVE);
179
  std::set<Node> elements = getElementsForBinaryOperator(n);
180
  for (const Node& e : elements)
181
  {
182
    InferInfo i = d_ig.differenceRemove(n, e);
183
    d_im.lemmaTheoryInference(&i);
184
  }
185
}
186
187
23
void BagSolver::checkDuplicateRemoval(Node n)
188
{
189
23
  Assert(n.getKind() == DUPLICATE_REMOVAL);
190
46
  set<Node> elements;
191
23
  const set<Node>& downwards = d_state.getElements(n);
192
23
  const set<Node>& upwards = d_state.getElements(n[0]);
193
194
23
  elements.insert(downwards.begin(), downwards.end());
195
23
  elements.insert(upwards.begin(), upwards.end());
196
197
59
  for (const Node& e : elements)
198
  {
199
72
    InferInfo i = d_ig.duplicateRemoval(n, e);
200
36
    d_im.lemmaTheoryInference(&i);
201
  }
202
23
}
203
204
11257
void BagSolver::checkDisequalBagTerms()
205
{
206
11540
  for (const Node& n : d_state.getDisequalBagTerms())
207
  {
208
566
    InferInfo info = d_ig.bagDisequality(n);
209
283
    d_im.lemmaTheoryInference(&info);
210
  }
211
11257
}
212
213
}  // namespace bags
214
}  // namespace theory
215
26685
}  // namespace CVC4