1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Andrew Reynolds, Gereon Kremer, Tim King |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Model object for the non-linear extension class. |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "theory/arith/nl/nl_model.h" |
17 |
|
|
18 |
|
#include "expr/node_algorithm.h" |
19 |
|
#include "options/arith_options.h" |
20 |
|
#include "options/smt_options.h" |
21 |
|
#include "options/theory_options.h" |
22 |
|
#include "theory/arith/arith_msum.h" |
23 |
|
#include "theory/arith/arith_utilities.h" |
24 |
|
#include "theory/arith/nl/nl_lemma_utils.h" |
25 |
|
#include "theory/theory_model.h" |
26 |
|
#include "theory/rewriter.h" |
27 |
|
|
28 |
|
using namespace cvc5::kind; |
29 |
|
|
30 |
|
namespace cvc5 { |
31 |
|
namespace theory { |
32 |
|
namespace arith { |
33 |
|
namespace nl { |
34 |
|
|
35 |
5150 |
NlModel::NlModel(context::Context* c) : d_used_approx(false) |
36 |
|
{ |
37 |
5150 |
d_true = NodeManager::currentNM()->mkConst(true); |
38 |
5150 |
d_false = NodeManager::currentNM()->mkConst(false); |
39 |
5150 |
d_zero = NodeManager::currentNM()->mkConst(Rational(0)); |
40 |
5150 |
d_one = NodeManager::currentNM()->mkConst(Rational(1)); |
41 |
5150 |
d_two = NodeManager::currentNM()->mkConst(Rational(2)); |
42 |
5150 |
} |
43 |
|
|
44 |
6396 |
NlModel::~NlModel() {} |
45 |
|
|
46 |
2987 |
void NlModel::reset(TheoryModel* m, std::map<Node, Node>& arithModel) |
47 |
|
{ |
48 |
2987 |
d_model = m; |
49 |
2987 |
d_mv[0].clear(); |
50 |
2987 |
d_mv[1].clear(); |
51 |
2987 |
d_arithVal.clear(); |
52 |
|
// process arithModel |
53 |
2987 |
std::map<Node, Node>::iterator it; |
54 |
51868 |
for (const std::pair<const Node, Node>& m2 : arithModel) |
55 |
|
{ |
56 |
48881 |
d_arithVal[m2.first] = m2.second; |
57 |
|
} |
58 |
2987 |
} |
59 |
|
|
60 |
3003 |
void NlModel::resetCheck() |
61 |
|
{ |
62 |
3003 |
d_used_approx = false; |
63 |
3003 |
d_check_model_solved.clear(); |
64 |
3003 |
d_check_model_bounds.clear(); |
65 |
3003 |
d_check_model_witnesses.clear(); |
66 |
3003 |
d_check_model_vars.clear(); |
67 |
3003 |
d_check_model_subs.clear(); |
68 |
3003 |
} |
69 |
|
|
70 |
313422 |
Node NlModel::computeConcreteModelValue(Node n) |
71 |
|
{ |
72 |
313422 |
return computeModelValue(n, true); |
73 |
|
} |
74 |
|
|
75 |
194436 |
Node NlModel::computeAbstractModelValue(Node n) |
76 |
|
{ |
77 |
194436 |
return computeModelValue(n, false); |
78 |
|
} |
79 |
|
|
80 |
1572796 |
Node NlModel::computeModelValue(Node n, bool isConcrete) |
81 |
|
{ |
82 |
1572796 |
unsigned index = isConcrete ? 0 : 1; |
83 |
1572796 |
std::map<Node, Node>::iterator it = d_mv[index].find(n); |
84 |
1572796 |
if (it != d_mv[index].end()) |
85 |
|
{ |
86 |
979304 |
return it->second; |
87 |
|
} |
88 |
1186984 |
Trace("nl-ext-mv-debug") << "computeModelValue " << n << ", index=" << index |
89 |
593492 |
<< std::endl; |
90 |
1186984 |
Node ret; |
91 |
593492 |
Kind nk = n.getKind(); |
92 |
593492 |
if (n.isConst()) |
93 |
|
{ |
94 |
30611 |
ret = n; |
95 |
|
} |
96 |
562881 |
else if (!isConcrete && hasTerm(n)) |
97 |
|
{ |
98 |
|
// use model value for abstraction |
99 |
39371 |
ret = getRepresentative(n); |
100 |
|
} |
101 |
523510 |
else if (n.getNumChildren() == 0) |
102 |
|
{ |
103 |
|
// we are interested in the exact value of PI, which cannot be computed. |
104 |
|
// hence, we return PI itself when asked for the concrete value. |
105 |
22075 |
if (nk == PI) |
106 |
|
{ |
107 |
484 |
ret = n; |
108 |
|
} |
109 |
|
else |
110 |
|
{ |
111 |
21591 |
ret = getValueInternal(n); |
112 |
|
} |
113 |
|
} |
114 |
|
else |
115 |
|
{ |
116 |
|
// otherwise, compute true value |
117 |
501435 |
TheoryId ctid = theory::kindToTheoryId(nk); |
118 |
501435 |
if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN) |
119 |
|
{ |
120 |
|
// we directly look up terms not belonging to arithmetic |
121 |
6253 |
ret = getValueInternal(n); |
122 |
|
} |
123 |
|
else |
124 |
|
{ |
125 |
990364 |
std::vector<Node> children; |
126 |
495182 |
if (n.getMetaKind() == metakind::PARAMETERIZED) |
127 |
|
{ |
128 |
205 |
children.push_back(n.getOperator()); |
129 |
|
} |
130 |
1404846 |
for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++) |
131 |
|
{ |
132 |
1819328 |
Node mc = computeModelValue(n[i], isConcrete); |
133 |
909664 |
children.push_back(mc); |
134 |
|
} |
135 |
495182 |
ret = NodeManager::currentNM()->mkNode(nk, children); |
136 |
495182 |
ret = Rewriter::rewrite(ret); |
137 |
|
} |
138 |
|
} |
139 |
1186984 |
Trace("nl-ext-mv-debug") << "computed " << (index == 0 ? "M" : "M_A") << "[" |
140 |
593492 |
<< n << "] = " << ret << std::endl; |
141 |
593492 |
d_mv[index][n] = ret; |
142 |
593492 |
return ret; |
143 |
|
} |
144 |
|
|
145 |
121473 |
bool NlModel::hasTerm(Node n) const |
146 |
|
{ |
147 |
121473 |
return d_arithVal.find(n) != d_arithVal.end(); |
148 |
|
} |
149 |
|
|
150 |
39371 |
Node NlModel::getRepresentative(Node n) const |
151 |
|
{ |
152 |
39371 |
if (n.isConst()) |
153 |
|
{ |
154 |
|
return n; |
155 |
|
} |
156 |
39371 |
std::map<Node, Node>::const_iterator it = d_arithVal.find(n); |
157 |
39371 |
if (it != d_arithVal.end()) |
158 |
|
{ |
159 |
39371 |
AlwaysAssert(it->second.isConst()); |
160 |
39371 |
return it->second; |
161 |
|
} |
162 |
|
return d_model->getRepresentative(n); |
163 |
|
} |
164 |
|
|
165 |
27844 |
Node NlModel::getValueInternal(Node n) |
166 |
|
{ |
167 |
27844 |
if (n.isConst()) |
168 |
|
{ |
169 |
|
return n; |
170 |
|
} |
171 |
27844 |
std::map<Node, Node>::const_iterator it = d_arithVal.find(n); |
172 |
27844 |
if (it != d_arithVal.end()) |
173 |
|
{ |
174 |
27812 |
AlwaysAssert(it->second.isConst()); |
175 |
27812 |
return it->second; |
176 |
|
} |
177 |
|
// It is unconstrained in the model, return 0. We additionally add it |
178 |
|
// to mapping from the linear solver. This ensures that if the nonlinear |
179 |
|
// solver assumes that n = 0, then this assumption is recorded in the overall |
180 |
|
// model. |
181 |
32 |
d_arithVal[n] = d_zero; |
182 |
32 |
return d_zero; |
183 |
|
} |
184 |
|
|
185 |
72441 |
int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute) |
186 |
|
{ |
187 |
144882 |
Node ci = computeModelValue(i, isConcrete); |
188 |
144882 |
Node cj = computeModelValue(j, isConcrete); |
189 |
72441 |
if (ci.isConst()) |
190 |
|
{ |
191 |
72441 |
if (cj.isConst()) |
192 |
|
{ |
193 |
72441 |
return compareValue(ci, cj, isAbsolute); |
194 |
|
} |
195 |
|
return 1; |
196 |
|
} |
197 |
|
return cj.isConst() ? -1 : 0; |
198 |
|
} |
199 |
|
|
200 |
82833 |
int NlModel::compareValue(Node i, Node j, bool isAbsolute) const |
201 |
|
{ |
202 |
82833 |
Assert(i.isConst() && j.isConst()); |
203 |
|
int ret; |
204 |
82833 |
if (i == j) |
205 |
|
{ |
206 |
11334 |
ret = 0; |
207 |
|
} |
208 |
71499 |
else if (!isAbsolute) |
209 |
|
{ |
210 |
5886 |
ret = i.getConst<Rational>() < j.getConst<Rational>() ? 1 : -1; |
211 |
|
} |
212 |
|
else |
213 |
|
{ |
214 |
196839 |
ret = (i.getConst<Rational>().abs() == j.getConst<Rational>().abs() |
215 |
127345 |
? 0 |
216 |
250809 |
: (i.getConst<Rational>().abs() < j.getConst<Rational>().abs() |
217 |
185196 |
? 1 |
218 |
|
: -1)); |
219 |
|
} |
220 |
82833 |
return ret; |
221 |
|
} |
222 |
|
|
223 |
265 |
bool NlModel::checkModel(const std::vector<Node>& assertions, |
224 |
|
unsigned d, |
225 |
|
std::vector<NlLemma>& lemmas) |
226 |
|
{ |
227 |
265 |
Trace("nl-ext-cm-debug") << " solve for equalities..." << std::endl; |
228 |
3652 |
for (const Node& atom : assertions) |
229 |
|
{ |
230 |
|
// see if it corresponds to a univariate polynomial equation of degree two |
231 |
3387 |
if (atom.getKind() == EQUAL) |
232 |
|
{ |
233 |
549 |
if (!solveEqualitySimple(atom, d, lemmas)) |
234 |
|
{ |
235 |
|
// no chance we will satisfy this equality |
236 |
318 |
Trace("nl-ext-cm") << "...check-model : failed to solve equality : " |
237 |
159 |
<< atom << std::endl; |
238 |
|
} |
239 |
|
} |
240 |
|
} |
241 |
|
|
242 |
|
// all remaining variables are constrained to their exact model values |
243 |
530 |
Trace("nl-ext-cm-debug") << " set exact bounds for remaining variables..." |
244 |
265 |
<< std::endl; |
245 |
530 |
std::unordered_set<TNode> visited; |
246 |
530 |
std::vector<TNode> visit; |
247 |
530 |
TNode cur; |
248 |
3652 |
for (const Node& a : assertions) |
249 |
|
{ |
250 |
3387 |
visit.push_back(a); |
251 |
13625 |
do |
252 |
|
{ |
253 |
17012 |
cur = visit.back(); |
254 |
17012 |
visit.pop_back(); |
255 |
17012 |
if (visited.find(cur) == visited.end()) |
256 |
|
{ |
257 |
10241 |
visited.insert(cur); |
258 |
10241 |
if (cur.getType().isReal() && !cur.isConst()) |
259 |
|
{ |
260 |
3068 |
Kind k = cur.getKind(); |
261 |
5212 |
if (k != MULT && k != PLUS && k != NONLINEAR_MULT |
262 |
3980 |
&& !isTranscendentalKind(k)) |
263 |
|
{ |
264 |
|
// if we have not set an approximate bound for it |
265 |
618 |
if (!hasCheckModelAssignment(cur)) |
266 |
|
{ |
267 |
|
// set its exact model value in the substitution |
268 |
634 |
Node curv = computeConcreteModelValue(cur); |
269 |
317 |
if (Trace.isOn("nl-ext-cm")) |
270 |
|
{ |
271 |
|
Trace("nl-ext-cm") |
272 |
|
<< "check-model-bound : exact : " << cur << " = "; |
273 |
|
printRationalApprox("nl-ext-cm", curv); |
274 |
|
Trace("nl-ext-cm") << std::endl; |
275 |
|
} |
276 |
317 |
bool ret = addCheckModelSubstitution(cur, curv); |
277 |
317 |
AlwaysAssert(ret); |
278 |
|
} |
279 |
|
} |
280 |
|
} |
281 |
23866 |
for (const Node& cn : cur) |
282 |
|
{ |
283 |
13625 |
visit.push_back(cn); |
284 |
|
} |
285 |
|
} |
286 |
17012 |
} while (!visit.empty()); |
287 |
|
} |
288 |
|
|
289 |
265 |
Trace("nl-ext-cm-debug") << " check assertions..." << std::endl; |
290 |
530 |
std::vector<Node> check_assertions; |
291 |
3652 |
for (const Node& a : assertions) |
292 |
|
{ |
293 |
3387 |
if (d_check_model_solved.find(a) == d_check_model_solved.end()) |
294 |
|
{ |
295 |
5994 |
Node av = a; |
296 |
|
// apply the substitution to a |
297 |
2997 |
if (!d_check_model_vars.empty()) |
298 |
|
{ |
299 |
2594 |
av = arithSubstitute(av, d_check_model_vars, d_check_model_subs); |
300 |
2594 |
av = Rewriter::rewrite(av); |
301 |
|
} |
302 |
|
// simple check literal |
303 |
2997 |
if (!simpleCheckModelLit(av)) |
304 |
|
{ |
305 |
778 |
Trace("nl-ext-cm") << "...check-model : assertion failed : " << a |
306 |
389 |
<< std::endl; |
307 |
389 |
check_assertions.push_back(av); |
308 |
778 |
Trace("nl-ext-cm-debug") |
309 |
389 |
<< "...check-model : failed assertion, value : " << av << std::endl; |
310 |
|
} |
311 |
|
} |
312 |
|
} |
313 |
|
|
314 |
265 |
if (!check_assertions.empty()) |
315 |
|
{ |
316 |
199 |
Trace("nl-ext-cm") << "...simple check failed." << std::endl; |
317 |
|
// TODO (#1450) check model for general case |
318 |
199 |
return false; |
319 |
|
} |
320 |
66 |
Trace("nl-ext-cm") << "...simple check succeeded!" << std::endl; |
321 |
66 |
return true; |
322 |
|
} |
323 |
|
|
324 |
632 |
bool NlModel::addCheckModelSubstitution(TNode v, TNode s) |
325 |
|
{ |
326 |
|
// should not substitute the same variable twice |
327 |
1264 |
Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s |
328 |
632 |
<< std::endl; |
329 |
|
// should not set exact bound more than once |
330 |
1896 |
if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) |
331 |
1896 |
!= d_check_model_vars.end()) |
332 |
|
{ |
333 |
|
Trace("nl-ext-model") << "...ERROR: already has value." << std::endl; |
334 |
|
// this should never happen since substitutions should be applied eagerly |
335 |
|
Assert(false); |
336 |
|
return false; |
337 |
|
} |
338 |
|
// if we previously had an approximate bound, the exact bound should be in its |
339 |
|
// range |
340 |
|
std::map<Node, std::pair<Node, Node>>::iterator itb = |
341 |
632 |
d_check_model_bounds.find(v); |
342 |
632 |
if (itb != d_check_model_bounds.end()) |
343 |
|
{ |
344 |
|
if (s.getConst<Rational>() >= itb->second.first.getConst<Rational>() |
345 |
|
|| s.getConst<Rational>() <= itb->second.second.getConst<Rational>()) |
346 |
|
{ |
347 |
|
Trace("nl-ext-model") |
348 |
|
<< "...ERROR: already has bound which is out of range." << std::endl; |
349 |
|
return false; |
350 |
|
} |
351 |
|
} |
352 |
632 |
Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end()) |
353 |
|
<< "We tried to add a substitution where we already had a witness term." |
354 |
|
<< std::endl; |
355 |
1264 |
std::vector<Node> varsTmp; |
356 |
632 |
varsTmp.push_back(v); |
357 |
1264 |
std::vector<Node> subsTmp; |
358 |
632 |
subsTmp.push_back(s); |
359 |
1899 |
for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++) |
360 |
|
{ |
361 |
2534 |
Node ms = d_check_model_subs[i]; |
362 |
2534 |
Node mss = arithSubstitute(ms, varsTmp, subsTmp); |
363 |
1267 |
if (mss != ms) |
364 |
|
{ |
365 |
43 |
mss = Rewriter::rewrite(mss); |
366 |
|
} |
367 |
1267 |
d_check_model_subs[i] = mss; |
368 |
|
} |
369 |
632 |
d_check_model_vars.push_back(v); |
370 |
632 |
d_check_model_subs.push_back(s); |
371 |
632 |
return true; |
372 |
|
} |
373 |
|
|
374 |
398 |
bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u) |
375 |
|
{ |
376 |
796 |
Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " " |
377 |
398 |
<< u << "]" << std::endl; |
378 |
398 |
if (l == u) |
379 |
|
{ |
380 |
|
// bound is exact, can add as substitution |
381 |
|
return addCheckModelSubstitution(v, l); |
382 |
|
} |
383 |
|
// should not set a bound for a value that is exact |
384 |
1194 |
if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) |
385 |
1194 |
!= d_check_model_vars.end()) |
386 |
|
{ |
387 |
|
Trace("nl-ext-model") |
388 |
|
<< "...ERROR: setting bound for variable that already has exact value." |
389 |
|
<< std::endl; |
390 |
|
Assert(false); |
391 |
|
return false; |
392 |
|
} |
393 |
398 |
Assert(l.isConst()); |
394 |
398 |
Assert(u.isConst()); |
395 |
398 |
Assert(l.getConst<Rational>() <= u.getConst<Rational>()); |
396 |
398 |
d_check_model_bounds[v] = std::pair<Node, Node>(l, u); |
397 |
398 |
if (Trace.isOn("nl-ext-cm")) |
398 |
|
{ |
399 |
|
Trace("nl-ext-cm") << "check-model-bound : approximate : "; |
400 |
|
printRationalApprox("nl-ext-cm", l); |
401 |
|
Trace("nl-ext-cm") << " <= " << v << " <= "; |
402 |
|
printRationalApprox("nl-ext-cm", u); |
403 |
|
Trace("nl-ext-cm") << std::endl; |
404 |
|
} |
405 |
398 |
return true; |
406 |
|
} |
407 |
|
|
408 |
7 |
bool NlModel::addCheckModelWitness(TNode v, TNode w) |
409 |
|
{ |
410 |
14 |
Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w |
411 |
7 |
<< std::endl; |
412 |
|
// should not set a witness for a value that is already set |
413 |
21 |
if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) |
414 |
21 |
!= d_check_model_vars.end()) |
415 |
|
{ |
416 |
|
Trace("nl-ext-model") << "...ERROR: setting witness for variable that " |
417 |
|
"already has a constant value." |
418 |
|
<< std::endl; |
419 |
|
Assert(false); |
420 |
|
return false; |
421 |
|
} |
422 |
7 |
d_check_model_witnesses.emplace(v, w); |
423 |
7 |
return true; |
424 |
|
} |
425 |
|
|
426 |
784 |
bool NlModel::hasCheckModelAssignment(Node v) const |
427 |
|
{ |
428 |
784 |
if (d_check_model_bounds.find(v) != d_check_model_bounds.end()) |
429 |
|
{ |
430 |
10 |
return true; |
431 |
|
} |
432 |
774 |
if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end()) |
433 |
|
{ |
434 |
|
return true; |
435 |
|
} |
436 |
1548 |
return std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) |
437 |
2322 |
!= d_check_model_vars.end(); |
438 |
|
} |
439 |
|
|
440 |
271 |
void NlModel::setUsedApproximate() { d_used_approx = true; } |
441 |
|
|
442 |
16 |
bool NlModel::usedApproximate() const { return d_used_approx; } |
443 |
|
|
444 |
642 |
bool NlModel::solveEqualitySimple(Node eq, |
445 |
|
unsigned d, |
446 |
|
std::vector<NlLemma>& lemmas) |
447 |
|
{ |
448 |
1284 |
Node seq = eq; |
449 |
642 |
if (!d_check_model_vars.empty()) |
450 |
|
{ |
451 |
464 |
seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs); |
452 |
464 |
seq = Rewriter::rewrite(seq); |
453 |
464 |
if (seq.isConst()) |
454 |
|
{ |
455 |
182 |
if (seq.getConst<bool>()) |
456 |
|
{ |
457 |
|
// already true |
458 |
182 |
d_check_model_solved[eq] = Node::null(); |
459 |
182 |
return true; |
460 |
|
} |
461 |
|
return false; |
462 |
|
} |
463 |
|
} |
464 |
460 |
Trace("nl-ext-cms") << "simple solve equality " << seq << "..." << std::endl; |
465 |
460 |
Assert(seq.getKind() == EQUAL); |
466 |
920 |
std::map<Node, Node> msum; |
467 |
460 |
if (!ArithMSum::getMonomialSumLit(seq, msum)) |
468 |
|
{ |
469 |
|
Trace("nl-ext-cms") << "...fail, could not determine monomial sum." |
470 |
|
<< std::endl; |
471 |
|
return false; |
472 |
|
} |
473 |
460 |
bool is_valid = true; |
474 |
|
// the variable we will solve a quadratic equation for |
475 |
920 |
Node var; |
476 |
920 |
Node a = d_zero; |
477 |
920 |
Node b = d_zero; |
478 |
920 |
Node c = d_zero; |
479 |
460 |
NodeManager* nm = NodeManager::currentNM(); |
480 |
|
// the list of variables that occur as a monomial in msum, and whose value |
481 |
|
// is so far unconstrained in the model. |
482 |
920 |
std::unordered_set<Node> unc_vars; |
483 |
|
// the list of variables that occur as a factor in a monomial, and whose |
484 |
|
// value is so far unconstrained in the model. |
485 |
920 |
std::unordered_set<Node> unc_vars_factor; |
486 |
1331 |
for (std::pair<const Node, Node>& m : msum) |
487 |
|
{ |
488 |
1742 |
Node v = m.first; |
489 |
1742 |
Node coeff = m.second.isNull() ? d_one : m.second; |
490 |
871 |
if (v.isNull()) |
491 |
|
{ |
492 |
218 |
c = coeff; |
493 |
|
} |
494 |
653 |
else if (v.getKind() == NONLINEAR_MULT) |
495 |
|
{ |
496 |
398 |
if (v.getNumChildren() == 2 && v[0].isVar() && v[0] == v[1] |
497 |
306 |
&& (var.isNull() || var == v[0])) |
498 |
|
{ |
499 |
|
// may solve quadratic |
500 |
36 |
a = coeff; |
501 |
36 |
var = v[0]; |
502 |
|
} |
503 |
|
else |
504 |
|
{ |
505 |
97 |
is_valid = false; |
506 |
194 |
Trace("nl-ext-cms-debug") |
507 |
97 |
<< "...invalid due to non-linear monomial " << v << std::endl; |
508 |
|
// may wish to set an exact bound for a factor and repeat |
509 |
292 |
for (const Node& vc : v) |
510 |
|
{ |
511 |
195 |
unc_vars_factor.insert(vc); |
512 |
|
} |
513 |
|
} |
514 |
|
} |
515 |
520 |
else if (!v.isVar() || (!var.isNull() && var != v)) |
516 |
|
{ |
517 |
576 |
Trace("nl-ext-cms-debug") |
518 |
288 |
<< "...invalid due to factor " << v << std::endl; |
519 |
|
// cannot solve multivariate |
520 |
288 |
if (is_valid) |
521 |
|
{ |
522 |
206 |
is_valid = false; |
523 |
|
// if b is non-zero, then var is also an unconstrained variable |
524 |
206 |
if (b != d_zero) |
525 |
|
{ |
526 |
55 |
unc_vars.insert(var); |
527 |
55 |
unc_vars_factor.insert(var); |
528 |
|
} |
529 |
|
} |
530 |
|
// if v is unconstrained, we may turn this equality into a substitution |
531 |
288 |
unc_vars.insert(v); |
532 |
288 |
unc_vars_factor.insert(v); |
533 |
|
} |
534 |
|
else |
535 |
|
{ |
536 |
|
// set the variable to solve for |
537 |
232 |
b = coeff; |
538 |
232 |
var = v; |
539 |
|
} |
540 |
|
} |
541 |
460 |
if (!is_valid) |
542 |
|
{ |
543 |
|
// see if we can solve for a variable? |
544 |
548 |
for (const Node& uv : unc_vars) |
545 |
|
{ |
546 |
312 |
Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl; |
547 |
|
// cannot already have a bound |
548 |
312 |
if (uv.isVar() && !hasCheckModelAssignment(uv)) |
549 |
|
{ |
550 |
63 |
Node slv; |
551 |
63 |
Node veqc; |
552 |
63 |
if (ArithMSum::isolate(uv, msum, veqc, slv, EQUAL) != 0) |
553 |
|
{ |
554 |
63 |
Assert(!slv.isNull()); |
555 |
|
// Currently do not support substitution-with-coefficients. |
556 |
|
// We also ensure types are correct here, which avoids substituting |
557 |
|
// a term of non-integer type for a variable of integer type. |
558 |
189 |
if (veqc.isNull() && !expr::hasSubterm(slv, uv) |
559 |
189 |
&& slv.getType().isSubtypeOf(uv.getType())) |
560 |
|
{ |
561 |
126 |
Trace("nl-ext-cm") |
562 |
63 |
<< "check-model-subs : " << uv << " -> " << slv << std::endl; |
563 |
63 |
bool ret = addCheckModelSubstitution(uv, slv); |
564 |
63 |
if (ret) |
565 |
|
{ |
566 |
126 |
Trace("nl-ext-cms") << "...success, model substitution " << uv |
567 |
63 |
<< " -> " << slv << std::endl; |
568 |
63 |
d_check_model_solved[eq] = uv; |
569 |
|
} |
570 |
63 |
return ret; |
571 |
|
} |
572 |
|
} |
573 |
|
} |
574 |
|
} |
575 |
|
// see if we can assign a variable to a constant |
576 |
467 |
for (const Node& uvf : unc_vars_factor) |
577 |
|
{ |
578 |
324 |
Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl; |
579 |
|
// cannot already have a bound |
580 |
324 |
if (uvf.isVar() && !hasCheckModelAssignment(uvf)) |
581 |
|
{ |
582 |
186 |
Node uvfv = computeConcreteModelValue(uvf); |
583 |
93 |
if (Trace.isOn("nl-ext-cm")) |
584 |
|
{ |
585 |
|
Trace("nl-ext-cm") << "check-model-bound : exact : " << uvf << " = "; |
586 |
|
printRationalApprox("nl-ext-cm", uvfv); |
587 |
|
Trace("nl-ext-cm") << std::endl; |
588 |
|
} |
589 |
93 |
bool ret = addCheckModelSubstitution(uvf, uvfv); |
590 |
|
// recurse |
591 |
93 |
return ret ? solveEqualitySimple(eq, d, lemmas) : false; |
592 |
|
} |
593 |
|
} |
594 |
286 |
Trace("nl-ext-cms") << "...fail due to constrained invalid terms." |
595 |
143 |
<< std::endl; |
596 |
143 |
return false; |
597 |
|
} |
598 |
161 |
else if (var.isNull() || var.getType().isInteger()) |
599 |
|
{ |
600 |
|
// cannot solve quadratic equations for integer variables |
601 |
16 |
Trace("nl-ext-cms") << "...fail due to variable to solve for." << std::endl; |
602 |
16 |
return false; |
603 |
|
} |
604 |
|
|
605 |
|
// we are linear, it is simple |
606 |
145 |
if (a == d_zero) |
607 |
|
{ |
608 |
135 |
if (b == d_zero) |
609 |
|
{ |
610 |
|
Trace("nl-ext-cms") << "...fail due to zero a/b." << std::endl; |
611 |
|
Assert(false); |
612 |
|
return false; |
613 |
|
} |
614 |
270 |
Node val = nm->mkConst(-c.getConst<Rational>() / b.getConst<Rational>()); |
615 |
135 |
if (Trace.isOn("nl-ext-cm")) |
616 |
|
{ |
617 |
|
Trace("nl-ext-cm") << "check-model-bound : exact : " << var << " = "; |
618 |
|
printRationalApprox("nl-ext-cm", val); |
619 |
|
Trace("nl-ext-cm") << std::endl; |
620 |
|
} |
621 |
135 |
bool ret = addCheckModelSubstitution(var, val); |
622 |
135 |
if (ret) |
623 |
|
{ |
624 |
135 |
Trace("nl-ext-cms") << "...success, solved linear." << std::endl; |
625 |
135 |
d_check_model_solved[eq] = var; |
626 |
|
} |
627 |
135 |
return ret; |
628 |
|
} |
629 |
10 |
Trace("nl-ext-quad") << "Solve quadratic : " << seq << std::endl; |
630 |
10 |
Trace("nl-ext-quad") << " a : " << a << std::endl; |
631 |
10 |
Trace("nl-ext-quad") << " b : " << b << std::endl; |
632 |
10 |
Trace("nl-ext-quad") << " c : " << c << std::endl; |
633 |
20 |
Node two_a = nm->mkNode(MULT, d_two, a); |
634 |
10 |
two_a = Rewriter::rewrite(two_a); |
635 |
|
Node sqrt_val = nm->mkNode( |
636 |
20 |
MINUS, nm->mkNode(MULT, b, b), nm->mkNode(MULT, d_two, two_a, c)); |
637 |
10 |
sqrt_val = Rewriter::rewrite(sqrt_val); |
638 |
10 |
Trace("nl-ext-quad") << "Will approximate sqrt " << sqrt_val << std::endl; |
639 |
10 |
Assert(sqrt_val.isConst()); |
640 |
|
// if it is negative, then we are in conflict |
641 |
10 |
if (sqrt_val.getConst<Rational>().sgn() == -1) |
642 |
|
{ |
643 |
|
Node conf = seq.negate(); |
644 |
|
Trace("nl-ext-lemma") << "NlModel::Lemma : quadratic no root : " << conf |
645 |
|
<< std::endl; |
646 |
|
lemmas.emplace_back(InferenceId::ARITH_NL_CM_QUADRATIC_EQ, conf); |
647 |
|
Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl; |
648 |
|
return false; |
649 |
|
} |
650 |
10 |
if (hasCheckModelAssignment(var)) |
651 |
|
{ |
652 |
|
Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for." |
653 |
|
<< std::endl; |
654 |
|
// two quadratic equations for same variable, give up |
655 |
|
return false; |
656 |
|
} |
657 |
|
// approximate the square root of sqrt_val |
658 |
20 |
Node l, u; |
659 |
10 |
if (!getApproximateSqrt(sqrt_val, l, u, 15 + d)) |
660 |
|
{ |
661 |
|
Trace("nl-ext-cms") << "...fail, could not approximate sqrt." << std::endl; |
662 |
|
return false; |
663 |
|
} |
664 |
10 |
d_used_approx = true; |
665 |
20 |
Trace("nl-ext-quad") << "...got " << l << " <= sqrt(" << sqrt_val |
666 |
10 |
<< ") <= " << u << std::endl; |
667 |
20 |
Node negb = nm->mkConst(-b.getConst<Rational>()); |
668 |
20 |
Node coeffa = nm->mkConst(Rational(1) / two_a.getConst<Rational>()); |
669 |
|
// two possible bound regions |
670 |
20 |
Node bounds[2][2]; |
671 |
20 |
Node diff_bound[2]; |
672 |
20 |
Node m_var = computeConcreteModelValue(var); |
673 |
10 |
Assert(m_var.isConst()); |
674 |
30 |
for (unsigned r = 0; r < 2; r++) |
675 |
|
{ |
676 |
60 |
for (unsigned b2 = 0; b2 < 2; b2++) |
677 |
|
{ |
678 |
80 |
Node val = b2 == 0 ? l : u; |
679 |
|
// (-b +- approx_sqrt( b^2 - 4ac ))/2a |
680 |
|
Node approx = nm->mkNode( |
681 |
80 |
MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val)); |
682 |
40 |
approx = Rewriter::rewrite(approx); |
683 |
40 |
bounds[r][b2] = approx; |
684 |
40 |
Assert(approx.isConst()); |
685 |
|
} |
686 |
20 |
if (bounds[r][0].getConst<Rational>() > bounds[r][1].getConst<Rational>()) |
687 |
|
{ |
688 |
|
// ensure bound is (lower, upper) |
689 |
20 |
Node tmp = bounds[r][0]; |
690 |
10 |
bounds[r][0] = bounds[r][1]; |
691 |
10 |
bounds[r][1] = tmp; |
692 |
|
} |
693 |
|
Node diff = |
694 |
|
nm->mkNode(MINUS, |
695 |
|
m_var, |
696 |
80 |
nm->mkNode(MULT, |
697 |
40 |
nm->mkConst(Rational(1) / Rational(2)), |
698 |
80 |
nm->mkNode(PLUS, bounds[r][0], bounds[r][1]))); |
699 |
20 |
if (Trace.isOn("nl-ext-cm-debug")) |
700 |
|
{ |
701 |
|
Trace("nl-ext-cm-debug") << "Bound option #" << r << " : "; |
702 |
|
printRationalApprox("nl-ext-cm-debug", bounds[r][0]); |
703 |
|
Trace("nl-ext-cm-debug") << "..."; |
704 |
|
printRationalApprox("nl-ext-cm-debug", bounds[r][1]); |
705 |
|
Trace("nl-ext-cm-debug") << std::endl; |
706 |
|
} |
707 |
20 |
diff = Rewriter::rewrite(diff); |
708 |
20 |
Assert(diff.isConst()); |
709 |
20 |
diff = nm->mkConst(diff.getConst<Rational>().abs()); |
710 |
20 |
diff_bound[r] = diff; |
711 |
20 |
if (Trace.isOn("nl-ext-cm-debug")) |
712 |
|
{ |
713 |
|
Trace("nl-ext-cm-debug") << "...diff from model value ("; |
714 |
|
printRationalApprox("nl-ext-cm-debug", m_var); |
715 |
|
Trace("nl-ext-cm-debug") << ") is "; |
716 |
|
printRationalApprox("nl-ext-cm-debug", diff_bound[r]); |
717 |
|
Trace("nl-ext-cm-debug") << std::endl; |
718 |
|
} |
719 |
|
} |
720 |
|
// take the one that var is closer to in the model |
721 |
20 |
Node cmp = nm->mkNode(GEQ, diff_bound[0], diff_bound[1]); |
722 |
10 |
cmp = Rewriter::rewrite(cmp); |
723 |
10 |
Assert(cmp.isConst()); |
724 |
10 |
unsigned r_use_index = cmp == d_true ? 1 : 0; |
725 |
10 |
if (Trace.isOn("nl-ext-cm")) |
726 |
|
{ |
727 |
|
Trace("nl-ext-cm") << "check-model-bound : approximate (sqrt) : "; |
728 |
|
printRationalApprox("nl-ext-cm", bounds[r_use_index][0]); |
729 |
|
Trace("nl-ext-cm") << " <= " << var << " <= "; |
730 |
|
printRationalApprox("nl-ext-cm", bounds[r_use_index][1]); |
731 |
|
Trace("nl-ext-cm") << std::endl; |
732 |
|
} |
733 |
|
bool ret = |
734 |
10 |
addCheckModelBound(var, bounds[r_use_index][0], bounds[r_use_index][1]); |
735 |
10 |
if (ret) |
736 |
|
{ |
737 |
10 |
d_check_model_solved[eq] = var; |
738 |
10 |
Trace("nl-ext-cms") << "...success, solved quadratic." << std::endl; |
739 |
|
} |
740 |
10 |
return ret; |
741 |
|
} |
742 |
|
|
743 |
3670 |
bool NlModel::simpleCheckModelLit(Node lit) |
744 |
|
{ |
745 |
7340 |
Trace("nl-ext-cms") << "*** Simple check-model lit for " << lit << "..." |
746 |
3670 |
<< std::endl; |
747 |
3670 |
if (lit.isConst()) |
748 |
|
{ |
749 |
1484 |
Trace("nl-ext-cms") << " return constant." << std::endl; |
750 |
1484 |
return lit.getConst<bool>(); |
751 |
|
} |
752 |
2186 |
NodeManager* nm = NodeManager::currentNM(); |
753 |
2186 |
bool pol = lit.getKind() != kind::NOT; |
754 |
4372 |
Node atom = lit.getKind() == kind::NOT ? lit[0] : lit; |
755 |
|
|
756 |
2186 |
if (atom.getKind() == EQUAL) |
757 |
|
{ |
758 |
|
// x = a is ( x >= a ^ x <= a ) |
759 |
604 |
for (unsigned i = 0; i < 2; i++) |
760 |
|
{ |
761 |
850 |
Node lit2 = nm->mkNode(GEQ, atom[i], atom[1 - i]); |
762 |
604 |
if (!pol) |
763 |
|
{ |
764 |
524 |
lit2 = lit2.negate(); |
765 |
|
} |
766 |
604 |
lit2 = Rewriter::rewrite(lit2); |
767 |
604 |
bool success = simpleCheckModelLit(lit2); |
768 |
604 |
if (success != pol) |
769 |
|
{ |
770 |
|
// false != true -> one conjunct of equality is false, we fail |
771 |
|
// true != false -> one disjunct of disequality is true, we succeed |
772 |
358 |
return success; |
773 |
|
} |
774 |
|
} |
775 |
|
// both checks passed and polarity is true, or both checks failed and |
776 |
|
// polarity is false |
777 |
|
return pol; |
778 |
|
} |
779 |
1828 |
else if (atom.getKind() != GEQ) |
780 |
|
{ |
781 |
|
Trace("nl-ext-cms") << " failed due to unknown literal." << std::endl; |
782 |
|
return false; |
783 |
|
} |
784 |
|
// get the monomial sum |
785 |
3656 |
std::map<Node, Node> msum; |
786 |
1828 |
if (!ArithMSum::getMonomialSumLit(atom, msum)) |
787 |
|
{ |
788 |
|
Trace("nl-ext-cms") << " failed due to get msum." << std::endl; |
789 |
|
return false; |
790 |
|
} |
791 |
|
// simple interval analysis |
792 |
1828 |
if (simpleCheckModelMsum(msum, pol)) |
793 |
|
{ |
794 |
1265 |
return true; |
795 |
|
} |
796 |
|
// can also try reasoning about univariate quadratic equations |
797 |
1126 |
Trace("nl-ext-cms-debug") |
798 |
563 |
<< "* Try univariate quadratic analysis..." << std::endl; |
799 |
1126 |
std::vector<Node> vs_invalid; |
800 |
1126 |
std::unordered_set<Node> vs; |
801 |
1126 |
std::map<Node, Node> v_a; |
802 |
1126 |
std::map<Node, Node> v_b; |
803 |
|
// get coefficients... |
804 |
1674 |
for (std::pair<const Node, Node>& m : msum) |
805 |
|
{ |
806 |
2222 |
Node v = m.first; |
807 |
1111 |
if (!v.isNull()) |
808 |
|
{ |
809 |
676 |
if (v.isVar()) |
810 |
|
{ |
811 |
77 |
v_b[v] = m.second.isNull() ? d_one : m.second; |
812 |
77 |
vs.insert(v); |
813 |
|
} |
814 |
1267 |
else if (v.getKind() == NONLINEAR_MULT && v.getNumChildren() == 2 |
815 |
1267 |
&& v[0] == v[1] && v[0].isVar()) |
816 |
|
{ |
817 |
69 |
v_a[v[0]] = m.second.isNull() ? d_one : m.second; |
818 |
69 |
vs.insert(v[0]); |
819 |
|
} |
820 |
|
else |
821 |
|
{ |
822 |
530 |
vs_invalid.push_back(v); |
823 |
|
} |
824 |
|
} |
825 |
|
} |
826 |
|
// solve the valid variables... |
827 |
563 |
Node invalid_vsum = vs_invalid.empty() ? d_zero |
828 |
450 |
: (vs_invalid.size() == 1 |
829 |
370 |
? vs_invalid[0] |
830 |
1946 |
: nm->mkNode(PLUS, vs_invalid)); |
831 |
|
// substitution to try |
832 |
1126 |
std::vector<Node> qvars; |
833 |
1126 |
std::vector<Node> qsubs; |
834 |
676 |
for (const Node& v : vs) |
835 |
|
{ |
836 |
|
// is it a valid variable? |
837 |
|
std::map<Node, std::pair<Node, Node>>::iterator bit = |
838 |
113 |
d_check_model_bounds.find(v); |
839 |
113 |
if (!expr::hasSubterm(invalid_vsum, v) && bit != d_check_model_bounds.end()) |
840 |
|
{ |
841 |
113 |
std::map<Node, Node>::iterator it = v_a.find(v); |
842 |
113 |
if (it != v_a.end()) |
843 |
|
{ |
844 |
138 |
Node a = it->second; |
845 |
69 |
Assert(a.isConst()); |
846 |
69 |
int asgn = a.getConst<Rational>().sgn(); |
847 |
69 |
Assert(asgn != 0); |
848 |
138 |
Node t = nm->mkNode(MULT, a, v, v); |
849 |
138 |
Node b = d_zero; |
850 |
69 |
it = v_b.find(v); |
851 |
69 |
if (it != v_b.end()) |
852 |
|
{ |
853 |
33 |
b = it->second; |
854 |
33 |
t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v)); |
855 |
|
} |
856 |
69 |
t = Rewriter::rewrite(t); |
857 |
138 |
Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic " |
858 |
69 |
<< t << "..." << std::endl; |
859 |
69 |
Trace("nl-ext-cms-debug") << " a = " << a << std::endl; |
860 |
69 |
Trace("nl-ext-cms-debug") << " b = " << b << std::endl; |
861 |
|
// find maximal/minimal value on the interval |
862 |
|
Node apex = nm->mkNode( |
863 |
138 |
DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a)); |
864 |
69 |
apex = Rewriter::rewrite(apex); |
865 |
69 |
Assert(apex.isConst()); |
866 |
|
// for lower, upper, whether we are greater than the apex |
867 |
|
bool cmp[2]; |
868 |
138 |
Node boundn[2]; |
869 |
207 |
for (unsigned r = 0; r < 2; r++) |
870 |
|
{ |
871 |
138 |
boundn[r] = r == 0 ? bit->second.first : bit->second.second; |
872 |
276 |
Node cmpn = nm->mkNode(GT, boundn[r], apex); |
873 |
138 |
cmpn = Rewriter::rewrite(cmpn); |
874 |
138 |
Assert(cmpn.isConst()); |
875 |
138 |
cmp[r] = cmpn.getConst<bool>(); |
876 |
|
} |
877 |
69 |
Trace("nl-ext-cms-debug") << " apex " << apex << std::endl; |
878 |
138 |
Trace("nl-ext-cms-debug") |
879 |
69 |
<< " lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl; |
880 |
138 |
Trace("nl-ext-cms-debug") |
881 |
69 |
<< " upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl; |
882 |
69 |
Assert(boundn[0].getConst<Rational>() |
883 |
|
<= boundn[1].getConst<Rational>()); |
884 |
138 |
Node s; |
885 |
69 |
qvars.push_back(v); |
886 |
69 |
if (cmp[0] != cmp[1]) |
887 |
|
{ |
888 |
|
Assert(!cmp[0] && cmp[1]); |
889 |
|
// does the sign match the bound? |
890 |
|
if ((asgn == 1) == pol) |
891 |
|
{ |
892 |
|
// the apex is the max/min value |
893 |
|
s = apex; |
894 |
|
Trace("nl-ext-cms-debug") << " ...set to apex." << std::endl; |
895 |
|
} |
896 |
|
else |
897 |
|
{ |
898 |
|
// it is one of the endpoints, plug in and compare |
899 |
|
Node tcmpn[2]; |
900 |
|
for (unsigned r = 0; r < 2; r++) |
901 |
|
{ |
902 |
|
qsubs.push_back(boundn[r]); |
903 |
|
Node ts = arithSubstitute(t, qvars, qsubs); |
904 |
|
tcmpn[r] = Rewriter::rewrite(ts); |
905 |
|
qsubs.pop_back(); |
906 |
|
} |
907 |
|
Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]); |
908 |
|
Trace("nl-ext-cms-debug") |
909 |
|
<< " ...both sides of apex, compare " << tcmp << std::endl; |
910 |
|
tcmp = Rewriter::rewrite(tcmp); |
911 |
|
Assert(tcmp.isConst()); |
912 |
|
unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0; |
913 |
|
Trace("nl-ext-cms-debug") |
914 |
|
<< " ...set to " << (bindex_use == 1 ? "upper" : "lower") |
915 |
|
<< std::endl; |
916 |
|
s = boundn[bindex_use]; |
917 |
|
} |
918 |
|
} |
919 |
|
else |
920 |
|
{ |
921 |
|
// both to one side of the apex |
922 |
|
// we figure out which bound to use (lower or upper) based on |
923 |
|
// three factors: |
924 |
|
// (1) whether a's sign is positive, |
925 |
|
// (2) whether we are greater than the apex of the parabola, |
926 |
|
// (3) the polarity of the constraint, i.e. >= or <=. |
927 |
|
// there are 8 cases of these factors, which we test here. |
928 |
69 |
unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1; |
929 |
138 |
Trace("nl-ext-cms-debug") |
930 |
69 |
<< " ...set to " << (bindex_use == 1 ? "upper" : "lower") |
931 |
69 |
<< std::endl; |
932 |
69 |
s = boundn[bindex_use]; |
933 |
|
} |
934 |
69 |
Assert(!s.isNull()); |
935 |
69 |
qsubs.push_back(s); |
936 |
138 |
Trace("nl-ext-cms") << "* set bound based on quadratic : " << v |
937 |
69 |
<< " -> " << s << std::endl; |
938 |
|
} |
939 |
|
} |
940 |
|
} |
941 |
563 |
if (!qvars.empty()) |
942 |
|
{ |
943 |
69 |
Assert(qvars.size() == qsubs.size()); |
944 |
138 |
Node slit = arithSubstitute(lit, qvars, qsubs); |
945 |
69 |
slit = Rewriter::rewrite(slit); |
946 |
69 |
return simpleCheckModelLit(slit); |
947 |
|
} |
948 |
494 |
return false; |
949 |
|
} |
950 |
|
|
951 |
1828 |
bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol) |
952 |
|
{ |
953 |
1828 |
Trace("nl-ext-cms-debug") << "* Try simple interval analysis..." << std::endl; |
954 |
1828 |
NodeManager* nm = NodeManager::currentNM(); |
955 |
|
// map from transcendental functions to whether they were set to lower |
956 |
|
// bound |
957 |
1828 |
bool simpleSuccess = true; |
958 |
3656 |
std::map<Node, bool> set_bound; |
959 |
3656 |
std::vector<Node> sum_bound; |
960 |
5482 |
for (const std::pair<const Node, Node>& m : msum) |
961 |
|
{ |
962 |
7341 |
Node v = m.first; |
963 |
3687 |
if (v.isNull()) |
964 |
|
{ |
965 |
1564 |
sum_bound.push_back(m.second.isNull() ? d_one : m.second); |
966 |
|
} |
967 |
|
else |
968 |
|
{ |
969 |
2123 |
Trace("nl-ext-cms-debug") << "- monomial : " << v << std::endl; |
970 |
|
// --- whether we should set a lower bound for this monomial |
971 |
|
bool set_lower = |
972 |
2123 |
(m.second.isNull() || m.second.getConst<Rational>().sgn() == 1) |
973 |
2123 |
== pol; |
974 |
4246 |
Trace("nl-ext-cms-debug") |
975 |
2123 |
<< "set bound to " << (set_lower ? "lower" : "upper") << std::endl; |
976 |
|
|
977 |
|
// --- Collect variables and factors in v |
978 |
4213 |
std::vector<Node> vars; |
979 |
4213 |
std::vector<unsigned> factors; |
980 |
2123 |
if (v.getKind() == NONLINEAR_MULT) |
981 |
|
{ |
982 |
117 |
unsigned last_start = 0; |
983 |
351 |
for (unsigned i = 0, nchildren = v.getNumChildren(); i < nchildren; i++) |
984 |
|
{ |
985 |
|
// are we at the end? |
986 |
234 |
if (i + 1 == nchildren || v[i + 1] != v[i]) |
987 |
|
{ |
988 |
117 |
unsigned vfact = 1 + (i - last_start); |
989 |
117 |
last_start = (i + 1); |
990 |
117 |
vars.push_back(v[i]); |
991 |
117 |
factors.push_back(vfact); |
992 |
|
} |
993 |
|
} |
994 |
|
} |
995 |
|
else |
996 |
|
{ |
997 |
2006 |
vars.push_back(v); |
998 |
2006 |
factors.push_back(1); |
999 |
|
} |
1000 |
|
|
1001 |
|
// --- Get the lower and upper bounds and sign information. |
1002 |
|
// Whether we have an (odd) number of negative factors in vars, apart |
1003 |
|
// from the variable at choose_index. |
1004 |
2123 |
bool has_neg_factor = false; |
1005 |
2123 |
int choose_index = -1; |
1006 |
4213 |
std::vector<Node> ls; |
1007 |
4213 |
std::vector<Node> us; |
1008 |
|
// the relevant sign information for variables with odd exponents: |
1009 |
|
// 1: both signs of the interval of this variable are positive, |
1010 |
|
// -1: both signs of the interval of this variable are negative. |
1011 |
4213 |
std::vector<int> signs; |
1012 |
2123 |
Trace("nl-ext-cms-debug") << "get sign information..." << std::endl; |
1013 |
4246 |
for (unsigned i = 0, size = vars.size(); i < size; i++) |
1014 |
|
{ |
1015 |
4246 |
Node vc = vars[i]; |
1016 |
2123 |
unsigned vcfact = factors[i]; |
1017 |
2123 |
if (Trace.isOn("nl-ext-cms-debug")) |
1018 |
|
{ |
1019 |
|
Trace("nl-ext-cms-debug") << "-- " << vc; |
1020 |
|
if (vcfact > 1) |
1021 |
|
{ |
1022 |
|
Trace("nl-ext-cms-debug") << "^" << vcfact; |
1023 |
|
} |
1024 |
|
Trace("nl-ext-cms-debug") << " "; |
1025 |
|
} |
1026 |
|
std::map<Node, std::pair<Node, Node>>::iterator bit = |
1027 |
2123 |
d_check_model_bounds.find(vc); |
1028 |
|
// if there is a model bound for this term |
1029 |
2123 |
if (bit != d_check_model_bounds.end()) |
1030 |
|
{ |
1031 |
4246 |
Node l = bit->second.first; |
1032 |
4246 |
Node u = bit->second.second; |
1033 |
2123 |
ls.push_back(l); |
1034 |
2123 |
us.push_back(u); |
1035 |
2123 |
int vsign = 0; |
1036 |
2123 |
if (vcfact % 2 == 1) |
1037 |
|
{ |
1038 |
2006 |
vsign = 1; |
1039 |
2006 |
int lsgn = l.getConst<Rational>().sgn(); |
1040 |
2006 |
int usgn = u.getConst<Rational>().sgn(); |
1041 |
4012 |
Trace("nl-ext-cms-debug") |
1042 |
2006 |
<< "bound_sign(" << lsgn << "," << usgn << ") "; |
1043 |
2006 |
if (lsgn == -1) |
1044 |
|
{ |
1045 |
345 |
if (usgn < 1) |
1046 |
|
{ |
1047 |
|
// must have a negative factor |
1048 |
345 |
has_neg_factor = !has_neg_factor; |
1049 |
345 |
vsign = -1; |
1050 |
|
} |
1051 |
|
else if (choose_index == -1) |
1052 |
|
{ |
1053 |
|
// set the choose index to this |
1054 |
|
choose_index = i; |
1055 |
|
vsign = 0; |
1056 |
|
} |
1057 |
|
else |
1058 |
|
{ |
1059 |
|
// ambiguous, can't determine the bound |
1060 |
|
Trace("nl-ext-cms") |
1061 |
|
<< " failed due to ambiguious monomial." << std::endl; |
1062 |
|
return false; |
1063 |
|
} |
1064 |
|
} |
1065 |
|
} |
1066 |
2123 |
Trace("nl-ext-cms-debug") << " -> " << vsign << std::endl; |
1067 |
2123 |
signs.push_back(vsign); |
1068 |
|
} |
1069 |
|
else |
1070 |
|
{ |
1071 |
|
Assert(d_check_model_witnesses.find(vc) |
1072 |
|
== d_check_model_witnesses.end()) |
1073 |
|
<< "No variable should be assigned a witness term if we get " |
1074 |
|
"here. " |
1075 |
|
<< vc << " is, though." << std::endl; |
1076 |
|
Trace("nl-ext-cms-debug") << std::endl; |
1077 |
|
Trace("nl-ext-cms") |
1078 |
|
<< " failed due to unknown bound for " << vc << std::endl; |
1079 |
|
// should either assign a model bound or eliminate the variable |
1080 |
|
// via substitution |
1081 |
|
Assert(false); |
1082 |
|
return false; |
1083 |
|
} |
1084 |
|
} |
1085 |
|
// whether we will try to minimize/maximize (-1/1) the absolute value |
1086 |
2123 |
int setAbs = (set_lower == has_neg_factor) ? 1 : -1; |
1087 |
4246 |
Trace("nl-ext-cms-debug") |
1088 |
2123 |
<< "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal") |
1089 |
2123 |
<< std::endl; |
1090 |
|
|
1091 |
4213 |
std::vector<Node> vbs; |
1092 |
2123 |
Trace("nl-ext-cms-debug") << "set bounds..." << std::endl; |
1093 |
4213 |
for (unsigned i = 0, size = vars.size(); i < size; i++) |
1094 |
|
{ |
1095 |
4213 |
Node vc = vars[i]; |
1096 |
2123 |
unsigned vcfact = factors[i]; |
1097 |
4213 |
Node l = ls[i]; |
1098 |
4213 |
Node u = us[i]; |
1099 |
|
bool vc_set_lower; |
1100 |
2123 |
int vcsign = signs[i]; |
1101 |
4246 |
Trace("nl-ext-cms-debug") |
1102 |
2123 |
<< "Bounds for " << vc << " : " << l << ", " << u |
1103 |
2123 |
<< ", sign : " << vcsign << ", factor : " << vcfact << std::endl; |
1104 |
2123 |
if (l == u) |
1105 |
|
{ |
1106 |
|
// by convention, always say it is lower if they are the same |
1107 |
|
vc_set_lower = true; |
1108 |
|
Trace("nl-ext-cms-debug") |
1109 |
|
<< "..." << vc << " equal bound, set to lower" << std::endl; |
1110 |
|
} |
1111 |
|
else |
1112 |
|
{ |
1113 |
2123 |
if (vcfact % 2 == 0) |
1114 |
|
{ |
1115 |
|
// minimize or maximize its absolute value |
1116 |
234 |
Rational la = l.getConst<Rational>().abs(); |
1117 |
234 |
Rational ua = u.getConst<Rational>().abs(); |
1118 |
117 |
if (la == ua) |
1119 |
|
{ |
1120 |
|
// by convention, always say it is lower if abs are the same |
1121 |
|
vc_set_lower = true; |
1122 |
|
Trace("nl-ext-cms-debug") |
1123 |
|
<< "..." << vc << " equal abs, set to lower" << std::endl; |
1124 |
|
} |
1125 |
|
else |
1126 |
|
{ |
1127 |
117 |
vc_set_lower = (la > ua) == (setAbs == 1); |
1128 |
|
} |
1129 |
|
} |
1130 |
2006 |
else if (signs[i] == 0) |
1131 |
|
{ |
1132 |
|
// we choose this index to match the overall set_lower |
1133 |
|
vc_set_lower = set_lower; |
1134 |
|
} |
1135 |
|
else |
1136 |
|
{ |
1137 |
2006 |
vc_set_lower = (signs[i] != setAbs); |
1138 |
|
} |
1139 |
4246 |
Trace("nl-ext-cms-debug") |
1140 |
2123 |
<< "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper") |
1141 |
2123 |
<< std::endl; |
1142 |
|
} |
1143 |
|
// check whether this is a conflicting bound |
1144 |
2123 |
std::map<Node, bool>::iterator itsb = set_bound.find(vc); |
1145 |
2123 |
if (itsb == set_bound.end()) |
1146 |
|
{ |
1147 |
2082 |
set_bound[vc] = vc_set_lower; |
1148 |
|
} |
1149 |
41 |
else if (itsb->second != vc_set_lower) |
1150 |
|
{ |
1151 |
66 |
Trace("nl-ext-cms") |
1152 |
33 |
<< " failed due to conflicting bound for " << vc << std::endl; |
1153 |
33 |
return false; |
1154 |
|
} |
1155 |
|
// must over/under approximate based on vc_set_lower, computed above |
1156 |
4180 |
Node vb = vc_set_lower ? l : u; |
1157 |
4264 |
for (unsigned i2 = 0; i2 < vcfact; i2++) |
1158 |
|
{ |
1159 |
2174 |
vbs.push_back(vb); |
1160 |
|
} |
1161 |
|
} |
1162 |
2090 |
if (!simpleSuccess) |
1163 |
|
{ |
1164 |
|
break; |
1165 |
|
} |
1166 |
4180 |
Node vbound = vbs.size() == 1 ? vbs[0] : nm->mkNode(MULT, vbs); |
1167 |
2090 |
sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound)); |
1168 |
|
} |
1169 |
|
} |
1170 |
|
// if the exact bound was computed via simple analysis above |
1171 |
|
// make the bound |
1172 |
3590 |
Node bound; |
1173 |
1795 |
if (sum_bound.size() > 1) |
1174 |
|
{ |
1175 |
1662 |
bound = nm->mkNode(kind::PLUS, sum_bound); |
1176 |
|
} |
1177 |
133 |
else if (sum_bound.size() == 1) |
1178 |
|
{ |
1179 |
133 |
bound = sum_bound[0]; |
1180 |
|
} |
1181 |
|
else |
1182 |
|
{ |
1183 |
|
bound = d_zero; |
1184 |
|
} |
1185 |
|
// make the comparison |
1186 |
3590 |
Node comp = nm->mkNode(kind::GEQ, bound, d_zero); |
1187 |
1795 |
if (!pol) |
1188 |
|
{ |
1189 |
1197 |
comp = comp.negate(); |
1190 |
|
} |
1191 |
1795 |
Trace("nl-ext-cms") << " comparison is : " << comp << std::endl; |
1192 |
1795 |
comp = Rewriter::rewrite(comp); |
1193 |
1795 |
Assert(comp.isConst()); |
1194 |
1795 |
Trace("nl-ext-cms") << " returned : " << comp << std::endl; |
1195 |
1795 |
return comp == d_true; |
1196 |
|
} |
1197 |
|
|
1198 |
10 |
bool NlModel::getApproximateSqrt(Node c, Node& l, Node& u, unsigned iter) const |
1199 |
|
{ |
1200 |
10 |
Assert(c.isConst()); |
1201 |
10 |
if (c == d_one || c == d_zero) |
1202 |
|
{ |
1203 |
|
l = c; |
1204 |
|
u = c; |
1205 |
|
return true; |
1206 |
|
} |
1207 |
20 |
Rational rc = c.getConst<Rational>(); |
1208 |
|
|
1209 |
20 |
Rational rl = rc < Rational(1) ? rc : Rational(1); |
1210 |
20 |
Rational ru = rc < Rational(1) ? Rational(1) : rc; |
1211 |
10 |
unsigned count = 0; |
1212 |
20 |
Rational half = Rational(1) / Rational(2); |
1213 |
390 |
while (count < iter) |
1214 |
|
{ |
1215 |
380 |
Rational curr = half * (rl + ru); |
1216 |
380 |
Rational curr_sq = curr * curr; |
1217 |
190 |
if (curr_sq == rc) |
1218 |
|
{ |
1219 |
|
rl = curr; |
1220 |
|
ru = curr; |
1221 |
|
break; |
1222 |
|
} |
1223 |
190 |
else if (curr_sq < rc) |
1224 |
|
{ |
1225 |
90 |
rl = curr; |
1226 |
|
} |
1227 |
|
else |
1228 |
|
{ |
1229 |
100 |
ru = curr; |
1230 |
|
} |
1231 |
190 |
count++; |
1232 |
|
} |
1233 |
|
|
1234 |
10 |
NodeManager* nm = NodeManager::currentNM(); |
1235 |
10 |
l = nm->mkConst(rl); |
1236 |
10 |
u = nm->mkConst(ru); |
1237 |
10 |
return true; |
1238 |
|
} |
1239 |
|
|
1240 |
44582 |
void NlModel::printModelValue(const char* c, Node n, unsigned prec) const |
1241 |
|
{ |
1242 |
44582 |
if (Trace.isOn(c)) |
1243 |
|
{ |
1244 |
|
Trace(c) << " " << n << " -> "; |
1245 |
|
for (int i = 1; i >= 0; --i) |
1246 |
|
{ |
1247 |
|
std::map<Node, Node>::const_iterator it = d_mv[i].find(n); |
1248 |
|
Assert(it != d_mv[i].end()); |
1249 |
|
if (it->second.isConst()) |
1250 |
|
{ |
1251 |
|
printRationalApprox(c, it->second, prec); |
1252 |
|
} |
1253 |
|
else |
1254 |
|
{ |
1255 |
|
Trace(c) << "?"; |
1256 |
|
} |
1257 |
|
Trace(c) << (i == 1 ? " [actual: " : " ]"); |
1258 |
|
} |
1259 |
|
Trace(c) << std::endl; |
1260 |
|
} |
1261 |
44582 |
} |
1262 |
|
|
1263 |
283 |
void NlModel::getModelValueRepair( |
1264 |
|
std::map<Node, Node>& arithModel, |
1265 |
|
std::map<Node, std::pair<Node, Node>>& approximations, |
1266 |
|
std::map<Node, Node>& witnesses) |
1267 |
|
{ |
1268 |
283 |
Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl; |
1269 |
|
// If we extended the model with entries x -> 0 for unconstrained values, |
1270 |
|
// we first update the map to the extended one. |
1271 |
283 |
if (d_arithVal.size() > arithModel.size()) |
1272 |
|
{ |
1273 |
4 |
arithModel = d_arithVal; |
1274 |
|
} |
1275 |
|
// Record the approximations we used. This code calls the |
1276 |
|
// recordApproximation method of the model, which overrides the model |
1277 |
|
// values for variables that we solved for, using techniques specific to |
1278 |
|
// this class. |
1279 |
283 |
NodeManager* nm = NodeManager::currentNM(); |
1280 |
27 |
for (const std::pair<const Node, std::pair<Node, Node>>& cb : |
1281 |
283 |
d_check_model_bounds) |
1282 |
|
{ |
1283 |
54 |
Node l = cb.second.first; |
1284 |
54 |
Node u = cb.second.second; |
1285 |
54 |
Node pred; |
1286 |
54 |
Node v = cb.first; |
1287 |
27 |
if (l != u) |
1288 |
|
{ |
1289 |
27 |
pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v)); |
1290 |
27 |
Trace("nl-model") << v << " approximated as " << pred << std::endl; |
1291 |
54 |
Node witness; |
1292 |
27 |
if (options::modelWitnessValue()) |
1293 |
|
{ |
1294 |
|
// witness is the midpoint |
1295 |
|
witness = nm->mkNode( |
1296 |
|
MULT, nm->mkConst(Rational(1, 2)), nm->mkNode(PLUS, l, u)); |
1297 |
|
witness = Rewriter::rewrite(witness); |
1298 |
|
Trace("nl-model") << v << " witness is " << witness << std::endl; |
1299 |
|
} |
1300 |
27 |
approximations[v] = std::pair<Node, Node>(pred, witness); |
1301 |
|
} |
1302 |
|
else |
1303 |
|
{ |
1304 |
|
// overwrite |
1305 |
|
arithModel[v] = l; |
1306 |
|
Trace("nl-model") << v << " exact approximation is " << l << std::endl; |
1307 |
|
} |
1308 |
|
} |
1309 |
290 |
for (const auto& vw : d_check_model_witnesses) |
1310 |
|
{ |
1311 |
7 |
Trace("nl-model") << vw.first << " witness is " << vw.second << std::endl; |
1312 |
7 |
witnesses.emplace(vw.first, vw.second); |
1313 |
|
} |
1314 |
|
// Also record the exact values we used. An exact value can be seen as a |
1315 |
|
// special kind approximation of the form (witness x. x = exact_value). |
1316 |
|
// Notice that the above term gets rewritten such that the choice function |
1317 |
|
// is eliminated. |
1318 |
551 |
for (size_t i = 0, num = d_check_model_vars.size(); i < num; i++) |
1319 |
|
{ |
1320 |
536 |
Node v = d_check_model_vars[i]; |
1321 |
536 |
Node s = d_check_model_subs[i]; |
1322 |
|
// overwrite |
1323 |
268 |
arithModel[v] = s; |
1324 |
268 |
Trace("nl-model") << v << " solved is " << s << std::endl; |
1325 |
|
} |
1326 |
|
|
1327 |
|
// multiplication terms should not be given values; their values are |
1328 |
|
// implied by the monomials that they consist of |
1329 |
566 |
std::vector<Node> amErase; |
1330 |
6302 |
for (const std::pair<const Node, Node>& am : arithModel) |
1331 |
|
{ |
1332 |
6019 |
if (am.first.getKind() == NONLINEAR_MULT) |
1333 |
|
{ |
1334 |
1451 |
amErase.push_back(am.first); |
1335 |
|
} |
1336 |
|
} |
1337 |
1734 |
for (const Node& ae : amErase) |
1338 |
|
{ |
1339 |
1451 |
arithModel.erase(ae); |
1340 |
|
} |
1341 |
283 |
} |
1342 |
|
|
1343 |
|
} // namespace nl |
1344 |
|
} // namespace arith |
1345 |
|
} // namespace theory |
1346 |
29349 |
} // namespace cvc5 |