1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Gereon Kremer |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Extract bounds on variables from theory atoms. |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "theory/arith/bound_inference.h" |
17 |
|
|
18 |
|
#include "theory/arith/normal_form.h" |
19 |
|
#include "theory/rewriter.h" |
20 |
|
|
21 |
|
namespace cvc5 { |
22 |
|
namespace theory { |
23 |
|
namespace arith { |
24 |
|
|
25 |
|
std::ostream& operator<<(std::ostream& os, const Bounds& b) { |
26 |
|
return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. " |
27 |
|
<< b.upper_value << (b.upper_strict ? ')' : ']'); |
28 |
|
} |
29 |
|
|
30 |
|
void BoundInference::reset() { d_bounds.clear(); } |
31 |
|
|
32 |
277216 |
Bounds& BoundInference::get_or_add(const Node& lhs) |
33 |
|
{ |
34 |
277216 |
auto it = d_bounds.find(lhs); |
35 |
277216 |
if (it == d_bounds.end()) |
36 |
|
{ |
37 |
131626 |
it = d_bounds.emplace(lhs, Bounds()).first; |
38 |
|
} |
39 |
277216 |
return it->second; |
40 |
|
} |
41 |
18 |
Bounds BoundInference::get(const Node& lhs) const |
42 |
|
{ |
43 |
18 |
auto it = d_bounds.find(lhs); |
44 |
18 |
if (it == d_bounds.end()) |
45 |
|
{ |
46 |
4 |
return Bounds(); |
47 |
|
} |
48 |
14 |
return it->second; |
49 |
|
} |
50 |
|
|
51 |
3081 |
const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; } |
52 |
293348 |
bool BoundInference::add(const Node& n, bool onlyVariables) |
53 |
|
{ |
54 |
586696 |
Node tmp = Rewriter::rewrite(n); |
55 |
293348 |
if (tmp.getKind() == Kind::CONST_BOOLEAN) |
56 |
|
{ |
57 |
10 |
return false; |
58 |
|
} |
59 |
|
// Parse the node as a comparison |
60 |
586676 |
auto comp = Comparison::parseNormalForm(tmp); |
61 |
586676 |
auto dec = comp.decompose(true); |
62 |
293338 |
if (onlyVariables && !std::get<0>(dec).isVariable()) |
63 |
|
{ |
64 |
366 |
return false; |
65 |
|
} |
66 |
|
|
67 |
585944 |
Node lhs = std::get<0>(dec).getNode(); |
68 |
292972 |
Kind relation = std::get<1>(dec); |
69 |
292972 |
if (relation == Kind::DISTINCT) return false; |
70 |
498216 |
Node bound = std::get<2>(dec).getNode(); |
71 |
|
// has the form lhs ~relation~ bound |
72 |
|
|
73 |
249108 |
if (lhs.getType().isInteger()) |
74 |
|
{ |
75 |
314592 |
Rational br = bound.getConst<Rational>(); |
76 |
157296 |
auto* nm = NodeManager::currentNM(); |
77 |
157296 |
switch (relation) |
78 |
|
{ |
79 |
|
case Kind::LEQ: bound = nm->mkConst<Rational>(br.floor()); break; |
80 |
52548 |
case Kind::LT: |
81 |
52548 |
bound = nm->mkConst<Rational>((br - 1).ceiling()); |
82 |
52548 |
relation = Kind::LEQ; |
83 |
52548 |
break; |
84 |
|
case Kind::GT: |
85 |
|
bound = nm->mkConst<Rational>((br + 1).floor()); |
86 |
|
relation = Kind::GEQ; |
87 |
|
break; |
88 |
89054 |
case Kind::GEQ: bound = nm->mkConst<Rational>(br.ceiling()); break; |
89 |
157296 |
default:; |
90 |
|
} |
91 |
314592 |
Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " " |
92 |
157296 |
<< relation << " " << bound << std::endl; |
93 |
|
} |
94 |
|
|
95 |
249108 |
switch (relation) |
96 |
|
{ |
97 |
69832 |
case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break; |
98 |
18185 |
case Kind::LT: update_upper_bound(n, lhs, bound, true); break; |
99 |
28108 |
case Kind::EQUAL: |
100 |
28108 |
update_lower_bound(n, lhs, bound, false); |
101 |
28108 |
update_upper_bound(n, lhs, bound, false); |
102 |
28108 |
break; |
103 |
18703 |
case Kind::GT: update_lower_bound(n, lhs, bound, true); break; |
104 |
114280 |
case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break; |
105 |
|
default: Assert(false); |
106 |
|
} |
107 |
249108 |
return true; |
108 |
|
} |
109 |
|
|
110 |
|
void BoundInference::replaceByOrigins(std::vector<Node>& nodes) const |
111 |
|
{ |
112 |
|
std::vector<Node> toAdd; |
113 |
|
for (auto& n : nodes) |
114 |
|
{ |
115 |
|
for (const auto& b : d_bounds) |
116 |
|
{ |
117 |
|
if (n == b.second.lower_bound && n == b.second.upper_bound) |
118 |
|
{ |
119 |
|
if (n != b.second.lower_origin && n != b.second.upper_origin) |
120 |
|
{ |
121 |
|
Trace("bound-inf") |
122 |
|
<< "Replace " << n << " by origins " << b.second.lower_origin |
123 |
|
<< " and " << b.second.upper_origin << std::endl; |
124 |
|
n = b.second.lower_origin; |
125 |
|
toAdd.emplace_back(b.second.upper_origin); |
126 |
|
} |
127 |
|
} |
128 |
|
else if (n == b.second.lower_bound) |
129 |
|
{ |
130 |
|
if (n != b.second.lower_origin) |
131 |
|
{ |
132 |
|
Trace("bound-inf") << "Replace " << n << " by origin " |
133 |
|
<< b.second.lower_origin << std::endl; |
134 |
|
n = b.second.lower_origin; |
135 |
|
} |
136 |
|
} |
137 |
|
else if (n == b.second.upper_bound) |
138 |
|
{ |
139 |
|
if (n != b.second.upper_origin) |
140 |
|
{ |
141 |
|
Trace("bound-inf") << "Replace " << n << " by origin " |
142 |
|
<< b.second.upper_origin << std::endl; |
143 |
|
n = b.second.upper_origin; |
144 |
|
} |
145 |
|
} |
146 |
|
} |
147 |
|
} |
148 |
|
nodes.insert(nodes.end(), toAdd.begin(), toAdd.end()); |
149 |
|
} |
150 |
|
|
151 |
161091 |
void BoundInference::update_lower_bound(const Node& origin, |
152 |
|
const Node& lhs, |
153 |
|
const Node& value, |
154 |
|
bool strict) |
155 |
|
{ |
156 |
|
// lhs > or >= value because of origin |
157 |
322182 |
Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value |
158 |
161091 |
<< " due to " << origin << std::endl; |
159 |
161091 |
Bounds& b = get_or_add(lhs); |
160 |
322182 |
if (b.lower_value.isNull() |
161 |
161091 |
|| b.lower_value.getConst<Rational>() < value.getConst<Rational>()) |
162 |
|
{ |
163 |
105562 |
auto* nm = NodeManager::currentNM(); |
164 |
105562 |
b.lower_value = value; |
165 |
105562 |
b.lower_strict = strict; |
166 |
|
|
167 |
105562 |
b.lower_origin = origin; |
168 |
|
|
169 |
105562 |
if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value) |
170 |
|
{ |
171 |
1084 |
b.lower_bound = b.upper_bound = |
172 |
2168 |
Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); |
173 |
|
} |
174 |
|
else |
175 |
|
{ |
176 |
208956 |
b.lower_bound = Rewriter::rewrite( |
177 |
313434 |
nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value)); |
178 |
|
} |
179 |
|
} |
180 |
55529 |
else if (strict && b.lower_value == value) |
181 |
|
{ |
182 |
3601 |
auto* nm = NodeManager::currentNM(); |
183 |
3601 |
b.lower_strict = strict; |
184 |
3601 |
b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value)); |
185 |
3601 |
b.lower_origin = origin; |
186 |
|
} |
187 |
161091 |
} |
188 |
116125 |
void BoundInference::update_upper_bound(const Node& origin, |
189 |
|
const Node& lhs, |
190 |
|
const Node& value, |
191 |
|
bool strict) |
192 |
|
{ |
193 |
|
// lhs < or <= value because of origin |
194 |
232250 |
Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value |
195 |
116125 |
<< " due to " << origin << std::endl; |
196 |
116125 |
Bounds& b = get_or_add(lhs); |
197 |
232250 |
if (b.upper_value.isNull() |
198 |
116125 |
|| b.upper_value.getConst<Rational>() > value.getConst<Rational>()) |
199 |
|
{ |
200 |
83580 |
auto* nm = NodeManager::currentNM(); |
201 |
83580 |
b.upper_value = value; |
202 |
83580 |
b.upper_strict = strict; |
203 |
83580 |
b.upper_origin = origin; |
204 |
83580 |
if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value) |
205 |
|
{ |
206 |
26294 |
b.lower_bound = b.upper_bound = |
207 |
52588 |
Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); |
208 |
|
} |
209 |
|
else |
210 |
|
{ |
211 |
114572 |
b.upper_bound = Rewriter::rewrite( |
212 |
171858 |
nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value)); |
213 |
|
} |
214 |
|
} |
215 |
32545 |
else if (strict && b.upper_value == value) |
216 |
|
{ |
217 |
3095 |
auto* nm = NodeManager::currentNM(); |
218 |
3095 |
b.upper_strict = strict; |
219 |
3095 |
b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value)); |
220 |
3095 |
b.upper_origin = origin; |
221 |
|
} |
222 |
116125 |
} |
223 |
|
|
224 |
|
std::ostream& operator<<(std::ostream& os, const BoundInference& bi) |
225 |
|
{ |
226 |
|
os << "Bounds:" << std::endl; |
227 |
|
for (const auto& b : bi.get()) |
228 |
|
{ |
229 |
|
os << "\t" << b.first << " -> " << b.second.lower_value << ".." |
230 |
|
<< b.second.upper_value << std::endl; |
231 |
|
} |
232 |
|
return os; |
233 |
|
} |
234 |
|
|
235 |
|
std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions) { |
236 |
|
BoundInference bi; |
237 |
|
for (const auto& a: assertions) { |
238 |
|
bi.add(a); |
239 |
|
} |
240 |
|
std::map<Node, std::pair<Node,Node>> res; |
241 |
|
for (const auto& b : bi.get()) |
242 |
|
{ |
243 |
|
res.emplace(b.first, |
244 |
|
std::make_pair(b.second.lower_value, b.second.upper_value)); |
245 |
|
} |
246 |
|
return res; |
247 |
|
} |
248 |
|
|
249 |
|
} // namespace arith |
250 |
|
} // namespace theory |
251 |
29340 |
} // namespace cvc5 |