1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Andrew Reynolds, Gereon Kremer, Tim King |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Model object for the non-linear extension class. |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "theory/arith/nl/nl_model.h" |
17 |
|
|
18 |
|
#include "expr/node_algorithm.h" |
19 |
|
#include "options/arith_options.h" |
20 |
|
#include "options/smt_options.h" |
21 |
|
#include "options/theory_options.h" |
22 |
|
#include "theory/arith/arith_msum.h" |
23 |
|
#include "theory/arith/arith_utilities.h" |
24 |
|
#include "theory/arith/nl/nl_lemma_utils.h" |
25 |
|
#include "theory/theory_model.h" |
26 |
|
#include "theory/rewriter.h" |
27 |
|
|
28 |
|
using namespace cvc5::kind; |
29 |
|
|
30 |
|
namespace cvc5 { |
31 |
|
namespace theory { |
32 |
|
namespace arith { |
33 |
|
namespace nl { |
34 |
|
|
35 |
9696 |
NlModel::NlModel() : d_used_approx(false) |
36 |
|
{ |
37 |
9696 |
d_true = NodeManager::currentNM()->mkConst(true); |
38 |
9696 |
d_false = NodeManager::currentNM()->mkConst(false); |
39 |
9696 |
d_zero = NodeManager::currentNM()->mkConst(Rational(0)); |
40 |
9696 |
d_one = NodeManager::currentNM()->mkConst(Rational(1)); |
41 |
9696 |
d_two = NodeManager::currentNM()->mkConst(Rational(2)); |
42 |
9696 |
} |
43 |
|
|
44 |
11695 |
NlModel::~NlModel() {} |
45 |
|
|
46 |
4881 |
void NlModel::reset(TheoryModel* m, const std::map<Node, Node>& arithModel) |
47 |
|
{ |
48 |
4881 |
d_model = m; |
49 |
4881 |
d_concreteModelCache.clear(); |
50 |
4881 |
d_abstractModelCache.clear(); |
51 |
4881 |
d_arithVal = arithModel; |
52 |
4881 |
} |
53 |
|
|
54 |
4897 |
void NlModel::resetCheck() |
55 |
|
{ |
56 |
4897 |
d_used_approx = false; |
57 |
4897 |
d_check_model_solved.clear(); |
58 |
4897 |
d_check_model_bounds.clear(); |
59 |
4897 |
d_check_model_witnesses.clear(); |
60 |
4897 |
d_substitutions.clear(); |
61 |
4897 |
} |
62 |
|
|
63 |
841242 |
Node NlModel::computeConcreteModelValue(TNode n) |
64 |
|
{ |
65 |
841242 |
return computeModelValue(n, true); |
66 |
|
} |
67 |
|
|
68 |
440886 |
Node NlModel::computeAbstractModelValue(TNode n) |
69 |
|
{ |
70 |
440886 |
return computeModelValue(n, false); |
71 |
|
} |
72 |
|
|
73 |
3917345 |
Node NlModel::computeModelValue(TNode n, bool isConcrete) |
74 |
|
{ |
75 |
3917345 |
auto& cache = isConcrete ? d_concreteModelCache : d_abstractModelCache; |
76 |
3917345 |
if (auto it = cache.find(n); it != cache.end()) |
77 |
|
{ |
78 |
2464350 |
return it->second; |
79 |
|
} |
80 |
2905990 |
Trace("nl-ext-mv-debug") << "computeModelValue " << n |
81 |
1452995 |
<< ", isConcrete=" << isConcrete << std::endl; |
82 |
2905990 |
Node ret; |
83 |
1452995 |
if (n.isConst()) |
84 |
|
{ |
85 |
59851 |
ret = n; |
86 |
|
} |
87 |
1393144 |
else if (!isConcrete && hasLinearModelValue(n, ret)) |
88 |
|
{ |
89 |
|
// use model value for abstraction |
90 |
|
} |
91 |
1309323 |
else if (n.getNumChildren() == 0) |
92 |
|
{ |
93 |
|
// we are interested in the exact value of PI, which cannot be computed. |
94 |
|
// hence, we return PI itself when asked for the concrete value. |
95 |
38024 |
if (n.getKind() == PI) |
96 |
|
{ |
97 |
553 |
ret = n; |
98 |
|
} |
99 |
|
else |
100 |
|
{ |
101 |
37471 |
ret = getValueInternal(n); |
102 |
|
} |
103 |
|
} |
104 |
|
else |
105 |
|
{ |
106 |
|
// otherwise, compute true value |
107 |
1271299 |
TheoryId ctid = theory::kindToTheoryId(n.getKind()); |
108 |
1271299 |
if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN) |
109 |
|
{ |
110 |
|
// we directly look up terms not belonging to arithmetic |
111 |
24178 |
ret = getValueInternal(n); |
112 |
|
} |
113 |
|
else |
114 |
|
{ |
115 |
2494242 |
std::vector<Node> children; |
116 |
1247121 |
if (n.getMetaKind() == metakind::PARAMETERIZED) |
117 |
|
{ |
118 |
327 |
children.emplace_back(n.getOperator()); |
119 |
|
} |
120 |
3546509 |
for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++) |
121 |
|
{ |
122 |
2299388 |
children.emplace_back(computeModelValue(n[i], isConcrete)); |
123 |
|
} |
124 |
1247121 |
ret = NodeManager::currentNM()->mkNode(n.getKind(), children); |
125 |
1247121 |
ret = Rewriter::rewrite(ret); |
126 |
|
} |
127 |
|
} |
128 |
2905990 |
Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "[" |
129 |
1452995 |
<< n << "] = " << ret << std::endl; |
130 |
1452995 |
cache[n] = ret; |
131 |
1452995 |
return ret; |
132 |
|
} |
133 |
|
|
134 |
160177 |
int NlModel::compare(TNode i, TNode j, bool isConcrete, bool isAbsolute) |
135 |
|
{ |
136 |
160177 |
if (i == j) |
137 |
|
{ |
138 |
|
return 0; |
139 |
|
} |
140 |
320354 |
Node ci = computeModelValue(i, isConcrete); |
141 |
320354 |
Node cj = computeModelValue(j, isConcrete); |
142 |
160177 |
if (ci.isConst()) |
143 |
|
{ |
144 |
160177 |
if (cj.isConst()) |
145 |
|
{ |
146 |
160177 |
return compareValue(ci, cj, isAbsolute); |
147 |
|
} |
148 |
|
return 1; |
149 |
|
} |
150 |
|
return cj.isConst() ? -1 : 0; |
151 |
|
} |
152 |
|
|
153 |
175652 |
int NlModel::compareValue(TNode i, TNode j, bool isAbsolute) const |
154 |
|
{ |
155 |
175652 |
Assert(i.isConst() && j.isConst()); |
156 |
175652 |
if (i == j) |
157 |
|
{ |
158 |
22908 |
return 0; |
159 |
|
} |
160 |
152744 |
if (!isAbsolute) |
161 |
|
{ |
162 |
6730 |
return i.getConst<Rational>() < j.getConst<Rational>() ? -1 : 1; |
163 |
|
} |
164 |
292028 |
Rational iabs = i.getConst<Rational>().abs(); |
165 |
292028 |
Rational jabs = j.getConst<Rational>().abs(); |
166 |
146014 |
if (iabs == jabs) |
167 |
|
{ |
168 |
6727 |
return 0; |
169 |
|
} |
170 |
139287 |
return iabs < jabs ? -1 : 1; |
171 |
|
} |
172 |
|
|
173 |
273 |
bool NlModel::checkModel(const std::vector<Node>& assertions, |
174 |
|
unsigned d, |
175 |
|
std::vector<NlLemma>& lemmas) |
176 |
|
{ |
177 |
273 |
Trace("nl-ext-cm-debug") << " solve for equalities..." << std::endl; |
178 |
5517 |
for (const Node& atom : assertions) |
179 |
|
{ |
180 |
|
// see if it corresponds to a univariate polynomial equation of degree two |
181 |
5244 |
if (atom.getKind() == EQUAL) |
182 |
|
{ |
183 |
846 |
if (!solveEqualitySimple(atom, d, lemmas)) |
184 |
|
{ |
185 |
|
// no chance we will satisfy this equality |
186 |
518 |
Trace("nl-ext-cm") << "...check-model : failed to solve equality : " |
187 |
259 |
<< atom << std::endl; |
188 |
|
} |
189 |
|
} |
190 |
|
} |
191 |
|
|
192 |
|
// all remaining variables are constrained to their exact model values |
193 |
546 |
Trace("nl-ext-cm-debug") << " set exact bounds for remaining variables..." |
194 |
273 |
<< std::endl; |
195 |
546 |
std::unordered_set<TNode> visited; |
196 |
546 |
std::vector<TNode> visit; |
197 |
546 |
TNode cur; |
198 |
5517 |
for (const Node& a : assertions) |
199 |
|
{ |
200 |
5244 |
visit.push_back(a); |
201 |
19893 |
do |
202 |
|
{ |
203 |
25137 |
cur = visit.back(); |
204 |
25137 |
visit.pop_back(); |
205 |
25137 |
if (visited.find(cur) == visited.end()) |
206 |
|
{ |
207 |
13850 |
visited.insert(cur); |
208 |
13850 |
if (cur.getType().isReal() && !cur.isConst()) |
209 |
|
{ |
210 |
4029 |
Kind k = cur.getKind(); |
211 |
6946 |
if (k != MULT && k != PLUS && k != NONLINEAR_MULT |
212 |
4960 |
&& !isTranscendentalKind(k)) |
213 |
|
{ |
214 |
|
// if we have not set an approximate bound for it |
215 |
638 |
if (!hasAssignment(cur)) |
216 |
|
{ |
217 |
|
// set its exact model value in the substitution |
218 |
654 |
Node curv = computeConcreteModelValue(cur); |
219 |
327 |
if (Trace.isOn("nl-ext-cm")) |
220 |
|
{ |
221 |
|
Trace("nl-ext-cm") |
222 |
|
<< "check-model-bound : exact : " << cur << " = "; |
223 |
|
printRationalApprox("nl-ext-cm", curv); |
224 |
|
Trace("nl-ext-cm") << std::endl; |
225 |
|
} |
226 |
327 |
bool ret = addSubstitution(cur, curv); |
227 |
327 |
AlwaysAssert(ret); |
228 |
|
} |
229 |
|
} |
230 |
|
} |
231 |
33743 |
for (const Node& cn : cur) |
232 |
|
{ |
233 |
19893 |
visit.push_back(cn); |
234 |
|
} |
235 |
|
} |
236 |
25137 |
} while (!visit.empty()); |
237 |
|
} |
238 |
|
|
239 |
273 |
Trace("nl-ext-cm-debug") << " check assertions..." << std::endl; |
240 |
546 |
std::vector<Node> check_assertions; |
241 |
5517 |
for (const Node& a : assertions) |
242 |
|
{ |
243 |
5244 |
if (d_check_model_solved.find(a) == d_check_model_solved.end()) |
244 |
|
{ |
245 |
9314 |
Node av = a; |
246 |
|
// apply the substitution to a |
247 |
4657 |
if (!d_substitutions.empty()) |
248 |
|
{ |
249 |
4103 |
av = Rewriter::rewrite(arithSubstitute(av, d_substitutions)); |
250 |
|
} |
251 |
|
// simple check literal |
252 |
4657 |
if (!simpleCheckModelLit(av)) |
253 |
|
{ |
254 |
826 |
Trace("nl-ext-cm") << "...check-model : assertion failed : " << a |
255 |
413 |
<< std::endl; |
256 |
413 |
check_assertions.push_back(av); |
257 |
826 |
Trace("nl-ext-cm-debug") |
258 |
413 |
<< "...check-model : failed assertion, value : " << av << std::endl; |
259 |
|
} |
260 |
|
} |
261 |
|
} |
262 |
|
|
263 |
273 |
if (!check_assertions.empty()) |
264 |
|
{ |
265 |
205 |
Trace("nl-ext-cm") << "...simple check failed." << std::endl; |
266 |
|
// TODO (#1450) check model for general case |
267 |
205 |
return false; |
268 |
|
} |
269 |
68 |
Trace("nl-ext-cm") << "...simple check succeeded!" << std::endl; |
270 |
68 |
return true; |
271 |
|
} |
272 |
|
|
273 |
687 |
bool NlModel::addSubstitution(TNode v, TNode s) |
274 |
|
{ |
275 |
|
// should not substitute the same variable twice |
276 |
1374 |
Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s |
277 |
687 |
<< std::endl; |
278 |
|
// should not set exact bound more than once |
279 |
687 |
if (d_substitutions.contains(v)) |
280 |
|
{ |
281 |
|
Trace("nl-ext-model") << "...ERROR: already has value." << std::endl; |
282 |
|
// this should never happen since substitutions should be applied eagerly |
283 |
|
Assert(false); |
284 |
|
return false; |
285 |
|
} |
286 |
|
// if we previously had an approximate bound, the exact bound should be in its |
287 |
|
// range |
288 |
|
std::map<Node, std::pair<Node, Node>>::iterator itb = |
289 |
687 |
d_check_model_bounds.find(v); |
290 |
687 |
if (itb != d_check_model_bounds.end()) |
291 |
|
{ |
292 |
|
if (s.getConst<Rational>() >= itb->second.first.getConst<Rational>() |
293 |
|
|| s.getConst<Rational>() <= itb->second.second.getConst<Rational>()) |
294 |
|
{ |
295 |
|
Trace("nl-ext-model") |
296 |
|
<< "...ERROR: already has bound which is out of range." << std::endl; |
297 |
|
return false; |
298 |
|
} |
299 |
|
} |
300 |
687 |
Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end()) |
301 |
|
<< "We tried to add a substitution where we already had a witness term." |
302 |
|
<< std::endl; |
303 |
1374 |
Subs tmp; |
304 |
687 |
tmp.add(v, s); |
305 |
2196 |
for (auto& sub : d_substitutions.d_subs) |
306 |
|
{ |
307 |
3018 |
Node ms = arithSubstitute(sub, tmp); |
308 |
1509 |
if (ms != sub) |
309 |
|
{ |
310 |
108 |
sub = Rewriter::rewrite(ms); |
311 |
|
} |
312 |
|
} |
313 |
687 |
d_substitutions.add(v, s); |
314 |
687 |
return true; |
315 |
|
} |
316 |
|
|
317 |
384 |
bool NlModel::addBound(TNode v, TNode l, TNode u) |
318 |
|
{ |
319 |
768 |
Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " " |
320 |
384 |
<< u << "]" << std::endl; |
321 |
384 |
if (l == u) |
322 |
|
{ |
323 |
|
// bound is exact, can add as substitution |
324 |
|
return addSubstitution(v, l); |
325 |
|
} |
326 |
|
// should not set a bound for a value that is exact |
327 |
384 |
if (d_substitutions.contains(v)) |
328 |
|
{ |
329 |
|
Trace("nl-ext-model") |
330 |
|
<< "...ERROR: setting bound for variable that already has exact value." |
331 |
|
<< std::endl; |
332 |
|
Assert(false); |
333 |
|
return false; |
334 |
|
} |
335 |
384 |
Assert(l.isConst()); |
336 |
384 |
Assert(u.isConst()); |
337 |
384 |
Assert(l.getConst<Rational>() <= u.getConst<Rational>()); |
338 |
384 |
d_check_model_bounds[v] = std::pair<Node, Node>(l, u); |
339 |
384 |
if (Trace.isOn("nl-ext-cm")) |
340 |
|
{ |
341 |
|
Trace("nl-ext-cm") << "check-model-bound : approximate : "; |
342 |
|
printRationalApprox("nl-ext-cm", l); |
343 |
|
Trace("nl-ext-cm") << " <= " << v << " <= "; |
344 |
|
printRationalApprox("nl-ext-cm", u); |
345 |
|
Trace("nl-ext-cm") << std::endl; |
346 |
|
} |
347 |
384 |
return true; |
348 |
|
} |
349 |
|
|
350 |
9 |
bool NlModel::addWitness(TNode v, TNode w) |
351 |
|
{ |
352 |
18 |
Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w |
353 |
9 |
<< std::endl; |
354 |
|
// should not set a witness for a value that is already set |
355 |
9 |
if (d_substitutions.contains(v)) |
356 |
|
{ |
357 |
|
Trace("nl-ext-model") << "...ERROR: setting witness for variable that " |
358 |
|
"already has a constant value." |
359 |
|
<< std::endl; |
360 |
|
Assert(false); |
361 |
|
return false; |
362 |
|
} |
363 |
9 |
d_check_model_witnesses.emplace(v, w); |
364 |
9 |
return true; |
365 |
|
} |
366 |
|
|
367 |
269 |
void NlModel::setUsedApproximate() { d_used_approx = true; } |
368 |
|
|
369 |
20 |
bool NlModel::usedApproximate() const { return d_used_approx; } |
370 |
|
|
371 |
949 |
bool NlModel::solveEqualitySimple(Node eq, |
372 |
|
unsigned d, |
373 |
|
std::vector<NlLemma>& lemmas) |
374 |
|
{ |
375 |
1898 |
Node seq = eq; |
376 |
949 |
if (!d_substitutions.empty()) |
377 |
|
{ |
378 |
763 |
seq = arithSubstitute(eq, d_substitutions); |
379 |
763 |
seq = Rewriter::rewrite(seq); |
380 |
763 |
if (seq.isConst()) |
381 |
|
{ |
382 |
379 |
if (seq.getConst<bool>()) |
383 |
|
{ |
384 |
|
// already true |
385 |
379 |
d_check_model_solved[eq] = Node::null(); |
386 |
379 |
return true; |
387 |
|
} |
388 |
|
return false; |
389 |
|
} |
390 |
|
} |
391 |
570 |
Trace("nl-ext-cms") << "simple solve equality " << seq << "..." << std::endl; |
392 |
570 |
Assert(seq.getKind() == EQUAL); |
393 |
1140 |
std::map<Node, Node> msum; |
394 |
570 |
if (!ArithMSum::getMonomialSumLit(seq, msum)) |
395 |
|
{ |
396 |
|
Trace("nl-ext-cms") << "...fail, could not determine monomial sum." |
397 |
|
<< std::endl; |
398 |
|
return false; |
399 |
|
} |
400 |
570 |
bool is_valid = true; |
401 |
|
// the variable we will solve a quadratic equation for |
402 |
1140 |
Node var; |
403 |
1140 |
Node a = d_zero; |
404 |
1140 |
Node b = d_zero; |
405 |
1140 |
Node c = d_zero; |
406 |
570 |
NodeManager* nm = NodeManager::currentNM(); |
407 |
|
// the list of variables that occur as a monomial in msum, and whose value |
408 |
|
// is so far unconstrained in the model. |
409 |
1140 |
std::unordered_set<Node> unc_vars; |
410 |
|
// the list of variables that occur as a factor in a monomial, and whose |
411 |
|
// value is so far unconstrained in the model. |
412 |
1140 |
std::unordered_set<Node> unc_vars_factor; |
413 |
1645 |
for (std::pair<const Node, Node>& m : msum) |
414 |
|
{ |
415 |
2150 |
Node v = m.first; |
416 |
2150 |
Node coeff = m.second.isNull() ? d_one : m.second; |
417 |
1075 |
if (v.isNull()) |
418 |
|
{ |
419 |
224 |
c = coeff; |
420 |
|
} |
421 |
851 |
else if (v.getKind() == NONLINEAR_MULT) |
422 |
|
{ |
423 |
523 |
if (v.getNumChildren() == 2 && v[0].isVar() && v[0] == v[1] |
424 |
397 |
&& (var.isNull() || var == v[0])) |
425 |
|
{ |
426 |
|
// may solve quadratic |
427 |
43 |
a = coeff; |
428 |
43 |
var = v[0]; |
429 |
|
} |
430 |
|
else |
431 |
|
{ |
432 |
132 |
is_valid = false; |
433 |
264 |
Trace("nl-ext-cms-debug") |
434 |
132 |
<< "...invalid due to non-linear monomial " << v << std::endl; |
435 |
|
// may wish to set an exact bound for a factor and repeat |
436 |
398 |
for (const Node& vc : v) |
437 |
|
{ |
438 |
266 |
unc_vars_factor.insert(vc); |
439 |
|
} |
440 |
|
} |
441 |
|
} |
442 |
676 |
else if (!v.isVar() || (!var.isNull() && var != v)) |
443 |
|
{ |
444 |
816 |
Trace("nl-ext-cms-debug") |
445 |
408 |
<< "...invalid due to factor " << v << std::endl; |
446 |
|
// cannot solve multivariate |
447 |
408 |
if (is_valid) |
448 |
|
{ |
449 |
328 |
is_valid = false; |
450 |
|
// if b is non-zero, then var is also an unconstrained variable |
451 |
328 |
if (b != d_zero) |
452 |
|
{ |
453 |
107 |
unc_vars.insert(var); |
454 |
107 |
unc_vars_factor.insert(var); |
455 |
|
} |
456 |
|
} |
457 |
|
// if v is unconstrained, we may turn this equality into a substitution |
458 |
408 |
unc_vars.insert(v); |
459 |
408 |
unc_vars_factor.insert(v); |
460 |
|
} |
461 |
|
else |
462 |
|
{ |
463 |
|
// set the variable to solve for |
464 |
268 |
b = coeff; |
465 |
268 |
var = v; |
466 |
|
} |
467 |
|
} |
468 |
570 |
if (!is_valid) |
469 |
|
{ |
470 |
|
// see if we can solve for a variable? |
471 |
752 |
for (const Node& uv : unc_vars) |
472 |
|
{ |
473 |
436 |
Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl; |
474 |
|
// cannot already have a bound |
475 |
436 |
if (uv.isVar() && !hasAssignment(uv)) |
476 |
|
{ |
477 |
121 |
Node slv; |
478 |
121 |
Node veqc; |
479 |
118 |
if (ArithMSum::isolate(uv, msum, veqc, slv, EQUAL) != 0) |
480 |
|
{ |
481 |
118 |
Assert(!slv.isNull()); |
482 |
|
// Currently do not support substitution-with-coefficients. |
483 |
|
// We also ensure types are correct here, which avoids substituting |
484 |
|
// a term of non-integer type for a variable of integer type. |
485 |
351 |
if (veqc.isNull() && !expr::hasSubterm(slv, uv) |
486 |
351 |
&& slv.getType().isSubtypeOf(uv.getType())) |
487 |
|
{ |
488 |
230 |
Trace("nl-ext-cm") |
489 |
115 |
<< "check-model-subs : " << uv << " -> " << slv << std::endl; |
490 |
115 |
bool ret = addSubstitution(uv, slv); |
491 |
115 |
if (ret) |
492 |
|
{ |
493 |
230 |
Trace("nl-ext-cms") << "...success, model substitution " << uv |
494 |
115 |
<< " -> " << slv << std::endl; |
495 |
115 |
d_check_model_solved[eq] = uv; |
496 |
|
} |
497 |
115 |
return ret; |
498 |
|
} |
499 |
|
} |
500 |
|
} |
501 |
|
} |
502 |
|
// see if we can assign a variable to a constant |
503 |
622 |
for (const Node& uvf : unc_vars_factor) |
504 |
|
{ |
505 |
409 |
Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl; |
506 |
|
// cannot already have a bound |
507 |
409 |
if (uvf.isVar() && !hasAssignment(uvf)) |
508 |
|
{ |
509 |
206 |
Node uvfv = computeConcreteModelValue(uvf); |
510 |
103 |
if (Trace.isOn("nl-ext-cm")) |
511 |
|
{ |
512 |
|
Trace("nl-ext-cm") << "check-model-bound : exact : " << uvf << " = "; |
513 |
|
printRationalApprox("nl-ext-cm", uvfv); |
514 |
|
Trace("nl-ext-cm") << std::endl; |
515 |
|
} |
516 |
103 |
bool ret = addSubstitution(uvf, uvfv); |
517 |
|
// recurse |
518 |
103 |
return ret ? solveEqualitySimple(eq, d, lemmas) : false; |
519 |
|
} |
520 |
|
} |
521 |
426 |
Trace("nl-ext-cms") << "...fail due to constrained invalid terms." |
522 |
213 |
<< std::endl; |
523 |
213 |
return false; |
524 |
|
} |
525 |
139 |
else if (var.isNull() || var.getType().isInteger()) |
526 |
|
{ |
527 |
|
// cannot solve quadratic equations for integer variables |
528 |
38 |
Trace("nl-ext-cms") << "...fail due to variable to solve for." << std::endl; |
529 |
38 |
return false; |
530 |
|
} |
531 |
|
|
532 |
|
// we are linear, it is simple |
533 |
101 |
if (a == d_zero) |
534 |
|
{ |
535 |
93 |
if (b == d_zero) |
536 |
|
{ |
537 |
|
Trace("nl-ext-cms") << "...fail due to zero a/b." << std::endl; |
538 |
|
Assert(false); |
539 |
|
return false; |
540 |
|
} |
541 |
186 |
Node val = nm->mkConst(-c.getConst<Rational>() / b.getConst<Rational>()); |
542 |
93 |
if (Trace.isOn("nl-ext-cm")) |
543 |
|
{ |
544 |
|
Trace("nl-ext-cm") << "check-model-bound : exact : " << var << " = "; |
545 |
|
printRationalApprox("nl-ext-cm", val); |
546 |
|
Trace("nl-ext-cm") << std::endl; |
547 |
|
} |
548 |
93 |
bool ret = addSubstitution(var, val); |
549 |
93 |
if (ret) |
550 |
|
{ |
551 |
93 |
Trace("nl-ext-cms") << "...success, solved linear." << std::endl; |
552 |
93 |
d_check_model_solved[eq] = var; |
553 |
|
} |
554 |
93 |
return ret; |
555 |
|
} |
556 |
8 |
return false; |
557 |
|
} |
558 |
|
|
559 |
5099 |
bool NlModel::simpleCheckModelLit(Node lit) |
560 |
|
{ |
561 |
10198 |
Trace("nl-ext-cms") << "*** Simple check-model lit for " << lit << "..." |
562 |
5099 |
<< std::endl; |
563 |
5099 |
if (lit.isConst()) |
564 |
|
{ |
565 |
3099 |
Trace("nl-ext-cms") << " return constant." << std::endl; |
566 |
3099 |
return lit.getConst<bool>(); |
567 |
|
} |
568 |
2000 |
NodeManager* nm = NodeManager::currentNM(); |
569 |
2000 |
bool pol = lit.getKind() != kind::NOT; |
570 |
4000 |
Node atom = lit.getKind() == kind::NOT ? lit[0] : lit; |
571 |
|
|
572 |
2000 |
if (atom.getKind() == EQUAL) |
573 |
|
{ |
574 |
|
// x = a is ( x >= a ^ x <= a ) |
575 |
442 |
for (unsigned i = 0; i < 2; i++) |
576 |
|
{ |
577 |
612 |
Node lit2 = nm->mkNode(GEQ, atom[i], atom[1 - i]); |
578 |
442 |
if (!pol) |
579 |
|
{ |
580 |
362 |
lit2 = lit2.negate(); |
581 |
|
} |
582 |
442 |
lit2 = Rewriter::rewrite(lit2); |
583 |
442 |
bool success = simpleCheckModelLit(lit2); |
584 |
442 |
if (success != pol) |
585 |
|
{ |
586 |
|
// false != true -> one conjunct of equality is false, we fail |
587 |
|
// true != false -> one disjunct of disequality is true, we succeed |
588 |
272 |
return success; |
589 |
|
} |
590 |
|
} |
591 |
|
// both checks passed and polarity is true, or both checks failed and |
592 |
|
// polarity is false |
593 |
|
return pol; |
594 |
|
} |
595 |
1728 |
else if (atom.getKind() != GEQ) |
596 |
|
{ |
597 |
|
Trace("nl-ext-cms") << " failed due to unknown literal." << std::endl; |
598 |
|
return false; |
599 |
|
} |
600 |
|
// get the monomial sum |
601 |
3456 |
std::map<Node, Node> msum; |
602 |
1728 |
if (!ArithMSum::getMonomialSumLit(atom, msum)) |
603 |
|
{ |
604 |
|
Trace("nl-ext-cms") << " failed due to get msum." << std::endl; |
605 |
|
return false; |
606 |
|
} |
607 |
|
// simple interval analysis |
608 |
1728 |
if (simpleCheckModelMsum(msum, pol)) |
609 |
|
{ |
610 |
1271 |
return true; |
611 |
|
} |
612 |
|
// can also try reasoning about univariate quadratic equations |
613 |
914 |
Trace("nl-ext-cms-debug") |
614 |
457 |
<< "* Try univariate quadratic analysis..." << std::endl; |
615 |
914 |
std::vector<Node> vs_invalid; |
616 |
914 |
std::unordered_set<Node> vs; |
617 |
914 |
std::map<Node, Node> v_a; |
618 |
914 |
std::map<Node, Node> v_b; |
619 |
|
// get coefficients... |
620 |
1370 |
for (std::pair<const Node, Node>& m : msum) |
621 |
|
{ |
622 |
1826 |
Node v = m.first; |
623 |
913 |
if (!v.isNull()) |
624 |
|
{ |
625 |
537 |
if (v.isVar()) |
626 |
|
{ |
627 |
|
v_b[v] = m.second.isNull() ? d_one : m.second; |
628 |
|
vs.insert(v); |
629 |
|
} |
630 |
1074 |
else if (v.getKind() == NONLINEAR_MULT && v.getNumChildren() == 2 |
631 |
1074 |
&& v[0] == v[1] && v[0].isVar()) |
632 |
|
{ |
633 |
|
v_a[v[0]] = m.second.isNull() ? d_one : m.second; |
634 |
|
vs.insert(v[0]); |
635 |
|
} |
636 |
|
else |
637 |
|
{ |
638 |
537 |
vs_invalid.push_back(v); |
639 |
|
} |
640 |
|
} |
641 |
|
} |
642 |
|
// solve the valid variables... |
643 |
457 |
Node invalid_vsum = vs_invalid.empty() ? d_zero |
644 |
457 |
: (vs_invalid.size() == 1 |
645 |
377 |
? vs_invalid[0] |
646 |
1748 |
: nm->mkNode(PLUS, vs_invalid)); |
647 |
|
// substitution to try |
648 |
914 |
Subs qsub; |
649 |
457 |
for (const Node& v : vs) |
650 |
|
{ |
651 |
|
// is it a valid variable? |
652 |
|
std::map<Node, std::pair<Node, Node>>::iterator bit = |
653 |
|
d_check_model_bounds.find(v); |
654 |
|
if (!expr::hasSubterm(invalid_vsum, v) && bit != d_check_model_bounds.end()) |
655 |
|
{ |
656 |
|
std::map<Node, Node>::iterator it = v_a.find(v); |
657 |
|
if (it != v_a.end()) |
658 |
|
{ |
659 |
|
Node a = it->second; |
660 |
|
Assert(a.isConst()); |
661 |
|
int asgn = a.getConst<Rational>().sgn(); |
662 |
|
Assert(asgn != 0); |
663 |
|
Node t = nm->mkNode(MULT, a, v, v); |
664 |
|
Node b = d_zero; |
665 |
|
it = v_b.find(v); |
666 |
|
if (it != v_b.end()) |
667 |
|
{ |
668 |
|
b = it->second; |
669 |
|
t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v)); |
670 |
|
} |
671 |
|
t = Rewriter::rewrite(t); |
672 |
|
Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic " |
673 |
|
<< t << "..." << std::endl; |
674 |
|
Trace("nl-ext-cms-debug") << " a = " << a << std::endl; |
675 |
|
Trace("nl-ext-cms-debug") << " b = " << b << std::endl; |
676 |
|
// find maximal/minimal value on the interval |
677 |
|
Node apex = nm->mkNode( |
678 |
|
DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a)); |
679 |
|
apex = Rewriter::rewrite(apex); |
680 |
|
Assert(apex.isConst()); |
681 |
|
// for lower, upper, whether we are greater than the apex |
682 |
|
bool cmp[2]; |
683 |
|
Node boundn[2]; |
684 |
|
for (unsigned r = 0; r < 2; r++) |
685 |
|
{ |
686 |
|
boundn[r] = r == 0 ? bit->second.first : bit->second.second; |
687 |
|
Node cmpn = nm->mkNode(GT, boundn[r], apex); |
688 |
|
cmpn = Rewriter::rewrite(cmpn); |
689 |
|
Assert(cmpn.isConst()); |
690 |
|
cmp[r] = cmpn.getConst<bool>(); |
691 |
|
} |
692 |
|
Trace("nl-ext-cms-debug") << " apex " << apex << std::endl; |
693 |
|
Trace("nl-ext-cms-debug") |
694 |
|
<< " lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl; |
695 |
|
Trace("nl-ext-cms-debug") |
696 |
|
<< " upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl; |
697 |
|
Assert(boundn[0].getConst<Rational>() |
698 |
|
<= boundn[1].getConst<Rational>()); |
699 |
|
Node s; |
700 |
|
qsub.add(v, Node()); |
701 |
|
if (cmp[0] != cmp[1]) |
702 |
|
{ |
703 |
|
Assert(!cmp[0] && cmp[1]); |
704 |
|
// does the sign match the bound? |
705 |
|
if ((asgn == 1) == pol) |
706 |
|
{ |
707 |
|
// the apex is the max/min value |
708 |
|
s = apex; |
709 |
|
Trace("nl-ext-cms-debug") << " ...set to apex." << std::endl; |
710 |
|
} |
711 |
|
else |
712 |
|
{ |
713 |
|
// it is one of the endpoints, plug in and compare |
714 |
|
Node tcmpn[2]; |
715 |
|
for (unsigned r = 0; r < 2; r++) |
716 |
|
{ |
717 |
|
qsub.d_subs.back() = boundn[r]; |
718 |
|
Node ts = arithSubstitute(t, qsub); |
719 |
|
tcmpn[r] = Rewriter::rewrite(ts); |
720 |
|
} |
721 |
|
Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]); |
722 |
|
Trace("nl-ext-cms-debug") |
723 |
|
<< " ...both sides of apex, compare " << tcmp << std::endl; |
724 |
|
tcmp = Rewriter::rewrite(tcmp); |
725 |
|
Assert(tcmp.isConst()); |
726 |
|
unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0; |
727 |
|
Trace("nl-ext-cms-debug") |
728 |
|
<< " ...set to " << (bindex_use == 1 ? "upper" : "lower") |
729 |
|
<< std::endl; |
730 |
|
s = boundn[bindex_use]; |
731 |
|
} |
732 |
|
} |
733 |
|
else |
734 |
|
{ |
735 |
|
// both to one side of the apex |
736 |
|
// we figure out which bound to use (lower or upper) based on |
737 |
|
// three factors: |
738 |
|
// (1) whether a's sign is positive, |
739 |
|
// (2) whether we are greater than the apex of the parabola, |
740 |
|
// (3) the polarity of the constraint, i.e. >= or <=. |
741 |
|
// there are 8 cases of these factors, which we test here. |
742 |
|
unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1; |
743 |
|
Trace("nl-ext-cms-debug") |
744 |
|
<< " ...set to " << (bindex_use == 1 ? "upper" : "lower") |
745 |
|
<< std::endl; |
746 |
|
s = boundn[bindex_use]; |
747 |
|
} |
748 |
|
Assert(!s.isNull()); |
749 |
|
qsub.d_subs.back() = s; |
750 |
|
Trace("nl-ext-cms") << "* set bound based on quadratic : " << v |
751 |
|
<< " -> " << s << std::endl; |
752 |
|
} |
753 |
|
} |
754 |
|
} |
755 |
457 |
if (!qsub.empty()) |
756 |
|
{ |
757 |
|
Node slit = arithSubstitute(lit, qsub); |
758 |
|
slit = Rewriter::rewrite(slit); |
759 |
|
return simpleCheckModelLit(slit); |
760 |
|
} |
761 |
457 |
return false; |
762 |
|
} |
763 |
|
|
764 |
1728 |
bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol) |
765 |
|
{ |
766 |
1728 |
Trace("nl-ext-cms-debug") << "* Try simple interval analysis..." << std::endl; |
767 |
1728 |
NodeManager* nm = NodeManager::currentNM(); |
768 |
|
// map from transcendental functions to whether they were set to lower |
769 |
|
// bound |
770 |
1728 |
bool simpleSuccess = true; |
771 |
3456 |
std::map<Node, bool> set_bound; |
772 |
3456 |
std::vector<Node> sum_bound; |
773 |
5229 |
for (const std::pair<const Node, Node>& m : msum) |
774 |
|
{ |
775 |
7002 |
Node v = m.first; |
776 |
3501 |
if (v.isNull()) |
777 |
|
{ |
778 |
1519 |
sum_bound.push_back(m.second.isNull() ? d_one : m.second); |
779 |
|
} |
780 |
|
else |
781 |
|
{ |
782 |
1982 |
Trace("nl-ext-cms-debug") << "- monomial : " << v << std::endl; |
783 |
|
// --- whether we should set a lower bound for this monomial |
784 |
|
bool set_lower = |
785 |
1982 |
(m.second.isNull() || m.second.getConst<Rational>().sgn() == 1) |
786 |
1982 |
== pol; |
787 |
3964 |
Trace("nl-ext-cms-debug") |
788 |
1982 |
<< "set bound to " << (set_lower ? "lower" : "upper") << std::endl; |
789 |
|
|
790 |
|
// --- Collect variables and factors in v |
791 |
3964 |
std::vector<Node> vars; |
792 |
3964 |
std::vector<unsigned> factors; |
793 |
1982 |
if (v.getKind() == NONLINEAR_MULT) |
794 |
|
{ |
795 |
|
unsigned last_start = 0; |
796 |
|
for (unsigned i = 0, nchildren = v.getNumChildren(); i < nchildren; i++) |
797 |
|
{ |
798 |
|
// are we at the end? |
799 |
|
if (i + 1 == nchildren || v[i + 1] != v[i]) |
800 |
|
{ |
801 |
|
unsigned vfact = 1 + (i - last_start); |
802 |
|
last_start = (i + 1); |
803 |
|
vars.push_back(v[i]); |
804 |
|
factors.push_back(vfact); |
805 |
|
} |
806 |
|
} |
807 |
|
} |
808 |
|
else |
809 |
|
{ |
810 |
1982 |
vars.push_back(v); |
811 |
1982 |
factors.push_back(1); |
812 |
|
} |
813 |
|
|
814 |
|
// --- Get the lower and upper bounds and sign information. |
815 |
|
// Whether we have an (odd) number of negative factors in vars, apart |
816 |
|
// from the variable at choose_index. |
817 |
1982 |
bool has_neg_factor = false; |
818 |
1982 |
int choose_index = -1; |
819 |
3964 |
std::vector<Node> ls; |
820 |
3964 |
std::vector<Node> us; |
821 |
|
// the relevant sign information for variables with odd exponents: |
822 |
|
// 1: both signs of the interval of this variable are positive, |
823 |
|
// -1: both signs of the interval of this variable are negative. |
824 |
3964 |
std::vector<int> signs; |
825 |
1982 |
Trace("nl-ext-cms-debug") << "get sign information..." << std::endl; |
826 |
3964 |
for (unsigned i = 0, size = vars.size(); i < size; i++) |
827 |
|
{ |
828 |
3964 |
Node vc = vars[i]; |
829 |
1982 |
unsigned vcfact = factors[i]; |
830 |
1982 |
if (Trace.isOn("nl-ext-cms-debug")) |
831 |
|
{ |
832 |
|
Trace("nl-ext-cms-debug") << "-- " << vc; |
833 |
|
if (vcfact > 1) |
834 |
|
{ |
835 |
|
Trace("nl-ext-cms-debug") << "^" << vcfact; |
836 |
|
} |
837 |
|
Trace("nl-ext-cms-debug") << " "; |
838 |
|
} |
839 |
|
std::map<Node, std::pair<Node, Node>>::iterator bit = |
840 |
1982 |
d_check_model_bounds.find(vc); |
841 |
|
// if there is a model bound for this term |
842 |
1982 |
if (bit != d_check_model_bounds.end()) |
843 |
|
{ |
844 |
3964 |
Node l = bit->second.first; |
845 |
3964 |
Node u = bit->second.second; |
846 |
1982 |
ls.push_back(l); |
847 |
1982 |
us.push_back(u); |
848 |
1982 |
int vsign = 0; |
849 |
1982 |
if (vcfact % 2 == 1) |
850 |
|
{ |
851 |
1982 |
vsign = 1; |
852 |
1982 |
int lsgn = l.getConst<Rational>().sgn(); |
853 |
1982 |
int usgn = u.getConst<Rational>().sgn(); |
854 |
3964 |
Trace("nl-ext-cms-debug") |
855 |
1982 |
<< "bound_sign(" << lsgn << "," << usgn << ") "; |
856 |
1982 |
if (lsgn == -1) |
857 |
|
{ |
858 |
336 |
if (usgn < 1) |
859 |
|
{ |
860 |
|
// must have a negative factor |
861 |
336 |
has_neg_factor = !has_neg_factor; |
862 |
336 |
vsign = -1; |
863 |
|
} |
864 |
|
else if (choose_index == -1) |
865 |
|
{ |
866 |
|
// set the choose index to this |
867 |
|
choose_index = i; |
868 |
|
vsign = 0; |
869 |
|
} |
870 |
|
else |
871 |
|
{ |
872 |
|
// ambiguous, can't determine the bound |
873 |
|
Trace("nl-ext-cms") |
874 |
|
<< " failed due to ambiguious monomial." << std::endl; |
875 |
|
return false; |
876 |
|
} |
877 |
|
} |
878 |
|
} |
879 |
1982 |
Trace("nl-ext-cms-debug") << " -> " << vsign << std::endl; |
880 |
1982 |
signs.push_back(vsign); |
881 |
|
} |
882 |
|
else |
883 |
|
{ |
884 |
|
Assert(d_check_model_witnesses.find(vc) |
885 |
|
== d_check_model_witnesses.end()) |
886 |
|
<< "No variable should be assigned a witness term if we get " |
887 |
|
"here. " |
888 |
|
<< vc << " is, though." << std::endl; |
889 |
|
Trace("nl-ext-cms-debug") << std::endl; |
890 |
|
Trace("nl-ext-cms") |
891 |
|
<< " failed due to unknown bound for " << vc << std::endl; |
892 |
|
// should either assign a model bound or eliminate the variable |
893 |
|
// via substitution |
894 |
|
Assert(false); |
895 |
|
return false; |
896 |
|
} |
897 |
|
} |
898 |
|
// whether we will try to minimize/maximize (-1/1) the absolute value |
899 |
1982 |
int setAbs = (set_lower == has_neg_factor) ? 1 : -1; |
900 |
3964 |
Trace("nl-ext-cms-debug") |
901 |
1982 |
<< "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal") |
902 |
1982 |
<< std::endl; |
903 |
|
|
904 |
3964 |
std::vector<Node> vbs; |
905 |
1982 |
Trace("nl-ext-cms-debug") << "set bounds..." << std::endl; |
906 |
3964 |
for (unsigned i = 0, size = vars.size(); i < size; i++) |
907 |
|
{ |
908 |
3964 |
Node vc = vars[i]; |
909 |
1982 |
unsigned vcfact = factors[i]; |
910 |
3964 |
Node l = ls[i]; |
911 |
3964 |
Node u = us[i]; |
912 |
|
bool vc_set_lower; |
913 |
1982 |
int vcsign = signs[i]; |
914 |
3964 |
Trace("nl-ext-cms-debug") |
915 |
1982 |
<< "Bounds for " << vc << " : " << l << ", " << u |
916 |
1982 |
<< ", sign : " << vcsign << ", factor : " << vcfact << std::endl; |
917 |
1982 |
if (l == u) |
918 |
|
{ |
919 |
|
// by convention, always say it is lower if they are the same |
920 |
|
vc_set_lower = true; |
921 |
|
Trace("nl-ext-cms-debug") |
922 |
|
<< "..." << vc << " equal bound, set to lower" << std::endl; |
923 |
|
} |
924 |
|
else |
925 |
|
{ |
926 |
1982 |
if (vcfact % 2 == 0) |
927 |
|
{ |
928 |
|
// minimize or maximize its absolute value |
929 |
|
Rational la = l.getConst<Rational>().abs(); |
930 |
|
Rational ua = u.getConst<Rational>().abs(); |
931 |
|
if (la == ua) |
932 |
|
{ |
933 |
|
// by convention, always say it is lower if abs are the same |
934 |
|
vc_set_lower = true; |
935 |
|
Trace("nl-ext-cms-debug") |
936 |
|
<< "..." << vc << " equal abs, set to lower" << std::endl; |
937 |
|
} |
938 |
|
else |
939 |
|
{ |
940 |
|
vc_set_lower = (la > ua) == (setAbs == 1); |
941 |
|
} |
942 |
|
} |
943 |
1982 |
else if (signs[i] == 0) |
944 |
|
{ |
945 |
|
// we choose this index to match the overall set_lower |
946 |
|
vc_set_lower = set_lower; |
947 |
|
} |
948 |
|
else |
949 |
|
{ |
950 |
1982 |
vc_set_lower = (signs[i] != setAbs); |
951 |
|
} |
952 |
3964 |
Trace("nl-ext-cms-debug") |
953 |
1982 |
<< "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper") |
954 |
1982 |
<< std::endl; |
955 |
|
} |
956 |
|
// check whether this is a conflicting bound |
957 |
1982 |
std::map<Node, bool>::iterator itsb = set_bound.find(vc); |
958 |
1982 |
if (itsb == set_bound.end()) |
959 |
|
{ |
960 |
1982 |
set_bound[vc] = vc_set_lower; |
961 |
|
} |
962 |
|
else if (itsb->second != vc_set_lower) |
963 |
|
{ |
964 |
|
Trace("nl-ext-cms") |
965 |
|
<< " failed due to conflicting bound for " << vc << std::endl; |
966 |
|
return false; |
967 |
|
} |
968 |
|
// must over/under approximate based on vc_set_lower, computed above |
969 |
3964 |
Node vb = vc_set_lower ? l : u; |
970 |
3964 |
for (unsigned i2 = 0; i2 < vcfact; i2++) |
971 |
|
{ |
972 |
1982 |
vbs.push_back(vb); |
973 |
|
} |
974 |
|
} |
975 |
1982 |
if (!simpleSuccess) |
976 |
|
{ |
977 |
|
break; |
978 |
|
} |
979 |
3964 |
Node vbound = vbs.size() == 1 ? vbs[0] : nm->mkNode(MULT, vbs); |
980 |
1982 |
sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound)); |
981 |
|
} |
982 |
|
} |
983 |
|
// if the exact bound was computed via simple analysis above |
984 |
|
// make the bound |
985 |
3456 |
Node bound; |
986 |
1728 |
if (sum_bound.size() > 1) |
987 |
|
{ |
988 |
1613 |
bound = nm->mkNode(kind::PLUS, sum_bound); |
989 |
|
} |
990 |
115 |
else if (sum_bound.size() == 1) |
991 |
|
{ |
992 |
115 |
bound = sum_bound[0]; |
993 |
|
} |
994 |
|
else |
995 |
|
{ |
996 |
|
bound = d_zero; |
997 |
|
} |
998 |
|
// make the comparison |
999 |
3456 |
Node comp = nm->mkNode(kind::GEQ, bound, d_zero); |
1000 |
1728 |
if (!pol) |
1001 |
|
{ |
1002 |
1071 |
comp = comp.negate(); |
1003 |
|
} |
1004 |
1728 |
Trace("nl-ext-cms") << " comparison is : " << comp << std::endl; |
1005 |
1728 |
comp = Rewriter::rewrite(comp); |
1006 |
1728 |
Assert(comp.isConst()); |
1007 |
1728 |
Trace("nl-ext-cms") << " returned : " << comp << std::endl; |
1008 |
1728 |
return comp == d_true; |
1009 |
|
} |
1010 |
|
|
1011 |
113556 |
void NlModel::printModelValue(const char* c, Node n, unsigned prec) const |
1012 |
|
{ |
1013 |
113556 |
if (Trace.isOn(c)) |
1014 |
|
{ |
1015 |
|
Trace(c) << " " << n << " -> "; |
1016 |
|
const Node& aval = d_abstractModelCache.at(n); |
1017 |
|
if (aval.isConst()) |
1018 |
|
{ |
1019 |
|
printRationalApprox(c, aval, prec); |
1020 |
|
} |
1021 |
|
else |
1022 |
|
{ |
1023 |
|
Trace(c) << "?"; |
1024 |
|
} |
1025 |
|
Trace(c) << " [actual: "; |
1026 |
|
const Node& cval = d_concreteModelCache.at(n); |
1027 |
|
if (cval.isConst()) |
1028 |
|
{ |
1029 |
|
printRationalApprox(c, cval, prec); |
1030 |
|
} |
1031 |
|
else |
1032 |
|
{ |
1033 |
|
Trace(c) << "?"; |
1034 |
|
} |
1035 |
|
Trace(c) << " ]" << std::endl; |
1036 |
|
} |
1037 |
113556 |
} |
1038 |
|
|
1039 |
582 |
void NlModel::getModelValueRepair( |
1040 |
|
std::map<Node, Node>& arithModel, |
1041 |
|
std::map<Node, std::pair<Node, Node>>& approximations, |
1042 |
|
std::map<Node, Node>& witnesses, |
1043 |
|
bool witnessToValue) |
1044 |
|
{ |
1045 |
582 |
Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl; |
1046 |
|
// If we extended the model with entries x -> 0 for unconstrained values, |
1047 |
|
// we first update the map to the extended one. |
1048 |
582 |
if (d_arithVal.size() > arithModel.size()) |
1049 |
|
{ |
1050 |
6 |
arithModel = d_arithVal; |
1051 |
|
} |
1052 |
|
// Record the approximations we used. This code calls the |
1053 |
|
// recordApproximation method of the model, which overrides the model |
1054 |
|
// values for variables that we solved for, using techniques specific to |
1055 |
|
// this class. |
1056 |
582 |
NodeManager* nm = NodeManager::currentNM(); |
1057 |
21 |
for (const std::pair<const Node, std::pair<Node, Node>>& cb : |
1058 |
582 |
d_check_model_bounds) |
1059 |
|
{ |
1060 |
42 |
Node l = cb.second.first; |
1061 |
42 |
Node u = cb.second.second; |
1062 |
42 |
Node pred; |
1063 |
42 |
Node v = cb.first; |
1064 |
21 |
if (l != u) |
1065 |
|
{ |
1066 |
21 |
pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v)); |
1067 |
21 |
Trace("nl-model") << v << " approximated as " << pred << std::endl; |
1068 |
42 |
Node witness; |
1069 |
21 |
if (witnessToValue) |
1070 |
|
{ |
1071 |
|
// witness is the midpoint |
1072 |
|
witness = nm->mkNode( |
1073 |
|
MULT, nm->mkConst(Rational(1, 2)), nm->mkNode(PLUS, l, u)); |
1074 |
|
witness = Rewriter::rewrite(witness); |
1075 |
|
Trace("nl-model") << v << " witness is " << witness << std::endl; |
1076 |
|
} |
1077 |
21 |
approximations[v] = std::pair<Node, Node>(pred, witness); |
1078 |
|
} |
1079 |
|
else |
1080 |
|
{ |
1081 |
|
// overwrite |
1082 |
|
arithModel[v] = l; |
1083 |
|
Trace("nl-model") << v << " exact approximation is " << l << std::endl; |
1084 |
|
} |
1085 |
|
} |
1086 |
591 |
for (const auto& vw : d_check_model_witnesses) |
1087 |
|
{ |
1088 |
9 |
Trace("nl-model") << vw.first << " witness is " << vw.second << std::endl; |
1089 |
9 |
witnesses.emplace(vw.first, vw.second); |
1090 |
|
} |
1091 |
|
// Also record the exact values we used. An exact value can be seen as a |
1092 |
|
// special kind approximation of the form (witness x. x = exact_value). |
1093 |
|
// Notice that the above term gets rewritten such that the choice function |
1094 |
|
// is eliminated. |
1095 |
853 |
for (size_t i = 0; i < d_substitutions.size(); ++i) |
1096 |
|
{ |
1097 |
|
// overwrite |
1098 |
271 |
arithModel[d_substitutions.d_vars[i]] = d_substitutions.d_subs[i]; |
1099 |
542 |
Trace("nl-model") << d_substitutions.d_vars[i] << " solved is " |
1100 |
271 |
<< d_substitutions.d_subs[i] << std::endl; |
1101 |
|
} |
1102 |
|
|
1103 |
|
// multiplication terms should not be given values; their values are |
1104 |
|
// implied by the monomials that they consist of |
1105 |
1164 |
std::vector<Node> amErase; |
1106 |
12692 |
for (const std::pair<const Node, Node>& am : arithModel) |
1107 |
|
{ |
1108 |
12110 |
if (am.first.getKind() == NONLINEAR_MULT) |
1109 |
|
{ |
1110 |
2105 |
amErase.push_back(am.first); |
1111 |
|
} |
1112 |
|
} |
1113 |
2687 |
for (const Node& ae : amErase) |
1114 |
|
{ |
1115 |
2105 |
arithModel.erase(ae); |
1116 |
|
} |
1117 |
582 |
} |
1118 |
|
|
1119 |
61649 |
Node NlModel::getValueInternal(TNode n) |
1120 |
|
{ |
1121 |
61649 |
if (n.isConst()) |
1122 |
|
{ |
1123 |
|
return n; |
1124 |
|
} |
1125 |
61649 |
if (auto it = d_arithVal.find(n); it != d_arithVal.end()) |
1126 |
|
{ |
1127 |
61611 |
AlwaysAssert(it->second.isConst()); |
1128 |
61611 |
return it->second; |
1129 |
|
} |
1130 |
|
// It is unconstrained in the model, return 0. We additionally add it |
1131 |
|
// to mapping from the linear solver. This ensures that if the nonlinear |
1132 |
|
// solver assumes that n = 0, then this assumption is recorded in the overall |
1133 |
|
// model. |
1134 |
38 |
d_arithVal[n] = d_zero; |
1135 |
38 |
return d_zero; |
1136 |
|
} |
1137 |
|
|
1138 |
859 |
bool NlModel::hasAssignment(Node v) const |
1139 |
|
{ |
1140 |
859 |
if (d_check_model_bounds.find(v) != d_check_model_bounds.end()) |
1141 |
|
{ |
1142 |
|
return true; |
1143 |
|
} |
1144 |
859 |
if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end()) |
1145 |
|
{ |
1146 |
|
return true; |
1147 |
|
} |
1148 |
859 |
return (d_substitutions.contains(v)); |
1149 |
|
} |
1150 |
|
|
1151 |
355216 |
bool NlModel::hasLinearModelValue(TNode v, Node& val) const |
1152 |
|
{ |
1153 |
355216 |
auto it = d_arithVal.find(v); |
1154 |
355216 |
if (it != d_arithVal.end()) |
1155 |
|
{ |
1156 |
83821 |
val = it->second; |
1157 |
83821 |
return true; |
1158 |
|
} |
1159 |
271395 |
return false; |
1160 |
|
} |
1161 |
|
|
1162 |
|
} // namespace nl |
1163 |
|
} // namespace arith |
1164 |
|
} // namespace theory |
1165 |
31137 |
} // namespace cvc5 |