GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/bound_inference.cpp Lines: 88 136 64.7 %
Date: 2021-08-20 Branches: 171 453 37.7 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Gereon Kremer
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Extract bounds on variables from theory atoms.
14
 */
15
16
#include "theory/arith/bound_inference.h"
17
18
#include "theory/arith/normal_form.h"
19
#include "theory/rewriter.h"
20
21
namespace cvc5 {
22
namespace theory {
23
namespace arith {
24
25
std::ostream& operator<<(std::ostream& os, const Bounds& b) {
26
  return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. "
27
            << b.upper_value << (b.upper_strict ? ')' : ']');
28
}
29
30
void BoundInference::reset() { d_bounds.clear(); }
31
32
266292
Bounds& BoundInference::get_or_add(const Node& lhs)
33
{
34
266292
  auto it = d_bounds.find(lhs);
35
266292
  if (it == d_bounds.end())
36
  {
37
126576
    it = d_bounds.emplace(lhs, Bounds()).first;
38
  }
39
266292
  return it->second;
40
}
41
18
Bounds BoundInference::get(const Node& lhs) const
42
{
43
18
  auto it = d_bounds.find(lhs);
44
18
  if (it == d_bounds.end())
45
  {
46
4
    return Bounds();
47
  }
48
14
  return it->second;
49
}
50
51
3053
const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
52
280027
bool BoundInference::add(const Node& n, bool onlyVariables)
53
{
54
560054
  Node tmp = Rewriter::rewrite(n);
55
280027
  if (tmp.getKind() == Kind::CONST_BOOLEAN)
56
  {
57
10
    return false;
58
  }
59
  // Parse the node as a comparison
60
560034
  auto comp = Comparison::parseNormalForm(tmp);
61
560034
  auto dec = comp.decompose(true);
62
280017
  if (onlyVariables && !std::get<0>(dec).isVariable())
63
  {
64
366
    return false;
65
  }
66
67
559302
  Node lhs = std::get<0>(dec).getNode();
68
279651
  Kind relation = std::get<1>(dec);
69
279651
  if (relation == Kind::DISTINCT) return false;
70
476928
  Node bound = std::get<2>(dec).getNode();
71
  // has the form  lhs  ~relation~  bound
72
73
238464
  if (lhs.getType().isInteger())
74
  {
75
294964
    Rational br = bound.getConst<Rational>();
76
147482
    auto* nm = NodeManager::currentNM();
77
147482
    switch (relation)
78
    {
79
      case Kind::LEQ: bound = nm->mkConst<Rational>(br.floor()); break;
80
47785
      case Kind::LT:
81
47785
        bound = nm->mkConst<Rational>((br - 1).ceiling());
82
47785
        relation = Kind::LEQ;
83
47785
        break;
84
      case Kind::GT:
85
        bound = nm->mkConst<Rational>((br + 1).floor());
86
        relation = Kind::GEQ;
87
        break;
88
84183
      case Kind::GEQ: bound = nm->mkConst<Rational>(br.ceiling()); break;
89
147482
      default:;
90
    }
91
294964
    Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " "
92
147482
                       << relation << " " << bound << std::endl;
93
  }
94
95
238464
  switch (relation)
96
  {
97
65069
    case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break;
98
17861
    case Kind::LT: update_upper_bound(n, lhs, bound, true); break;
99
27828
    case Kind::EQUAL:
100
27828
      update_lower_bound(n, lhs, bound, false);
101
27828
      update_upper_bound(n, lhs, bound, false);
102
27828
      break;
103
18703
    case Kind::GT: update_lower_bound(n, lhs, bound, true); break;
104
109003
    case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break;
105
    default: Assert(false);
106
  }
107
238464
  return true;
108
}
109
110
void BoundInference::replaceByOrigins(std::vector<Node>& nodes) const
111
{
112
  std::vector<Node> toAdd;
113
  for (auto& n : nodes)
114
  {
115
    for (const auto& b : d_bounds)
116
    {
117
      if (n == b.second.lower_bound && n == b.second.upper_bound)
118
      {
119
        if (n != b.second.lower_origin && n != b.second.upper_origin)
120
        {
121
          Trace("bound-inf")
122
              << "Replace " << n << " by origins " << b.second.lower_origin
123
              << " and " << b.second.upper_origin << std::endl;
124
          n = b.second.lower_origin;
125
          toAdd.emplace_back(b.second.upper_origin);
126
        }
127
      }
128
      else if (n == b.second.lower_bound)
129
      {
130
        if (n != b.second.lower_origin)
131
        {
132
          Trace("bound-inf") << "Replace " << n << " by origin "
133
                             << b.second.lower_origin << std::endl;
134
          n = b.second.lower_origin;
135
        }
136
      }
137
      else if (n == b.second.upper_bound)
138
      {
139
        if (n != b.second.upper_origin)
140
        {
141
          Trace("bound-inf") << "Replace " << n << " by origin "
142
                             << b.second.upper_origin << std::endl;
143
          n = b.second.upper_origin;
144
        }
145
      }
146
    }
147
  }
148
  nodes.insert(nodes.end(), toAdd.begin(), toAdd.end());
149
}
150
151
155534
void BoundInference::update_lower_bound(const Node& origin,
152
                                        const Node& lhs,
153
                                        const Node& value,
154
                                        bool strict)
155
{
156
  // lhs > or >= value because of origin
157
311068
  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value
158
155534
                     << " due to " << origin << std::endl;
159
155534
  Bounds& b = get_or_add(lhs);
160
311068
  if (b.lower_value.isNull()
161
155534
      || b.lower_value.getConst<Rational>() < value.getConst<Rational>())
162
  {
163
102232
    auto* nm = NodeManager::currentNM();
164
102232
    b.lower_value = value;
165
102232
    b.lower_strict = strict;
166
167
102232
    b.lower_origin = origin;
168
169
102232
    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
170
    {
171
1169
      b.lower_bound = b.upper_bound =
172
2338
          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
173
    }
174
    else
175
    {
176
202126
      b.lower_bound = Rewriter::rewrite(
177
303189
          nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
178
    }
179
  }
180
53302
  else if (strict && b.lower_value == value)
181
  {
182
3601
    auto* nm = NodeManager::currentNM();
183
3601
    b.lower_strict = strict;
184
3601
    b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value));
185
3601
    b.lower_origin = origin;
186
  }
187
155534
}
188
110758
void BoundInference::update_upper_bound(const Node& origin,
189
                                        const Node& lhs,
190
                                        const Node& value,
191
                                        bool strict)
192
{
193
  // lhs < or <= value because of origin
194
221516
  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value
195
110758
                     << " due to " << origin << std::endl;
196
110758
  Bounds& b = get_or_add(lhs);
197
221516
  if (b.upper_value.isNull()
198
110758
      || b.upper_value.getConst<Rational>() > value.getConst<Rational>())
199
  {
200
80276
    auto* nm = NodeManager::currentNM();
201
80276
    b.upper_value = value;
202
80276
    b.upper_strict = strict;
203
80276
    b.upper_origin = origin;
204
80276
    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
205
    {
206
25833
      b.lower_bound = b.upper_bound =
207
51666
          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
208
    }
209
    else
210
    {
211
108886
      b.upper_bound = Rewriter::rewrite(
212
163329
          nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
213
    }
214
  }
215
30482
  else if (strict && b.upper_value == value)
216
  {
217
3095
    auto* nm = NodeManager::currentNM();
218
3095
    b.upper_strict = strict;
219
3095
    b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value));
220
3095
    b.upper_origin = origin;
221
  }
222
110758
}
223
224
std::ostream& operator<<(std::ostream& os, const BoundInference& bi)
225
{
226
  os << "Bounds:" << std::endl;
227
  for (const auto& b : bi.get())
228
  {
229
    os << "\t" << b.first << " -> " << b.second.lower_value << ".."
230
       << b.second.upper_value << std::endl;
231
  }
232
  return os;
233
}
234
235
std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions) {
236
  BoundInference bi;
237
  for (const auto& a: assertions) {
238
    bi.add(a);
239
  }
240
  std::map<Node, std::pair<Node,Node>> res;
241
  for (const auto& b : bi.get())
242
  {
243
    res.emplace(b.first,
244
                std::make_pair(b.second.lower_value, b.second.upper_value));
245
  }
246
  return res;
247
}
248
249
}  // namespace arith
250
}  // namespace theory
251
29358
}  // namespace cvc5