GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/bound_inference.cpp Lines: 87 136 64.0 %
Date: 2021-09-18 Branches: 170 453 37.5 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Gereon Kremer
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Extract bounds on variables from theory atoms.
14
 */
15
16
#include "theory/arith/bound_inference.h"
17
18
#include "theory/arith/normal_form.h"
19
#include "theory/rewriter.h"
20
21
namespace cvc5 {
22
namespace theory {
23
namespace arith {
24
25
std::ostream& operator<<(std::ostream& os, const Bounds& b) {
26
  return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. "
27
            << b.upper_value << (b.upper_strict ? ')' : ']');
28
}
29
30
void BoundInference::reset() { d_bounds.clear(); }
31
32
280857
Bounds& BoundInference::get_or_add(const Node& lhs)
33
{
34
280857
  auto it = d_bounds.find(lhs);
35
280857
  if (it == d_bounds.end())
36
  {
37
133675
    it = d_bounds.emplace(lhs, Bounds()).first;
38
  }
39
280857
  return it->second;
40
}
41
18
Bounds BoundInference::get(const Node& lhs) const
42
{
43
18
  auto it = d_bounds.find(lhs);
44
18
  if (it == d_bounds.end())
45
  {
46
4
    return Bounds();
47
  }
48
14
  return it->second;
49
}
50
51
4531
const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
52
303835
bool BoundInference::add(const Node& n, bool onlyVariables)
53
{
54
607670
  Node tmp = Rewriter::rewrite(n);
55
303835
  if (tmp.getKind() == Kind::CONST_BOOLEAN)
56
  {
57
    return false;
58
  }
59
  // Parse the node as a comparison
60
607670
  auto comp = Comparison::parseNormalForm(tmp);
61
607670
  auto dec = comp.decompose(true);
62
303835
  if (onlyVariables && !std::get<0>(dec).isVariable())
63
  {
64
598
    return false;
65
  }
66
67
606474
  Node lhs = std::get<0>(dec).getNode();
68
303237
  Kind relation = std::get<1>(dec);
69
303237
  if (relation == Kind::DISTINCT) return false;
70
511604
  Node bound = std::get<2>(dec).getNode();
71
  // has the form  lhs  ~relation~  bound
72
73
255802
  if (lhs.getType().isInteger())
74
  {
75
308026
    Rational br = bound.getConst<Rational>();
76
154013
    auto* nm = NodeManager::currentNM();
77
154013
    switch (relation)
78
    {
79
      case Kind::LEQ: bound = nm->mkConst<Rational>(br.floor()); break;
80
53073
      case Kind::LT:
81
53073
        bound = nm->mkConst<Rational>((br - 1).ceiling());
82
53073
        relation = Kind::LEQ;
83
53073
        break;
84
      case Kind::GT:
85
        bound = nm->mkConst<Rational>((br + 1).floor());
86
        relation = Kind::GEQ;
87
        break;
88
90190
      case Kind::GEQ: bound = nm->mkConst<Rational>(br.ceiling()); break;
89
154013
      default:;
90
    }
91
308026
    Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " "
92
154013
                       << relation << " " << bound << std::endl;
93
  }
94
95
255802
  switch (relation)
96
  {
97
72465
    case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break;
98
21818
    case Kind::LT: update_upper_bound(n, lhs, bound, true); break;
99
25055
    case Kind::EQUAL:
100
25055
      update_lower_bound(n, lhs, bound, false);
101
25055
      update_upper_bound(n, lhs, bound, false);
102
25055
      break;
103
18567
    case Kind::GT: update_lower_bound(n, lhs, bound, true); break;
104
117897
    case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break;
105
    default: Assert(false);
106
  }
107
255802
  return true;
108
}
109
110
void BoundInference::replaceByOrigins(std::vector<Node>& nodes) const
111
{
112
  std::vector<Node> toAdd;
113
  for (auto& n : nodes)
114
  {
115
    for (const auto& b : d_bounds)
116
    {
117
      if (n == b.second.lower_bound && n == b.second.upper_bound)
118
      {
119
        if (n != b.second.lower_origin && n != b.second.upper_origin)
120
        {
121
          Trace("bound-inf")
122
              << "Replace " << n << " by origins " << b.second.lower_origin
123
              << " and " << b.second.upper_origin << std::endl;
124
          n = b.second.lower_origin;
125
          toAdd.emplace_back(b.second.upper_origin);
126
        }
127
      }
128
      else if (n == b.second.lower_bound)
129
      {
130
        if (n != b.second.lower_origin)
131
        {
132
          Trace("bound-inf") << "Replace " << n << " by origin "
133
                             << b.second.lower_origin << std::endl;
134
          n = b.second.lower_origin;
135
        }
136
      }
137
      else if (n == b.second.upper_bound)
138
      {
139
        if (n != b.second.upper_origin)
140
        {
141
          Trace("bound-inf") << "Replace " << n << " by origin "
142
                             << b.second.upper_origin << std::endl;
143
          n = b.second.upper_origin;
144
        }
145
      }
146
    }
147
  }
148
  nodes.insert(nodes.end(), toAdd.begin(), toAdd.end());
149
}
150
151
161519
void BoundInference::update_lower_bound(const Node& origin,
152
                                        const Node& lhs,
153
                                        const Node& value,
154
                                        bool strict)
155
{
156
  // lhs > or >= value because of origin
157
323038
  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value
158
161519
                     << " due to " << origin << std::endl;
159
161519
  Bounds& b = get_or_add(lhs);
160
323038
  if (b.lower_value.isNull()
161
161519
      || b.lower_value.getConst<Rational>() < value.getConst<Rational>())
162
  {
163
105217
    auto* nm = NodeManager::currentNM();
164
105217
    b.lower_value = value;
165
105217
    b.lower_strict = strict;
166
167
105217
    b.lower_origin = origin;
168
169
105217
    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
170
    {
171
1175
      b.lower_bound = b.upper_bound =
172
2350
          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
173
    }
174
    else
175
    {
176
208084
      b.lower_bound = Rewriter::rewrite(
177
312126
          nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
178
    }
179
  }
180
56302
  else if (strict && b.lower_value == value)
181
  {
182
3638
    auto* nm = NodeManager::currentNM();
183
3638
    b.lower_strict = strict;
184
3638
    b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value));
185
3638
    b.lower_origin = origin;
186
  }
187
161519
}
188
119338
void BoundInference::update_upper_bound(const Node& origin,
189
                                        const Node& lhs,
190
                                        const Node& value,
191
                                        bool strict)
192
{
193
  // lhs < or <= value because of origin
194
238676
  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value
195
119338
                     << " due to " << origin << std::endl;
196
119338
  Bounds& b = get_or_add(lhs);
197
238676
  if (b.upper_value.isNull()
198
119338
      || b.upper_value.getConst<Rational>() > value.getConst<Rational>())
199
  {
200
86422
    auto* nm = NodeManager::currentNM();
201
86422
    b.upper_value = value;
202
86422
    b.upper_strict = strict;
203
86422
    b.upper_origin = origin;
204
86422
    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
205
    {
206
25179
      b.lower_bound = b.upper_bound =
207
50358
          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
208
    }
209
    else
210
    {
211
122486
      b.upper_bound = Rewriter::rewrite(
212
183729
          nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
213
    }
214
  }
215
32916
  else if (strict && b.upper_value == value)
216
  {
217
3285
    auto* nm = NodeManager::currentNM();
218
3285
    b.upper_strict = strict;
219
3285
    b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value));
220
3285
    b.upper_origin = origin;
221
  }
222
119338
}
223
224
std::ostream& operator<<(std::ostream& os, const BoundInference& bi)
225
{
226
  os << "Bounds:" << std::endl;
227
  for (const auto& b : bi.get())
228
  {
229
    os << "\t" << b.first << " -> " << b.second.lower_value << ".."
230
       << b.second.upper_value << std::endl;
231
  }
232
  return os;
233
}
234
235
std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions) {
236
  BoundInference bi;
237
  for (const auto& a: assertions) {
238
    bi.add(a);
239
  }
240
  std::map<Node, std::pair<Node,Node>> res;
241
  for (const auto& b : bi.get())
242
  {
243
    res.emplace(b.first,
244
                std::make_pair(b.second.lower_value, b.second.upper_value));
245
  }
246
  return res;
247
}
248
249
}  // namespace arith
250
}  // namespace theory
251
29574
}  // namespace cvc5