GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/nl_model.cpp Lines: 578 722 80.1 %
Date: 2021-11-05 Branches: 1382 3465 39.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Gereon Kremer, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Model object for the non-linear extension class.
14
 */
15
16
#include "theory/arith/nl/nl_model.h"
17
18
#include "expr/node_algorithm.h"
19
#include "options/arith_options.h"
20
#include "options/smt_options.h"
21
#include "options/theory_options.h"
22
#include "theory/arith/arith_msum.h"
23
#include "theory/arith/arith_utilities.h"
24
#include "theory/arith/nl/nl_lemma_utils.h"
25
#include "theory/theory_model.h"
26
#include "theory/rewriter.h"
27
28
using namespace cvc5::kind;
29
30
namespace cvc5 {
31
namespace theory {
32
namespace arith {
33
namespace nl {
34
35
9700
NlModel::NlModel() : d_used_approx(false)
36
{
37
9700
  d_true = NodeManager::currentNM()->mkConst(true);
38
9700
  d_false = NodeManager::currentNM()->mkConst(false);
39
9700
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
40
9700
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
41
9700
  d_two = NodeManager::currentNM()->mkConst(Rational(2));
42
9700
}
43
44
11735
NlModel::~NlModel() {}
45
46
4708
void NlModel::reset(TheoryModel* m, const std::map<Node, Node>& arithModel)
47
{
48
4708
  d_model = m;
49
4708
  d_concreteModelCache.clear();
50
4708
  d_abstractModelCache.clear();
51
4708
  d_arithVal = arithModel;
52
4708
}
53
54
4724
void NlModel::resetCheck()
55
{
56
4724
  d_used_approx = false;
57
4724
  d_check_model_solved.clear();
58
4724
  d_check_model_bounds.clear();
59
4724
  d_check_model_witnesses.clear();
60
4724
  d_substitutions.clear();
61
4724
}
62
63
730007
Node NlModel::computeConcreteModelValue(TNode n)
64
{
65
730007
  return computeModelValue(n, true);
66
}
67
68
356274
Node NlModel::computeAbstractModelValue(TNode n)
69
{
70
356274
  return computeModelValue(n, false);
71
}
72
73
3326036
Node NlModel::computeModelValue(TNode n, bool isConcrete)
74
{
75
3326036
  auto& cache = isConcrete ? d_concreteModelCache : d_abstractModelCache;
76
3326036
  if (auto it = cache.find(n); it != cache.end())
77
  {
78
2030377
    return it->second;
79
  }
80
2591318
  Trace("nl-ext-mv-debug") << "computeModelValue " << n
81
1295659
                           << ", isConcrete=" << isConcrete << std::endl;
82
2591318
  Node ret;
83
1295659
  if (n.isConst())
84
  {
85
59859
    ret = n;
86
  }
87
1235800
  else if (!isConcrete && hasLinearModelValue(n, ret))
88
  {
89
    // use model value for abstraction
90
  }
91
1153705
  else if (n.getNumChildren() == 0)
92
  {
93
    // we are interested in the exact value of PI, which cannot be computed.
94
    // hence, we return PI itself when asked for the concrete value.
95
37243
    if (n.getKind() == PI)
96
    {
97
553
      ret = n;
98
    }
99
    else
100
    {
101
36690
      ret = getValueInternal(n);
102
    }
103
  }
104
  else
105
  {
106
    // otherwise, compute true value
107
1116462
    TheoryId ctid = theory::kindToTheoryId(n.getKind());
108
1116462
    if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN)
109
    {
110
      // we directly look up terms not belonging to arithmetic
111
25708
      ret = getValueInternal(n);
112
    }
113
    else
114
    {
115
2181508
      std::vector<Node> children;
116
1090754
      if (n.getMetaKind() == metakind::PARAMETERIZED)
117
      {
118
347
        children.emplace_back(n.getOperator());
119
      }
120
3087236
      for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
121
      {
122
1996482
        children.emplace_back(computeModelValue(n[i], isConcrete));
123
      }
124
1090754
      ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
125
1090754
      ret = Rewriter::rewrite(ret);
126
    }
127
  }
128
2591318
  Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "["
129
1295659
                           << n << "] = " << ret << std::endl;
130
1295659
  cache[n] = ret;
131
1295659
  return ret;
132
}
133
134
113901
int NlModel::compare(TNode i, TNode j, bool isConcrete, bool isAbsolute)
135
{
136
113901
  if (i == j)
137
  {
138
    return 0;
139
  }
140
227802
  Node ci = computeModelValue(i, isConcrete);
141
227802
  Node cj = computeModelValue(j, isConcrete);
142
113901
  if (ci.isConst())
143
  {
144
113901
    if (cj.isConst())
145
    {
146
113901
      return compareValue(ci, cj, isAbsolute);
147
    }
148
    return 1;
149
  }
150
  return cj.isConst() ? -1 : 0;
151
}
152
153
129372
int NlModel::compareValue(TNode i, TNode j, bool isAbsolute) const
154
{
155
129372
  Assert(i.isConst() && j.isConst());
156
129372
  if (i == j)
157
  {
158
22096
    return 0;
159
  }
160
107276
  if (!isAbsolute)
161
  {
162
6730
    return i.getConst<Rational>() < j.getConst<Rational>() ? -1 : 1;
163
  }
164
201092
  Rational iabs = i.getConst<Rational>().abs();
165
201092
  Rational jabs = j.getConst<Rational>().abs();
166
100546
  if (iabs == jabs)
167
  {
168
6406
    return 0;
169
  }
170
94140
  return iabs < jabs ? -1 : 1;
171
}
172
173
270
bool NlModel::checkModel(const std::vector<Node>& assertions,
174
                         unsigned d,
175
                         std::vector<NlLemma>& lemmas)
176
{
177
270
  Trace("nl-ext-cm-debug") << "  solve for equalities..." << std::endl;
178
4451
  for (const Node& atom : assertions)
179
  {
180
    // see if it corresponds to a univariate polynomial equation of degree two
181
4181
    if (atom.getKind() == EQUAL)
182
    {
183
571
      if (!solveEqualitySimple(atom, d, lemmas))
184
      {
185
        // no chance we will satisfy this equality
186
354
        Trace("nl-ext-cm") << "...check-model : failed to solve equality : "
187
177
                           << atom << std::endl;
188
      }
189
    }
190
  }
191
192
  // all remaining variables are constrained to their exact model values
193
540
  Trace("nl-ext-cm-debug") << "  set exact bounds for remaining variables..."
194
270
                           << std::endl;
195
540
  std::unordered_set<TNode> visited;
196
540
  std::vector<TNode> visit;
197
540
  TNode cur;
198
4451
  for (const Node& a : assertions)
199
  {
200
4181
    visit.push_back(a);
201
15806
    do
202
    {
203
19987
      cur = visit.back();
204
19987
      visit.pop_back();
205
19987
      if (visited.find(cur) == visited.end())
206
      {
207
11555
        visited.insert(cur);
208
11555
        if (cur.getType().isReal() && !cur.isConst())
209
        {
210
3234
          Kind k = cur.getKind();
211
5482
          if (k != MULT && k != PLUS && k != NONLINEAR_MULT
212
4125
              && !isTranscendentalKind(k))
213
          {
214
            // if we have not set an approximate bound for it
215
598
            if (!hasAssignment(cur))
216
            {
217
              // set its exact model value in the substitution
218
604
              Node curv = computeConcreteModelValue(cur);
219
302
              if (Trace.isOn("nl-ext-cm"))
220
              {
221
                Trace("nl-ext-cm")
222
                    << "check-model-bound : exact : " << cur << " = ";
223
                printRationalApprox("nl-ext-cm", curv);
224
                Trace("nl-ext-cm") << std::endl;
225
              }
226
302
              bool ret = addSubstitution(cur, curv);
227
302
              AlwaysAssert(ret);
228
            }
229
          }
230
        }
231
27361
        for (const Node& cn : cur)
232
        {
233
15806
          visit.push_back(cn);
234
        }
235
      }
236
19987
    } while (!visit.empty());
237
  }
238
239
270
  Trace("nl-ext-cm-debug") << "  check assertions..." << std::endl;
240
540
  std::vector<Node> check_assertions;
241
4451
  for (const Node& a : assertions)
242
  {
243
4181
    if (d_check_model_solved.find(a) == d_check_model_solved.end())
244
    {
245
7574
      Node av = a;
246
      // apply the substitution to a
247
3787
      if (!d_substitutions.empty())
248
      {
249
3220
        av = Rewriter::rewrite(arithSubstitute(av, d_substitutions));
250
      }
251
      // simple check literal
252
3787
      if (!simpleCheckModelLit(av))
253
      {
254
810
        Trace("nl-ext-cm") << "...check-model : assertion failed : " << a
255
405
                           << std::endl;
256
405
        check_assertions.push_back(av);
257
810
        Trace("nl-ext-cm-debug")
258
405
            << "...check-model : failed assertion, value : " << av << std::endl;
259
      }
260
    }
261
  }
262
263
270
  if (!check_assertions.empty())
264
  {
265
202
    Trace("nl-ext-cm") << "...simple check failed." << std::endl;
266
    // TODO (#1450) check model for general case
267
202
    return false;
268
  }
269
68
  Trace("nl-ext-cm") << "...simple check succeeded!" << std::endl;
270
68
  return true;
271
}
272
273
635
bool NlModel::addSubstitution(TNode v, TNode s)
274
{
275
  // should not substitute the same variable twice
276
1270
  Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s
277
635
                        << std::endl;
278
  // should not set exact bound more than once
279
635
  if (d_substitutions.contains(v))
280
  {
281
    Trace("nl-ext-model") << "...ERROR: already has value." << std::endl;
282
    // this should never happen since substitutions should be applied eagerly
283
    Assert(false);
284
    return false;
285
  }
286
  // if we previously had an approximate bound, the exact bound should be in its
287
  // range
288
  std::map<Node, std::pair<Node, Node>>::iterator itb =
289
635
      d_check_model_bounds.find(v);
290
635
  if (itb != d_check_model_bounds.end())
291
  {
292
    if (s.getConst<Rational>() >= itb->second.first.getConst<Rational>()
293
        || s.getConst<Rational>() <= itb->second.second.getConst<Rational>())
294
    {
295
      Trace("nl-ext-model")
296
          << "...ERROR: already has bound which is out of range." << std::endl;
297
      return false;
298
    }
299
  }
300
635
  Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end())
301
      << "We tried to add a substitution where we already had a witness term."
302
      << std::endl;
303
1270
  Subs tmp;
304
635
  tmp.add(v, s);
305
1825
  for (auto& sub : d_substitutions.d_subs)
306
  {
307
2380
    Node ms = arithSubstitute(sub, tmp);
308
1190
    if (ms != sub)
309
    {
310
43
      sub = Rewriter::rewrite(ms);
311
    }
312
  }
313
635
  d_substitutions.add(v, s);
314
635
  return true;
315
}
316
317
396
bool NlModel::addBound(TNode v, TNode l, TNode u)
318
{
319
792
  Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " "
320
396
                        << u << "]" << std::endl;
321
396
  if (l == u)
322
  {
323
    // bound is exact, can add as substitution
324
2
    return addSubstitution(v, l);
325
  }
326
  // should not set a bound for a value that is exact
327
394
  if (d_substitutions.contains(v))
328
  {
329
    Trace("nl-ext-model")
330
        << "...ERROR: setting bound for variable that already has exact value."
331
        << std::endl;
332
    Assert(false);
333
    return false;
334
  }
335
394
  Assert(l.isConst());
336
394
  Assert(u.isConst());
337
394
  Assert(l.getConst<Rational>() <= u.getConst<Rational>());
338
394
  d_check_model_bounds[v] = std::pair<Node, Node>(l, u);
339
394
  if (Trace.isOn("nl-ext-cm"))
340
  {
341
    Trace("nl-ext-cm") << "check-model-bound : approximate : ";
342
    printRationalApprox("nl-ext-cm", l);
343
    Trace("nl-ext-cm") << " <= " << v << " <= ";
344
    printRationalApprox("nl-ext-cm", u);
345
    Trace("nl-ext-cm") << std::endl;
346
  }
347
394
  return true;
348
}
349
350
7
bool NlModel::addWitness(TNode v, TNode w)
351
{
352
14
  Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w
353
7
                        << std::endl;
354
  // should not set a witness for a value that is already set
355
7
  if (d_substitutions.contains(v))
356
  {
357
    Trace("nl-ext-model") << "...ERROR: setting witness for variable that "
358
                             "already has a constant value."
359
                          << std::endl;
360
    Assert(false);
361
    return false;
362
  }
363
7
  d_check_model_witnesses.emplace(v, w);
364
7
  return true;
365
}
366
367
269
void NlModel::setUsedApproximate() { d_used_approx = true; }
368
369
18
bool NlModel::usedApproximate() const { return d_used_approx; }
370
371
663
bool NlModel::solveEqualitySimple(Node eq,
372
                                  unsigned d,
373
                                  std::vector<NlLemma>& lemmas)
374
{
375
1326
  Node seq = eq;
376
663
  if (!d_substitutions.empty())
377
  {
378
480
    seq = arithSubstitute(eq, d_substitutions);
379
480
    seq = Rewriter::rewrite(seq);
380
480
    if (seq.isConst())
381
    {
382
190
      if (seq.getConst<bool>())
383
      {
384
        // already true
385
190
        d_check_model_solved[eq] = Node::null();
386
190
        return true;
387
      }
388
      return false;
389
    }
390
  }
391
473
  Trace("nl-ext-cms") << "simple solve equality " << seq << "..." << std::endl;
392
473
  Assert(seq.getKind() == EQUAL);
393
946
  std::map<Node, Node> msum;
394
473
  if (!ArithMSum::getMonomialSumLit(seq, msum))
395
  {
396
    Trace("nl-ext-cms") << "...fail, could not determine monomial sum."
397
                        << std::endl;
398
    return false;
399
  }
400
473
  bool is_valid = true;
401
  // the variable we will solve a quadratic equation for
402
946
  Node var;
403
946
  Node a = d_zero;
404
946
  Node b = d_zero;
405
946
  Node c = d_zero;
406
473
  NodeManager* nm = NodeManager::currentNM();
407
  // the list of variables that occur as a monomial in msum, and whose value
408
  // is so far unconstrained in the model.
409
946
  std::unordered_set<Node> unc_vars;
410
  // the list of variables that occur as a factor in a monomial, and whose
411
  // value is so far unconstrained in the model.
412
946
  std::unordered_set<Node> unc_vars_factor;
413
1363
  for (std::pair<const Node, Node>& m : msum)
414
  {
415
1780
    Node v = m.first;
416
1780
    Node coeff = m.second.isNull() ? d_one : m.second;
417
890
    if (v.isNull())
418
    {
419
224
      c = coeff;
420
    }
421
666
    else if (v.getKind() == NONLINEAR_MULT)
422
    {
423
408
      if (v.getNumChildren() == 2 && v[0].isVar() && v[0] == v[1]
424
314
          && (var.isNull() || var == v[0]))
425
      {
426
        // may solve quadratic
427
38
        a = coeff;
428
38
        var = v[0];
429
      }
430
      else
431
      {
432
98
        is_valid = false;
433
196
        Trace("nl-ext-cms-debug")
434
98
            << "...invalid due to non-linear monomial " << v << std::endl;
435
        // may wish to set an exact bound for a factor and repeat
436
294
        for (const Node& vc : v)
437
        {
438
196
          unc_vars_factor.insert(vc);
439
        }
440
      }
441
    }
442
530
    else if (!v.isVar() || (!var.isNull() && var != v))
443
    {
444
612
      Trace("nl-ext-cms-debug")
445
306
          << "...invalid due to factor " << v << std::endl;
446
      // cannot solve multivariate
447
306
      if (is_valid)
448
      {
449
226
        is_valid = false;
450
        // if b is non-zero, then var is also an unconstrained variable
451
226
        if (b != d_zero)
452
        {
453
55
          unc_vars.insert(var);
454
55
          unc_vars_factor.insert(var);
455
        }
456
      }
457
      // if v is unconstrained, we may turn this equality into a substitution
458
306
      unc_vars.insert(v);
459
306
      unc_vars_factor.insert(v);
460
    }
461
    else
462
    {
463
      // set the variable to solve for
464
224
      b = coeff;
465
224
      var = v;
466
    }
467
  }
468
473
  if (!is_valid)
469
  {
470
    // see if we can solve for a variable?
471
583
    for (const Node& uv : unc_vars)
472
    {
473
330
      Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl;
474
      // cannot already have a bound
475
330
      if (uv.isVar() && !hasAssignment(uv))
476
      {
477
63
        Node slv;
478
63
        Node veqc;
479
63
        if (ArithMSum::isolate(uv, msum, veqc, slv, EQUAL) != 0)
480
        {
481
63
          Assert(!slv.isNull());
482
          // Currently do not support substitution-with-coefficients.
483
          // We also ensure types are correct here, which avoids substituting
484
          // a term of non-integer type for a variable of integer type.
485
189
          if (veqc.isNull() && !expr::hasSubterm(slv, uv)
486
189
              && slv.getType().isSubtypeOf(uv.getType()))
487
          {
488
126
            Trace("nl-ext-cm")
489
63
                << "check-model-subs : " << uv << " -> " << slv << std::endl;
490
63
            bool ret = addSubstitution(uv, slv);
491
63
            if (ret)
492
            {
493
126
              Trace("nl-ext-cms") << "...success, model substitution " << uv
494
63
                                  << " -> " << slv << std::endl;
495
63
              d_check_model_solved[eq] = uv;
496
            }
497
63
            return ret;
498
          }
499
        }
500
      }
501
    }
502
    // see if we can assign a variable to a constant
503
500
    for (const Node& uvf : unc_vars_factor)
504
    {
505
339
      Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl;
506
      // cannot already have a bound
507
339
      if (uvf.isVar() && !hasAssignment(uvf))
508
      {
509
184
        Node uvfv = computeConcreteModelValue(uvf);
510
92
        if (Trace.isOn("nl-ext-cm"))
511
        {
512
          Trace("nl-ext-cm") << "check-model-bound : exact : " << uvf << " = ";
513
          printRationalApprox("nl-ext-cm", uvfv);
514
          Trace("nl-ext-cm") << std::endl;
515
        }
516
92
        bool ret = addSubstitution(uvf, uvfv);
517
        // recurse
518
92
        return ret ? solveEqualitySimple(eq, d, lemmas) : false;
519
      }
520
    }
521
322
    Trace("nl-ext-cms") << "...fail due to constrained invalid terms."
522
161
                        << std::endl;
523
161
    return false;
524
  }
525
157
  else if (var.isNull() || var.getType().isInteger())
526
  {
527
    // cannot solve quadratic equations for integer variables
528
16
    Trace("nl-ext-cms") << "...fail due to variable to solve for." << std::endl;
529
16
    return false;
530
  }
531
532
  // we are linear, it is simple
533
141
  if (a == d_zero)
534
  {
535
129
    if (b == d_zero)
536
    {
537
      Trace("nl-ext-cms") << "...fail due to zero a/b." << std::endl;
538
      Assert(false);
539
      return false;
540
    }
541
258
    Node val = nm->mkConst(-c.getConst<Rational>() / b.getConst<Rational>());
542
129
    if (Trace.isOn("nl-ext-cm"))
543
    {
544
      Trace("nl-ext-cm") << "check-model-bound : exact : " << var << " = ";
545
      printRationalApprox("nl-ext-cm", val);
546
      Trace("nl-ext-cm") << std::endl;
547
    }
548
129
    bool ret = addSubstitution(var, val);
549
129
    if (ret)
550
    {
551
129
      Trace("nl-ext-cms") << "...success, solved linear." << std::endl;
552
129
      d_check_model_solved[eq] = var;
553
    }
554
129
    return ret;
555
  }
556
12
  Trace("nl-ext-quad") << "Solve quadratic : " << seq << std::endl;
557
12
  Trace("nl-ext-quad") << "  a : " << a << std::endl;
558
12
  Trace("nl-ext-quad") << "  b : " << b << std::endl;
559
12
  Trace("nl-ext-quad") << "  c : " << c << std::endl;
560
24
  Node two_a = nm->mkNode(MULT, d_two, a);
561
12
  two_a = Rewriter::rewrite(two_a);
562
  Node sqrt_val = nm->mkNode(
563
24
      MINUS, nm->mkNode(MULT, b, b), nm->mkNode(MULT, d_two, two_a, c));
564
12
  sqrt_val = Rewriter::rewrite(sqrt_val);
565
12
  Trace("nl-ext-quad") << "Will approximate sqrt " << sqrt_val << std::endl;
566
12
  Assert(sqrt_val.isConst());
567
  // if it is negative, then we are in conflict
568
12
  if (sqrt_val.getConst<Rational>().sgn() == -1)
569
  {
570
    Node conf = seq.negate();
571
    Trace("nl-ext-lemma") << "NlModel::Lemma : quadratic no root : " << conf
572
                          << std::endl;
573
    lemmas.emplace_back(InferenceId::ARITH_NL_CM_QUADRATIC_EQ, conf);
574
    Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl;
575
    return false;
576
  }
577
12
  if (hasAssignment(var))
578
  {
579
    Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for."
580
                        << std::endl;
581
    // two quadratic equations for same variable, give up
582
    return false;
583
  }
584
  // approximate the square root of sqrt_val
585
24
  Node l, u;
586
12
  if (!getApproximateSqrt(sqrt_val, l, u, 15 + d))
587
  {
588
    Trace("nl-ext-cms") << "...fail, could not approximate sqrt." << std::endl;
589
    return false;
590
  }
591
12
  d_used_approx = true;
592
24
  Trace("nl-ext-quad") << "...got " << l << " <= sqrt(" << sqrt_val
593
12
                       << ") <= " << u << std::endl;
594
24
  Node negb = nm->mkConst(-b.getConst<Rational>());
595
24
  Node coeffa = nm->mkConst(Rational(1) / two_a.getConst<Rational>());
596
  // two possible bound regions
597
24
  Node bounds[2][2];
598
24
  Node diff_bound[2];
599
24
  Node m_var = computeConcreteModelValue(var);
600
12
  Assert(m_var.isConst());
601
36
  for (unsigned r = 0; r < 2; r++)
602
  {
603
72
    for (unsigned b2 = 0; b2 < 2; b2++)
604
    {
605
96
      Node val = b2 == 0 ? l : u;
606
      // (-b +- approx_sqrt( b^2 - 4ac ))/2a
607
      Node approx = nm->mkNode(
608
96
          MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val));
609
48
      approx = Rewriter::rewrite(approx);
610
48
      bounds[r][b2] = approx;
611
48
      Assert(approx.isConst());
612
    }
613
24
    if (bounds[r][0].getConst<Rational>() > bounds[r][1].getConst<Rational>())
614
    {
615
      // ensure bound is (lower, upper)
616
20
      Node tmp = bounds[r][0];
617
10
      bounds[r][0] = bounds[r][1];
618
10
      bounds[r][1] = tmp;
619
    }
620
    Node diff =
621
        nm->mkNode(MINUS,
622
                   m_var,
623
96
                   nm->mkNode(MULT,
624
48
                              nm->mkConst(Rational(1) / Rational(2)),
625
96
                              nm->mkNode(PLUS, bounds[r][0], bounds[r][1])));
626
24
    if (Trace.isOn("nl-ext-cm-debug"))
627
    {
628
      Trace("nl-ext-cm-debug") << "Bound option #" << r << " : ";
629
      printRationalApprox("nl-ext-cm-debug", bounds[r][0]);
630
      Trace("nl-ext-cm-debug") << "...";
631
      printRationalApprox("nl-ext-cm-debug", bounds[r][1]);
632
      Trace("nl-ext-cm-debug") << std::endl;
633
    }
634
24
    diff = Rewriter::rewrite(diff);
635
24
    Assert(diff.isConst());
636
24
    diff = nm->mkConst(diff.getConst<Rational>().abs());
637
24
    diff_bound[r] = diff;
638
24
    if (Trace.isOn("nl-ext-cm-debug"))
639
    {
640
      Trace("nl-ext-cm-debug") << "...diff from model value (";
641
      printRationalApprox("nl-ext-cm-debug", m_var);
642
      Trace("nl-ext-cm-debug") << ") is ";
643
      printRationalApprox("nl-ext-cm-debug", diff_bound[r]);
644
      Trace("nl-ext-cm-debug") << std::endl;
645
    }
646
  }
647
  // take the one that var is closer to in the model
648
24
  Node cmp = nm->mkNode(GEQ, diff_bound[0], diff_bound[1]);
649
12
  cmp = Rewriter::rewrite(cmp);
650
12
  Assert(cmp.isConst());
651
12
  unsigned r_use_index = cmp == d_true ? 1 : 0;
652
12
  if (Trace.isOn("nl-ext-cm"))
653
  {
654
    Trace("nl-ext-cm") << "check-model-bound : approximate (sqrt) : ";
655
    printRationalApprox("nl-ext-cm", bounds[r_use_index][0]);
656
    Trace("nl-ext-cm") << " <= " << var << " <= ";
657
    printRationalApprox("nl-ext-cm", bounds[r_use_index][1]);
658
    Trace("nl-ext-cm") << std::endl;
659
  }
660
12
  bool ret = addBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
661
12
  if (ret)
662
  {
663
12
    d_check_model_solved[eq] = var;
664
12
    Trace("nl-ext-cms") << "...success, solved quadratic." << std::endl;
665
  }
666
12
  return ret;
667
}
668
669
4483
bool NlModel::simpleCheckModelLit(Node lit)
670
{
671
8966
  Trace("nl-ext-cms") << "*** Simple check-model lit for " << lit << "..."
672
4483
                      << std::endl;
673
4483
  if (lit.isConst())
674
  {
675
2035
    Trace("nl-ext-cms") << "  return constant." << std::endl;
676
2035
    return lit.getConst<bool>();
677
  }
678
2448
  NodeManager* nm = NodeManager::currentNM();
679
2448
  bool pol = lit.getKind() != kind::NOT;
680
4896
  Node atom = lit.getKind() == kind::NOT ? lit[0] : lit;
681
682
2448
  if (atom.getKind() == EQUAL)
683
  {
684
    // x = a is ( x >= a ^ x <= a )
685
605
    for (unsigned i = 0; i < 2; i++)
686
    {
687
852
      Node lit2 = nm->mkNode(GEQ, atom[i], atom[1 - i]);
688
605
      if (!pol)
689
      {
690
525
        lit2 = lit2.negate();
691
      }
692
605
      lit2 = Rewriter::rewrite(lit2);
693
605
      bool success = simpleCheckModelLit(lit2);
694
605
      if (success != pol)
695
      {
696
        // false != true -> one conjunct of equality is false, we fail
697
        // true != false -> one disjunct of disequality is true, we succeed
698
358
        return success;
699
      }
700
    }
701
    // both checks passed and polarity is true, or both checks failed and
702
    // polarity is false
703
    return pol;
704
  }
705
2090
  else if (atom.getKind() != GEQ)
706
  {
707
    Trace("nl-ext-cms") << "  failed due to unknown literal." << std::endl;
708
    return false;
709
  }
710
  // get the monomial sum
711
4180
  std::map<Node, Node> msum;
712
2090
  if (!ArithMSum::getMonomialSumLit(atom, msum))
713
  {
714
    Trace("nl-ext-cms") << "  failed due to get msum." << std::endl;
715
    return false;
716
  }
717
  // simple interval analysis
718
2090
  if (simpleCheckModelMsum(msum, pol))
719
  {
720
1498
    return true;
721
  }
722
  // can also try reasoning about univariate quadratic equations
723
1184
  Trace("nl-ext-cms-debug")
724
592
      << "* Try univariate quadratic analysis..." << std::endl;
725
1184
  std::vector<Node> vs_invalid;
726
1184
  std::unordered_set<Node> vs;
727
1184
  std::map<Node, Node> v_a;
728
1184
  std::map<Node, Node> v_b;
729
  // get coefficients...
730
1766
  for (std::pair<const Node, Node>& m : msum)
731
  {
732
2348
    Node v = m.first;
733
1174
    if (!v.isNull())
734
    {
735
726
      if (v.isVar())
736
      {
737
98
        v_b[v] = m.second.isNull() ? d_one : m.second;
738
98
        vs.insert(v);
739
      }
740
1347
      else if (v.getKind() == NONLINEAR_MULT && v.getNumChildren() == 2
741
1347
               && v[0] == v[1] && v[0].isVar())
742
      {
743
91
        v_a[v[0]] = m.second.isNull() ? d_one : m.second;
744
91
        vs.insert(v[0]);
745
      }
746
      else
747
      {
748
537
        vs_invalid.push_back(v);
749
      }
750
    }
751
  }
752
  // solve the valid variables...
753
592
  Node invalid_vsum = vs_invalid.empty() ? d_zero
754
457
                                         : (vs_invalid.size() == 1
755
377
                                                ? vs_invalid[0]
756
2018
                                                : nm->mkNode(PLUS, vs_invalid));
757
  // substitution to try
758
1184
  Subs qsub;
759
727
  for (const Node& v : vs)
760
  {
761
    // is it a valid variable?
762
    std::map<Node, std::pair<Node, Node>>::iterator bit =
763
135
        d_check_model_bounds.find(v);
764
135
    if (!expr::hasSubterm(invalid_vsum, v) && bit != d_check_model_bounds.end())
765
    {
766
135
      std::map<Node, Node>::iterator it = v_a.find(v);
767
135
      if (it != v_a.end())
768
      {
769
182
        Node a = it->second;
770
91
        Assert(a.isConst());
771
91
        int asgn = a.getConst<Rational>().sgn();
772
91
        Assert(asgn != 0);
773
182
        Node t = nm->mkNode(MULT, a, v, v);
774
182
        Node b = d_zero;
775
91
        it = v_b.find(v);
776
91
        if (it != v_b.end())
777
        {
778
54
          b = it->second;
779
54
          t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v));
780
        }
781
91
        t = Rewriter::rewrite(t);
782
182
        Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
783
91
                                  << t << "..." << std::endl;
784
91
        Trace("nl-ext-cms-debug") << "    a = " << a << std::endl;
785
91
        Trace("nl-ext-cms-debug") << "    b = " << b << std::endl;
786
        // find maximal/minimal value on the interval
787
        Node apex = nm->mkNode(
788
182
            DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
789
91
        apex = Rewriter::rewrite(apex);
790
91
        Assert(apex.isConst());
791
        // for lower, upper, whether we are greater than the apex
792
        bool cmp[2];
793
182
        Node boundn[2];
794
273
        for (unsigned r = 0; r < 2; r++)
795
        {
796
182
          boundn[r] = r == 0 ? bit->second.first : bit->second.second;
797
364
          Node cmpn = nm->mkNode(GT, boundn[r], apex);
798
182
          cmpn = Rewriter::rewrite(cmpn);
799
182
          Assert(cmpn.isConst());
800
182
          cmp[r] = cmpn.getConst<bool>();
801
        }
802
91
        Trace("nl-ext-cms-debug") << "  apex " << apex << std::endl;
803
182
        Trace("nl-ext-cms-debug")
804
91
            << "  lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
805
182
        Trace("nl-ext-cms-debug")
806
91
            << "  upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
807
91
        Assert(boundn[0].getConst<Rational>()
808
               <= boundn[1].getConst<Rational>());
809
182
        Node s;
810
91
        qsub.add(v, Node());
811
91
        if (cmp[0] != cmp[1])
812
        {
813
          Assert(!cmp[0] && cmp[1]);
814
          // does the sign match the bound?
815
          if ((asgn == 1) == pol)
816
          {
817
            // the apex is the max/min value
818
            s = apex;
819
            Trace("nl-ext-cms-debug") << "  ...set to apex." << std::endl;
820
          }
821
          else
822
          {
823
            // it is one of the endpoints, plug in and compare
824
            Node tcmpn[2];
825
            for (unsigned r = 0; r < 2; r++)
826
            {
827
              qsub.d_subs.back() = boundn[r];
828
              Node ts = arithSubstitute(t, qsub);
829
              tcmpn[r] = Rewriter::rewrite(ts);
830
            }
831
            Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
832
            Trace("nl-ext-cms-debug")
833
                << "  ...both sides of apex, compare " << tcmp << std::endl;
834
            tcmp = Rewriter::rewrite(tcmp);
835
            Assert(tcmp.isConst());
836
            unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
837
            Trace("nl-ext-cms-debug")
838
                << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
839
                << std::endl;
840
            s = boundn[bindex_use];
841
          }
842
        }
843
        else
844
        {
845
          // both to one side of the apex
846
          // we figure out which bound to use (lower or upper) based on
847
          // three factors:
848
          // (1) whether a's sign is positive,
849
          // (2) whether we are greater than the apex of the parabola,
850
          // (3) the polarity of the constraint, i.e. >= or <=.
851
          // there are 8 cases of these factors, which we test here.
852
91
          unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1;
853
182
          Trace("nl-ext-cms-debug")
854
91
              << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
855
91
              << std::endl;
856
91
          s = boundn[bindex_use];
857
        }
858
91
        Assert(!s.isNull());
859
91
        qsub.d_subs.back() = s;
860
182
        Trace("nl-ext-cms") << "* set bound based on quadratic : " << v
861
91
                            << " -> " << s << std::endl;
862
      }
863
    }
864
  }
865
592
  if (!qsub.empty())
866
  {
867
182
    Node slit = arithSubstitute(lit, qsub);
868
91
    slit = Rewriter::rewrite(slit);
869
91
    return simpleCheckModelLit(slit);
870
  }
871
501
  return false;
872
}
873
874
2090
bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol)
875
{
876
2090
  Trace("nl-ext-cms-debug") << "* Try simple interval analysis..." << std::endl;
877
2090
  NodeManager* nm = NodeManager::currentNM();
878
  // map from transcendental functions to whether they were set to lower
879
  // bound
880
2090
  bool simpleSuccess = true;
881
4180
  std::map<Node, bool> set_bound;
882
4180
  std::vector<Node> sum_bound;
883
6198
  for (const std::pair<const Node, Node>& m : msum)
884
  {
885
8270
    Node v = m.first;
886
4162
    if (v.isNull())
887
    {
888
1748
      sum_bound.push_back(m.second.isNull() ? d_one : m.second);
889
    }
890
    else
891
    {
892
2414
      Trace("nl-ext-cms-debug") << "- monomial : " << v << std::endl;
893
      // --- whether we should set a lower bound for this monomial
894
      bool set_lower =
895
2414
          (m.second.isNull() || m.second.getConst<Rational>().sgn() == 1)
896
2414
          == pol;
897
4828
      Trace("nl-ext-cms-debug")
898
2414
          << "set bound to " << (set_lower ? "lower" : "upper") << std::endl;
899
900
      // --- Collect variables and factors in v
901
4774
      std::vector<Node> vars;
902
4774
      std::vector<unsigned> factors;
903
2414
      if (v.getKind() == NONLINEAR_MULT)
904
      {
905
169
        unsigned last_start = 0;
906
507
        for (unsigned i = 0, nchildren = v.getNumChildren(); i < nchildren; i++)
907
        {
908
          // are we at the end?
909
338
          if (i + 1 == nchildren || v[i + 1] != v[i])
910
          {
911
169
            unsigned vfact = 1 + (i - last_start);
912
169
            last_start = (i + 1);
913
169
            vars.push_back(v[i]);
914
169
            factors.push_back(vfact);
915
          }
916
        }
917
      }
918
      else
919
      {
920
2245
        vars.push_back(v);
921
2245
        factors.push_back(1);
922
      }
923
924
      // --- Get the lower and upper bounds and sign information.
925
      // Whether we have an (odd) number of negative factors in vars, apart
926
      // from the variable at choose_index.
927
2414
      bool has_neg_factor = false;
928
2414
      int choose_index = -1;
929
4774
      std::vector<Node> ls;
930
4774
      std::vector<Node> us;
931
      // the relevant sign information for variables with odd exponents:
932
      //   1: both signs of the interval of this variable are positive,
933
      //  -1: both signs of the interval of this variable are negative.
934
4774
      std::vector<int> signs;
935
2414
      Trace("nl-ext-cms-debug") << "get sign information..." << std::endl;
936
4828
      for (unsigned i = 0, size = vars.size(); i < size; i++)
937
      {
938
4828
        Node vc = vars[i];
939
2414
        unsigned vcfact = factors[i];
940
2414
        if (Trace.isOn("nl-ext-cms-debug"))
941
        {
942
          Trace("nl-ext-cms-debug") << "-- " << vc;
943
          if (vcfact > 1)
944
          {
945
            Trace("nl-ext-cms-debug") << "^" << vcfact;
946
          }
947
          Trace("nl-ext-cms-debug") << " ";
948
        }
949
        std::map<Node, std::pair<Node, Node>>::iterator bit =
950
2414
            d_check_model_bounds.find(vc);
951
        // if there is a model bound for this term
952
2414
        if (bit != d_check_model_bounds.end())
953
        {
954
4828
          Node l = bit->second.first;
955
4828
          Node u = bit->second.second;
956
2414
          ls.push_back(l);
957
2414
          us.push_back(u);
958
2414
          int vsign = 0;
959
2414
          if (vcfact % 2 == 1)
960
          {
961
2245
            vsign = 1;
962
2245
            int lsgn = l.getConst<Rational>().sgn();
963
2245
            int usgn = u.getConst<Rational>().sgn();
964
4490
            Trace("nl-ext-cms-debug")
965
2245
                << "bound_sign(" << lsgn << "," << usgn << ") ";
966
2245
            if (lsgn == -1)
967
            {
968
349
              if (usgn < 1)
969
              {
970
                // must have a negative factor
971
349
                has_neg_factor = !has_neg_factor;
972
349
                vsign = -1;
973
              }
974
              else if (choose_index == -1)
975
              {
976
                // set the choose index to this
977
                choose_index = i;
978
                vsign = 0;
979
              }
980
              else
981
              {
982
                // ambiguous, can't determine the bound
983
                Trace("nl-ext-cms")
984
                    << "  failed due to ambiguious monomial." << std::endl;
985
                return false;
986
              }
987
            }
988
          }
989
2414
          Trace("nl-ext-cms-debug") << " -> " << vsign << std::endl;
990
2414
          signs.push_back(vsign);
991
        }
992
        else
993
        {
994
          Assert(d_check_model_witnesses.find(vc)
995
                 == d_check_model_witnesses.end())
996
              << "No variable should be assigned a witness term if we get "
997
                 "here. "
998
              << vc << " is, though." << std::endl;
999
          Trace("nl-ext-cms-debug") << std::endl;
1000
          Trace("nl-ext-cms")
1001
              << "  failed due to unknown bound for " << vc << std::endl;
1002
          // should either assign a model bound or eliminate the variable
1003
          // via substitution
1004
          Assert(false);
1005
          return false;
1006
        }
1007
      }
1008
      // whether we will try to minimize/maximize (-1/1) the absolute value
1009
2414
      int setAbs = (set_lower == has_neg_factor) ? 1 : -1;
1010
4828
      Trace("nl-ext-cms-debug")
1011
2414
          << "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal")
1012
2414
          << std::endl;
1013
1014
4774
      std::vector<Node> vbs;
1015
2414
      Trace("nl-ext-cms-debug") << "set bounds..." << std::endl;
1016
4774
      for (unsigned i = 0, size = vars.size(); i < size; i++)
1017
      {
1018
4774
        Node vc = vars[i];
1019
2414
        unsigned vcfact = factors[i];
1020
4774
        Node l = ls[i];
1021
4774
        Node u = us[i];
1022
        bool vc_set_lower;
1023
2414
        int vcsign = signs[i];
1024
4828
        Trace("nl-ext-cms-debug")
1025
2414
            << "Bounds for " << vc << " : " << l << ", " << u
1026
2414
            << ", sign : " << vcsign << ", factor : " << vcfact << std::endl;
1027
2414
        if (l == u)
1028
        {
1029
          // by convention, always say it is lower if they are the same
1030
          vc_set_lower = true;
1031
          Trace("nl-ext-cms-debug")
1032
              << "..." << vc << " equal bound, set to lower" << std::endl;
1033
        }
1034
        else
1035
        {
1036
2414
          if (vcfact % 2 == 0)
1037
          {
1038
            // minimize or maximize its absolute value
1039
338
            Rational la = l.getConst<Rational>().abs();
1040
338
            Rational ua = u.getConst<Rational>().abs();
1041
169
            if (la == ua)
1042
            {
1043
              // by convention, always say it is lower if abs are the same
1044
              vc_set_lower = true;
1045
              Trace("nl-ext-cms-debug")
1046
                  << "..." << vc << " equal abs, set to lower" << std::endl;
1047
            }
1048
            else
1049
            {
1050
169
              vc_set_lower = (la > ua) == (setAbs == 1);
1051
            }
1052
          }
1053
2245
          else if (signs[i] == 0)
1054
          {
1055
            // we choose this index to match the overall set_lower
1056
            vc_set_lower = set_lower;
1057
          }
1058
          else
1059
          {
1060
2245
            vc_set_lower = (signs[i] != setAbs);
1061
          }
1062
4828
          Trace("nl-ext-cms-debug")
1063
2414
              << "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper")
1064
2414
              << std::endl;
1065
        }
1066
        // check whether this is a conflicting bound
1067
2414
        std::map<Node, bool>::iterator itsb = set_bound.find(vc);
1068
2414
        if (itsb == set_bound.end())
1069
        {
1070
2344
          set_bound[vc] = vc_set_lower;
1071
        }
1072
70
        else if (itsb->second != vc_set_lower)
1073
        {
1074
108
          Trace("nl-ext-cms")
1075
54
              << "  failed due to conflicting bound for " << vc << std::endl;
1076
54
          return false;
1077
        }
1078
        // must over/under approximate based on vc_set_lower, computed above
1079
4720
        Node vb = vc_set_lower ? l : u;
1080
4835
        for (unsigned i2 = 0; i2 < vcfact; i2++)
1081
        {
1082
2475
          vbs.push_back(vb);
1083
        }
1084
      }
1085
2360
      if (!simpleSuccess)
1086
      {
1087
        break;
1088
      }
1089
4720
      Node vbound = vbs.size() == 1 ? vbs[0] : nm->mkNode(MULT, vbs);
1090
2360
      sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound));
1091
    }
1092
  }
1093
  // if the exact bound was computed via simple analysis above
1094
  // make the bound
1095
4072
  Node bound;
1096
2036
  if (sum_bound.size() > 1)
1097
  {
1098
1850
    bound = nm->mkNode(kind::PLUS, sum_bound);
1099
  }
1100
186
  else if (sum_bound.size() == 1)
1101
  {
1102
186
    bound = sum_bound[0];
1103
  }
1104
  else
1105
  {
1106
    bound = d_zero;
1107
  }
1108
  // make the comparison
1109
4072
  Node comp = nm->mkNode(kind::GEQ, bound, d_zero);
1110
2036
  if (!pol)
1111
  {
1112
1277
    comp = comp.negate();
1113
  }
1114
2036
  Trace("nl-ext-cms") << "  comparison is : " << comp << std::endl;
1115
2036
  comp = Rewriter::rewrite(comp);
1116
2036
  Assert(comp.isConst());
1117
2036
  Trace("nl-ext-cms") << "  returned : " << comp << std::endl;
1118
2036
  return comp == d_true;
1119
}
1120
1121
12
bool NlModel::getApproximateSqrt(Node c, Node& l, Node& u, unsigned iter) const
1122
{
1123
12
  Assert(c.isConst());
1124
12
  if (c == d_one || c == d_zero)
1125
  {
1126
2
    l = c;
1127
2
    u = c;
1128
2
    return true;
1129
  }
1130
20
  Rational rc = c.getConst<Rational>();
1131
1132
20
  Rational rl = rc < Rational(1) ? rc : Rational(1);
1133
20
  Rational ru = rc < Rational(1) ? Rational(1) : rc;
1134
10
  unsigned count = 0;
1135
20
  Rational half = Rational(1) / Rational(2);
1136
390
  while (count < iter)
1137
  {
1138
380
    Rational curr = half * (rl + ru);
1139
380
    Rational curr_sq = curr * curr;
1140
190
    if (curr_sq == rc)
1141
    {
1142
      rl = curr;
1143
      ru = curr;
1144
      break;
1145
    }
1146
190
    else if (curr_sq < rc)
1147
    {
1148
90
      rl = curr;
1149
    }
1150
    else
1151
    {
1152
100
      ru = curr;
1153
    }
1154
190
    count++;
1155
  }
1156
1157
10
  NodeManager* nm = NodeManager::currentNM();
1158
10
  l = nm->mkConst(rl);
1159
10
  u = nm->mkConst(ru);
1160
10
  return true;
1161
}
1162
1163
117048
void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
1164
{
1165
117048
  if (Trace.isOn(c))
1166
  {
1167
    Trace(c) << "  " << n << " -> ";
1168
    const Node& aval = d_abstractModelCache.at(n);
1169
    if (aval.isConst())
1170
    {
1171
      printRationalApprox(c, aval, prec);
1172
    }
1173
    else
1174
    {
1175
      Trace(c) << "?";
1176
    }
1177
    Trace(c) << " [actual: ";
1178
    const Node& cval = d_concreteModelCache.at(n);
1179
    if (cval.isConst())
1180
    {
1181
      printRationalApprox(c, cval, prec);
1182
    }
1183
    else
1184
    {
1185
      Trace(c) << "?";
1186
    }
1187
    Trace(c) << " ]" << std::endl;
1188
  }
1189
117048
}
1190
1191
585
void NlModel::getModelValueRepair(
1192
    std::map<Node, Node>& arithModel,
1193
    std::map<Node, std::pair<Node, Node>>& approximations,
1194
    std::map<Node, Node>& witnesses,
1195
    bool witnessToValue)
1196
{
1197
585
  Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl;
1198
  // If we extended the model with entries x -> 0 for unconstrained values,
1199
  // we first update the map to the extended one.
1200
585
  if (d_arithVal.size() > arithModel.size())
1201
  {
1202
6
    arithModel = d_arithVal;
1203
  }
1204
  // Record the approximations we used. This code calls the
1205
  // recordApproximation method of the model, which overrides the model
1206
  // values for variables that we solved for, using techniques specific to
1207
  // this class.
1208
585
  NodeManager* nm = NodeManager::currentNM();
1209
23
  for (const std::pair<const Node, std::pair<Node, Node>>& cb :
1210
585
       d_check_model_bounds)
1211
  {
1212
46
    Node l = cb.second.first;
1213
46
    Node u = cb.second.second;
1214
46
    Node pred;
1215
46
    Node v = cb.first;
1216
23
    if (l != u)
1217
    {
1218
23
      pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v));
1219
23
      Trace("nl-model") << v << " approximated as " << pred << std::endl;
1220
46
      Node witness;
1221
23
      if (witnessToValue)
1222
      {
1223
        // witness is the midpoint
1224
        witness = nm->mkNode(
1225
            MULT, nm->mkConst(Rational(1, 2)), nm->mkNode(PLUS, l, u));
1226
        witness = Rewriter::rewrite(witness);
1227
        Trace("nl-model") << v << " witness is " << witness << std::endl;
1228
      }
1229
23
      approximations[v] = std::pair<Node, Node>(pred, witness);
1230
    }
1231
    else
1232
    {
1233
      // overwrite
1234
      arithModel[v] = l;
1235
      Trace("nl-model") << v << " exact approximation is " << l << std::endl;
1236
    }
1237
  }
1238
592
  for (const auto& vw : d_check_model_witnesses)
1239
  {
1240
7
    Trace("nl-model") << vw.first << " witness is " << vw.second << std::endl;
1241
7
    witnesses.emplace(vw.first, vw.second);
1242
  }
1243
  // Also record the exact values we used. An exact value can be seen as a
1244
  // special kind approximation of the form (witness x. x = exact_value).
1245
  // Notice that the above term gets rewritten such that the choice function
1246
  // is eliminated.
1247
846
  for (size_t i = 0; i < d_substitutions.size(); ++i)
1248
  {
1249
    // overwrite
1250
261
    arithModel[d_substitutions.d_vars[i]] = d_substitutions.d_subs[i];
1251
522
    Trace("nl-model") << d_substitutions.d_vars[i] << " solved is "
1252
261
                      << d_substitutions.d_subs[i] << std::endl;
1253
  }
1254
1255
  // multiplication terms should not be given values; their values are
1256
  // implied by the monomials that they consist of
1257
1170
  std::vector<Node> amErase;
1258
12542
  for (const std::pair<const Node, Node>& am : arithModel)
1259
  {
1260
11957
    if (am.first.getKind() == NONLINEAR_MULT)
1261
    {
1262
2179
      amErase.push_back(am.first);
1263
    }
1264
  }
1265
2764
  for (const Node& ae : amErase)
1266
  {
1267
2179
    arithModel.erase(ae);
1268
  }
1269
585
}
1270
1271
62398
Node NlModel::getValueInternal(TNode n)
1272
{
1273
62398
  if (n.isConst())
1274
  {
1275
    return n;
1276
  }
1277
62398
  if (auto it = d_arithVal.find(n); it != d_arithVal.end())
1278
  {
1279
62360
    AlwaysAssert(it->second.isConst());
1280
62360
    return it->second;
1281
  }
1282
  // It is unconstrained in the model, return 0. We additionally add it
1283
  // to mapping from the linear solver. This ensures that if the nonlinear
1284
  // solver assumes that n = 0, then this assumption is recorded in the overall
1285
  // model.
1286
38
  d_arithVal[n] = d_zero;
1287
38
  return d_zero;
1288
}
1289
1290
765
bool NlModel::hasAssignment(Node v) const
1291
{
1292
765
  if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
1293
  {
1294
10
    return true;
1295
  }
1296
755
  if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end())
1297
  {
1298
    return true;
1299
  }
1300
755
  return (d_substitutions.contains(v));
1301
}
1302
1303
252553
bool NlModel::hasLinearModelValue(TNode v, Node& val) const
1304
{
1305
252553
  auto it = d_arithVal.find(v);
1306
252553
  if (it != d_arithVal.end())
1307
  {
1308
82095
    val = it->second;
1309
82095
    return true;
1310
  }
1311
170458
  return false;
1312
}
1313
1314
}  // namespace nl
1315
}  // namespace arith
1316
}  // namespace theory
1317
31125
}  // namespace cvc5