GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/bound_inference.cpp Lines: 82 136 60.3 %
Date: 2021-03-22 Branches: 161 465 34.6 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file bound_inference.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Gereon Kremer
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief
13
 **/
14
15
#include "theory/arith/bound_inference.h"
16
17
#include "theory/arith/normal_form.h"
18
#include "theory/rewriter.h"
19
20
namespace CVC4 {
21
namespace theory {
22
namespace arith {
23
24
std::ostream& operator<<(std::ostream& os, const Bounds& b) {
25
  return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. "
26
            << b.upper_value << (b.upper_strict ? ')' : ']');
27
}
28
29
void BoundInference::reset() { d_bounds.clear(); }
30
31
286359
Bounds& BoundInference::get_or_add(const Node& lhs)
32
{
33
286359
  auto it = d_bounds.find(lhs);
34
286359
  if (it == d_bounds.end())
35
  {
36
134304
    it = d_bounds.emplace(lhs, Bounds()).first;
37
  }
38
286359
  return it->second;
39
}
40
Bounds BoundInference::get(const Node& lhs) const
41
{
42
  auto it = d_bounds.find(lhs);
43
  if (it == d_bounds.end())
44
  {
45
    return Bounds{};
46
  }
47
  return it->second;
48
}
49
50
2926
const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
51
304720
bool BoundInference::add(const Node& n, bool onlyVariables)
52
{
53
609440
  Node tmp = Rewriter::rewrite(n);
54
304720
  if (tmp.getKind() == Kind::CONST_BOOLEAN)
55
  {
56
8
    return false;
57
  }
58
  // Parse the node as a comparison
59
609424
  auto comp = Comparison::parseNormalForm(tmp);
60
609424
  auto dec = comp.decompose(true);
61
304712
  if (onlyVariables && !std::get<0>(dec).isVariable())
62
  {
63
    return false;
64
  }
65
66
609424
  Node lhs = std::get<0>(dec).getNode();
67
304712
  Kind relation = std::get<1>(dec);
68
304712
  if (relation == Kind::DISTINCT) return false;
69
513616
  Node bound = std::get<2>(dec).getNode();
70
  // has the form  lhs  ~relation~  bound
71
72
256808
  if (lhs.getType().isInteger())
73
  {
74
315786
    Rational br = bound.getConst<Rational>();
75
157893
    auto* nm = NodeManager::currentNM();
76
157893
    switch (relation)
77
    {
78
      case Kind::LEQ: bound = nm->mkConst<Rational>(br.floor()); break;
79
53469
      case Kind::LT:
80
53469
        bound = nm->mkConst<Rational>((br - 1).ceiling());
81
53469
        relation = Kind::LEQ;
82
53469
        break;
83
      case Kind::GT:
84
        bound = nm->mkConst<Rational>((br + 1).floor());
85
        relation = Kind::GEQ;
86
        break;
87
86774
      case Kind::GEQ: bound = nm->mkConst<Rational>(br.ceiling()); break;
88
157893
      default:;
89
    }
90
315786
    Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " "
91
157893
                       << relation << " " << bound << std::endl;
92
  }
93
94
256808
  switch (relation)
95
  {
96
71260
    case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break;
97
17883
    case Kind::LT: update_upper_bound(n, lhs, bound, true); break;
98
29551
    case Kind::EQUAL:
99
29551
      update_lower_bound(n, lhs, bound, false);
100
29551
      update_upper_bound(n, lhs, bound, false);
101
29551
      break;
102
22493
    case Kind::GT: update_lower_bound(n, lhs, bound, true); break;
103
115621
    case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break;
104
    default: Assert(false);
105
  }
106
256808
  return true;
107
}
108
109
void BoundInference::replaceByOrigins(std::vector<Node>& nodes) const
110
{
111
  std::vector<Node> toAdd;
112
  for (auto& n : nodes)
113
  {
114
    for (const auto& b : d_bounds)
115
    {
116
      if (n == b.second.lower_bound && n == b.second.upper_bound)
117
      {
118
        if (n != b.second.lower_origin && n != b.second.upper_origin)
119
        {
120
          Trace("bound-inf")
121
              << "Replace " << n << " by origins " << b.second.lower_origin
122
              << " and " << b.second.upper_origin << std::endl;
123
          n = b.second.lower_origin;
124
          toAdd.emplace_back(b.second.upper_origin);
125
        }
126
      }
127
      else if (n == b.second.lower_bound)
128
      {
129
        if (n != b.second.lower_origin)
130
        {
131
          Trace("bound-inf") << "Replace " << n << " by origin "
132
                             << b.second.lower_origin << std::endl;
133
          n = b.second.lower_origin;
134
        }
135
      }
136
      else if (n == b.second.upper_bound)
137
      {
138
        if (n != b.second.upper_origin)
139
        {
140
          Trace("bound-inf") << "Replace " << n << " by origin "
141
                             << b.second.upper_origin << std::endl;
142
          n = b.second.upper_origin;
143
        }
144
      }
145
    }
146
  }
147
  nodes.insert(nodes.end(), toAdd.begin(), toAdd.end());
148
}
149
150
167665
void BoundInference::update_lower_bound(const Node& origin,
151
                                        const Node& lhs,
152
                                        const Node& value,
153
                                        bool strict)
154
{
155
  // lhs > or >= value because of origin
156
335330
  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value
157
167665
                     << " due to " << origin << std::endl;
158
167665
  Bounds& b = get_or_add(lhs);
159
335330
  if (b.lower_value.isNull()
160
167665
      || b.lower_value.getConst<Rational>() < value.getConst<Rational>())
161
  {
162
107045
    auto* nm = NodeManager::currentNM();
163
107045
    b.lower_value = value;
164
107045
    b.lower_strict = strict;
165
166
107045
    b.lower_origin = origin;
167
168
107045
    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
169
    {
170
1385
      b.lower_bound = b.upper_bound =
171
2770
          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
172
    }
173
    else
174
    {
175
211320
      b.lower_bound = Rewriter::rewrite(
176
316980
          nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
177
    }
178
  }
179
60620
  else if (strict && b.lower_value == value)
180
  {
181
4817
    auto* nm = NodeManager::currentNM();
182
4817
    b.lower_strict = strict;
183
4817
    b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value));
184
4817
    b.lower_origin = origin;
185
  }
186
167665
}
187
118694
void BoundInference::update_upper_bound(const Node& origin,
188
                                        const Node& lhs,
189
                                        const Node& value,
190
                                        bool strict)
191
{
192
  // lhs < or <= value because of origin
193
237388
  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value
194
118694
                     << " due to " << origin << std::endl;
195
118694
  Bounds& b = get_or_add(lhs);
196
237388
  if (b.upper_value.isNull()
197
118694
      || b.upper_value.getConst<Rational>() > value.getConst<Rational>())
198
  {
199
83897
    auto* nm = NodeManager::currentNM();
200
83897
    b.upper_value = value;
201
83897
    b.upper_strict = strict;
202
83897
    b.upper_origin = origin;
203
83897
    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
204
    {
205
26528
      b.lower_bound = b.upper_bound =
206
53056
          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
207
    }
208
    else
209
    {
210
114738
      b.upper_bound = Rewriter::rewrite(
211
172107
          nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
212
    }
213
  }
214
34797
  else if (strict && b.upper_value == value)
215
  {
216
2916
    auto* nm = NodeManager::currentNM();
217
2916
    b.upper_strict = strict;
218
2916
    b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value));
219
2916
    b.upper_origin = origin;
220
  }
221
118694
}
222
223
std::ostream& operator<<(std::ostream& os, const BoundInference& bi)
224
{
225
  os << "Bounds:" << std::endl;
226
  for (const auto& b : bi.get())
227
  {
228
    os << "\t" << b.first << " -> " << b.second.lower_value << ".."
229
       << b.second.upper_value << std::endl;
230
  }
231
  return os;
232
}
233
234
std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions) {
235
  BoundInference bi;
236
  for (const auto& a: assertions) {
237
    bi.add(a);
238
  }
239
  std::map<Node, std::pair<Node,Node>> res;
240
  for (const auto& b : bi.get())
241
  {
242
    res.emplace(b.first,
243
                std::make_pair(b.second.lower_value, b.second.upper_value));
244
  }
245
  return res;
246
}
247
248
}  // namespace arith
249
}  // namespace theory
250
26676
}  // namespace CVC4