GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/nl_model.cpp Lines: 597 745 80.1 %
Date: 2021-09-17 Branches: 1389 3539 39.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Gereon Kremer, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Model object for the non-linear extension class.
14
 */
15
16
#include "theory/arith/nl/nl_model.h"
17
18
#include "expr/node_algorithm.h"
19
#include "options/arith_options.h"
20
#include "options/smt_options.h"
21
#include "options/theory_options.h"
22
#include "theory/arith/arith_msum.h"
23
#include "theory/arith/arith_utilities.h"
24
#include "theory/arith/nl/nl_lemma_utils.h"
25
#include "theory/theory_model.h"
26
#include "theory/rewriter.h"
27
28
using namespace cvc5::kind;
29
30
namespace cvc5 {
31
namespace theory {
32
namespace arith {
33
namespace nl {
34
35
5203
NlModel::NlModel() : d_used_approx(false)
36
{
37
5203
  d_true = NodeManager::currentNM()->mkConst(true);
38
5203
  d_false = NodeManager::currentNM()->mkConst(false);
39
5203
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
40
5203
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
41
5203
  d_two = NodeManager::currentNM()->mkConst(Rational(2));
42
5203
}
43
44
7067
NlModel::~NlModel() {}
45
46
4433
void NlModel::reset(TheoryModel* m, std::map<Node, Node>& arithModel)
47
{
48
4433
  d_model = m;
49
4433
  d_mv[0].clear();
50
4433
  d_mv[1].clear();
51
4433
  d_arithVal.clear();
52
  // process arithModel
53
4433
  std::map<Node, Node>::iterator it;
54
92200
  for (const std::pair<const Node, Node>& m2 : arithModel)
55
  {
56
87767
    d_arithVal[m2.first] = m2.second;
57
  }
58
4433
}
59
60
4449
void NlModel::resetCheck()
61
{
62
4449
  d_used_approx = false;
63
4449
  d_check_model_solved.clear();
64
4449
  d_check_model_bounds.clear();
65
4449
  d_check_model_witnesses.clear();
66
4449
  d_check_model_vars.clear();
67
4449
  d_check_model_subs.clear();
68
4449
}
69
70
688050
Node NlModel::computeConcreteModelValue(Node n)
71
{
72
688050
  return computeModelValue(n, true);
73
}
74
75
342125
Node NlModel::computeAbstractModelValue(Node n)
76
{
77
342125
  return computeModelValue(n, false);
78
}
79
80
3205235
Node NlModel::computeModelValue(Node n, bool isConcrete)
81
{
82
3205235
  unsigned index = isConcrete ? 0 : 1;
83
3205235
  std::map<Node, Node>::iterator it = d_mv[index].find(n);
84
3205235
  if (it != d_mv[index].end())
85
  {
86
1959385
    return it->second;
87
  }
88
2491700
  Trace("nl-ext-mv-debug") << "computeModelValue " << n << ", index=" << index
89
1245850
                           << std::endl;
90
2491700
  Node ret;
91
1245850
  Kind nk = n.getKind();
92
1245850
  if (n.isConst())
93
  {
94
56386
    ret = n;
95
  }
96
1189464
  else if (!isConcrete && hasTerm(n))
97
  {
98
    // use model value for abstraction
99
74420
    ret = getRepresentative(n);
100
  }
101
1115044
  else if (n.getNumChildren() == 0)
102
  {
103
    // we are interested in the exact value of PI, which cannot be computed.
104
    // hence, we return PI itself when asked for the concrete value.
105
36308
    if (nk == PI)
106
    {
107
553
      ret = n;
108
    }
109
    else
110
    {
111
35755
      ret = getValueInternal(n);
112
    }
113
  }
114
  else
115
  {
116
    // otherwise, compute true value
117
1078736
    TheoryId ctid = theory::kindToTheoryId(nk);
118
1078736
    if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN)
119
    {
120
      // we directly look up terms not belonging to arithmetic
121
19189
      ret = getValueInternal(n);
122
    }
123
    else
124
    {
125
2119094
      std::vector<Node> children;
126
1059547
      if (n.getMetaKind() == metakind::PARAMETERIZED)
127
      {
128
347
        children.push_back(n.getOperator());
129
      }
130
2998911
      for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
131
      {
132
3878728
        Node mc = computeModelValue(n[i], isConcrete);
133
1939364
        children.push_back(mc);
134
      }
135
1059547
      ret = NodeManager::currentNM()->mkNode(nk, children);
136
1059547
      ret = Rewriter::rewrite(ret);
137
    }
138
  }
139
2491700
  Trace("nl-ext-mv-debug") << "computed " << (index == 0 ? "M" : "M_A") << "["
140
1245850
                           << n << "] = " << ret << std::endl;
141
1245850
  d_mv[index][n] = ret;
142
1245850
  return ret;
143
}
144
145
247334
bool NlModel::hasTerm(Node n) const
146
{
147
247334
  return d_arithVal.find(n) != d_arithVal.end();
148
}
149
150
74420
Node NlModel::getRepresentative(Node n) const
151
{
152
74420
  if (n.isConst())
153
  {
154
    return n;
155
  }
156
74420
  std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
157
74420
  if (it != d_arithVal.end())
158
  {
159
74420
    AlwaysAssert(it->second.isConst());
160
74420
    return it->second;
161
  }
162
  return d_model->getRepresentative(n);
163
}
164
165
54944
Node NlModel::getValueInternal(Node n)
166
{
167
54944
  if (n.isConst())
168
  {
169
    return n;
170
  }
171
54944
  std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
172
54944
  if (it != d_arithVal.end())
173
  {
174
54910
    AlwaysAssert(it->second.isConst());
175
54910
    return it->second;
176
  }
177
  // It is unconstrained in the model, return 0. We additionally add it
178
  // to mapping from the linear solver. This ensures that if the nonlinear
179
  // solver assumes that n = 0, then this assumption is recorded in the overall
180
  // model.
181
34
  d_arithVal[n] = d_zero;
182
34
  return d_zero;
183
}
184
185
110444
int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
186
{
187
220888
  Node ci = computeModelValue(i, isConcrete);
188
220888
  Node cj = computeModelValue(j, isConcrete);
189
110444
  if (ci.isConst())
190
  {
191
110444
    if (cj.isConst())
192
    {
193
110444
      return compareValue(ci, cj, isAbsolute);
194
    }
195
    return 1;
196
  }
197
  return cj.isConst() ? -1 : 0;
198
}
199
200
125252
int NlModel::compareValue(Node i, Node j, bool isAbsolute) const
201
{
202
125252
  Assert(i.isConst() && j.isConst());
203
  int ret;
204
125252
  if (i == j)
205
  {
206
18802
    ret = 0;
207
  }
208
106450
  else if (!isAbsolute)
209
  {
210
6730
    ret = i.getConst<Rational>() < j.getConst<Rational>() ? 1 : -1;
211
  }
212
  else
213
  {
214
299160
    ret = (i.getConst<Rational>().abs() == j.getConst<Rational>().abs()
215
193609
               ? 0
216
381387
               : (i.getConst<Rational>().abs() < j.getConst<Rational>().abs()
217
281667
                      ? 1
218
                      : -1));
219
  }
220
125252
  return ret;
221
}
222
223
262
bool NlModel::checkModel(const std::vector<Node>& assertions,
224
                         unsigned d,
225
                         std::vector<NlLemma>& lemmas)
226
{
227
262
  Trace("nl-ext-cm-debug") << "  solve for equalities..." << std::endl;
228
4443
  for (const Node& atom : assertions)
229
  {
230
    // see if it corresponds to a univariate polynomial equation of degree two
231
4181
    if (atom.getKind() == EQUAL)
232
    {
233
546
      if (!solveEqualitySimple(atom, d, lemmas))
234
      {
235
        // no chance we will satisfy this equality
236
342
        Trace("nl-ext-cm") << "...check-model : failed to solve equality : "
237
171
                           << atom << std::endl;
238
      }
239
    }
240
  }
241
242
  // all remaining variables are constrained to their exact model values
243
524
  Trace("nl-ext-cm-debug") << "  set exact bounds for remaining variables..."
244
262
                           << std::endl;
245
524
  std::unordered_set<TNode> visited;
246
524
  std::vector<TNode> visit;
247
524
  TNode cur;
248
4443
  for (const Node& a : assertions)
249
  {
250
4181
    visit.push_back(a);
251
15785
    do
252
    {
253
19966
      cur = visit.back();
254
19966
      visit.pop_back();
255
19966
      if (visited.find(cur) == visited.end())
256
      {
257
11548
        visited.insert(cur);
258
11548
        if (cur.getType().isReal() && !cur.isConst())
259
        {
260
3208
          Kind k = cur.getKind();
261
5419
          if (k != MULT && k != PLUS && k != NONLINEAR_MULT
262
4095
              && !isTranscendentalKind(k))
263
          {
264
            // if we have not set an approximate bound for it
265
594
            if (!hasCheckModelAssignment(cur))
266
            {
267
              // set its exact model value in the substitution
268
614
              Node curv = computeConcreteModelValue(cur);
269
307
              if (Trace.isOn("nl-ext-cm"))
270
              {
271
                Trace("nl-ext-cm")
272
                    << "check-model-bound : exact : " << cur << " = ";
273
                printRationalApprox("nl-ext-cm", curv);
274
                Trace("nl-ext-cm") << std::endl;
275
              }
276
307
              bool ret = addCheckModelSubstitution(cur, curv);
277
307
              AlwaysAssert(ret);
278
            }
279
          }
280
        }
281
27333
        for (const Node& cn : cur)
282
        {
283
15785
          visit.push_back(cn);
284
        }
285
      }
286
19966
    } while (!visit.empty());
287
  }
288
289
262
  Trace("nl-ext-cm-debug") << "  check assertions..." << std::endl;
290
524
  std::vector<Node> check_assertions;
291
4443
  for (const Node& a : assertions)
292
  {
293
4181
    if (d_check_model_solved.find(a) == d_check_model_solved.end())
294
    {
295
7612
      Node av = a;
296
      // apply the substitution to a
297
3806
      if (!d_check_model_vars.empty())
298
      {
299
3239
        av = arithSubstitute(av, d_check_model_vars, d_check_model_subs);
300
3239
        av = Rewriter::rewrite(av);
301
      }
302
      // simple check literal
303
3806
      if (!simpleCheckModelLit(av))
304
      {
305
794
        Trace("nl-ext-cm") << "...check-model : assertion failed : " << a
306
397
                           << std::endl;
307
397
        check_assertions.push_back(av);
308
794
        Trace("nl-ext-cm-debug")
309
397
            << "...check-model : failed assertion, value : " << av << std::endl;
310
      }
311
    }
312
  }
313
314
262
  if (!check_assertions.empty())
315
  {
316
200
    Trace("nl-ext-cm") << "...simple check failed." << std::endl;
317
    // TODO (#1450) check model for general case
318
200
    return false;
319
  }
320
62
  Trace("nl-ext-cm") << "...simple check succeeded!" << std::endl;
321
62
  return true;
322
}
323
324
609
bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
325
{
326
  // should not substitute the same variable twice
327
1218
  Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s
328
609
                        << std::endl;
329
  // should not set exact bound more than once
330
1827
  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
331
1827
      != d_check_model_vars.end())
332
  {
333
    Trace("nl-ext-model") << "...ERROR: already has value." << std::endl;
334
    // this should never happen since substitutions should be applied eagerly
335
    Assert(false);
336
    return false;
337
  }
338
  // if we previously had an approximate bound, the exact bound should be in its
339
  // range
340
  std::map<Node, std::pair<Node, Node>>::iterator itb =
341
609
      d_check_model_bounds.find(v);
342
609
  if (itb != d_check_model_bounds.end())
343
  {
344
    if (s.getConst<Rational>() >= itb->second.first.getConst<Rational>()
345
        || s.getConst<Rational>() <= itb->second.second.getConst<Rational>())
346
    {
347
      Trace("nl-ext-model")
348
          << "...ERROR: already has bound which is out of range." << std::endl;
349
      return false;
350
    }
351
  }
352
609
  Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end())
353
      << "We tried to add a substitution where we already had a witness term."
354
      << std::endl;
355
1218
  std::vector<Node> varsTmp;
356
609
  varsTmp.push_back(v);
357
1218
  std::vector<Node> subsTmp;
358
609
  subsTmp.push_back(s);
359
1770
  for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++)
360
  {
361
2322
    Node ms = d_check_model_subs[i];
362
2322
    Node mss = arithSubstitute(ms, varsTmp, subsTmp);
363
1161
    if (mss != ms)
364
    {
365
41
      mss = Rewriter::rewrite(mss);
366
    }
367
1161
    d_check_model_subs[i] = mss;
368
  }
369
609
  d_check_model_vars.push_back(v);
370
609
  d_check_model_subs.push_back(s);
371
609
  return true;
372
}
373
374
393
bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u)
375
{
376
786
  Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " "
377
393
                        << u << "]" << std::endl;
378
393
  if (l == u)
379
  {
380
    // bound is exact, can add as substitution
381
    return addCheckModelSubstitution(v, l);
382
  }
383
  // should not set a bound for a value that is exact
384
1179
  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
385
1179
      != d_check_model_vars.end())
386
  {
387
    Trace("nl-ext-model")
388
        << "...ERROR: setting bound for variable that already has exact value."
389
        << std::endl;
390
    Assert(false);
391
    return false;
392
  }
393
393
  Assert(l.isConst());
394
393
  Assert(u.isConst());
395
393
  Assert(l.getConst<Rational>() <= u.getConst<Rational>());
396
393
  d_check_model_bounds[v] = std::pair<Node, Node>(l, u);
397
393
  if (Trace.isOn("nl-ext-cm"))
398
  {
399
    Trace("nl-ext-cm") << "check-model-bound : approximate : ";
400
    printRationalApprox("nl-ext-cm", l);
401
    Trace("nl-ext-cm") << " <= " << v << " <= ";
402
    printRationalApprox("nl-ext-cm", u);
403
    Trace("nl-ext-cm") << std::endl;
404
  }
405
393
  return true;
406
}
407
408
7
bool NlModel::addCheckModelWitness(TNode v, TNode w)
409
{
410
14
  Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w
411
7
                        << std::endl;
412
  // should not set a witness for a value that is already set
413
21
  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
414
21
      != d_check_model_vars.end())
415
  {
416
    Trace("nl-ext-model") << "...ERROR: setting witness for variable that "
417
                             "already has a constant value."
418
                          << std::endl;
419
    Assert(false);
420
    return false;
421
  }
422
7
  d_check_model_witnesses.emplace(v, w);
423
7
  return true;
424
}
425
426
754
bool NlModel::hasCheckModelAssignment(Node v) const
427
{
428
754
  if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
429
  {
430
9
    return true;
431
  }
432
745
  if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end())
433
  {
434
    return true;
435
  }
436
1490
  return std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
437
2235
         != d_check_model_vars.end();
438
}
439
440
269
void NlModel::setUsedApproximate() { d_used_approx = true; }
441
442
16
bool NlModel::usedApproximate() const { return d_used_approx; }
443
444
636
bool NlModel::solveEqualitySimple(Node eq,
445
                                  unsigned d,
446
                                  std::vector<NlLemma>& lemmas)
447
{
448
1272
  Node seq = eq;
449
636
  if (!d_check_model_vars.empty())
450
  {
451
454
    seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs);
452
454
    seq = Rewriter::rewrite(seq);
453
454
    if (seq.isConst())
454
    {
455
178
      if (seq.getConst<bool>())
456
      {
457
        // already true
458
178
        d_check_model_solved[eq] = Node::null();
459
178
        return true;
460
      }
461
      return false;
462
    }
463
  }
464
458
  Trace("nl-ext-cms") << "simple solve equality " << seq << "..." << std::endl;
465
458
  Assert(seq.getKind() == EQUAL);
466
916
  std::map<Node, Node> msum;
467
458
  if (!ArithMSum::getMonomialSumLit(seq, msum))
468
  {
469
    Trace("nl-ext-cms") << "...fail, could not determine monomial sum."
470
                        << std::endl;
471
    return false;
472
  }
473
458
  bool is_valid = true;
474
  // the variable we will solve a quadratic equation for
475
916
  Node var;
476
916
  Node a = d_zero;
477
916
  Node b = d_zero;
478
916
  Node c = d_zero;
479
458
  NodeManager* nm = NodeManager::currentNM();
480
  // the list of variables that occur as a monomial in msum, and whose value
481
  // is so far unconstrained in the model.
482
916
  std::unordered_set<Node> unc_vars;
483
  // the list of variables that occur as a factor in a monomial, and whose
484
  // value is so far unconstrained in the model.
485
916
  std::unordered_set<Node> unc_vars_factor;
486
1320
  for (std::pair<const Node, Node>& m : msum)
487
  {
488
1724
    Node v = m.first;
489
1724
    Node coeff = m.second.isNull() ? d_one : m.second;
490
862
    if (v.isNull())
491
    {
492
217
      c = coeff;
493
    }
494
645
    else if (v.getKind() == NONLINEAR_MULT)
495
    {
496
387
      if (v.getNumChildren() == 2 && v[0].isVar() && v[0] == v[1]
497
297
          && (var.isNull() || var == v[0]))
498
      {
499
        // may solve quadratic
500
35
        a = coeff;
501
35
        var = v[0];
502
      }
503
      else
504
      {
505
94
        is_valid = false;
506
188
        Trace("nl-ext-cms-debug")
507
94
            << "...invalid due to non-linear monomial " << v << std::endl;
508
        // may wish to set an exact bound for a factor and repeat
509
282
        for (const Node& vc : v)
510
        {
511
188
          unc_vars_factor.insert(vc);
512
        }
513
      }
514
    }
515
516
    else if (!v.isVar() || (!var.isNull() && var != v))
516
    {
517
592
      Trace("nl-ext-cms-debug")
518
296
          << "...invalid due to factor " << v << std::endl;
519
      // cannot solve multivariate
520
296
      if (is_valid)
521
      {
522
216
        is_valid = false;
523
        // if b is non-zero, then var is also an unconstrained variable
524
216
        if (b != d_zero)
525
        {
526
53
          unc_vars.insert(var);
527
53
          unc_vars_factor.insert(var);
528
        }
529
      }
530
      // if v is unconstrained, we may turn this equality into a substitution
531
296
      unc_vars.insert(v);
532
296
      unc_vars_factor.insert(v);
533
    }
534
    else
535
    {
536
      // set the variable to solve for
537
220
      b = coeff;
538
220
      var = v;
539
    }
540
  }
541
458
  if (!is_valid)
542
  {
543
    // see if we can solve for a variable?
544
563
    for (const Node& uv : unc_vars)
545
    {
546
318
      Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl;
547
      // cannot already have a bound
548
318
      if (uv.isVar() && !hasCheckModelAssignment(uv))
549
      {
550
61
        Node slv;
551
61
        Node veqc;
552
61
        if (ArithMSum::isolate(uv, msum, veqc, slv, EQUAL) != 0)
553
        {
554
61
          Assert(!slv.isNull());
555
          // Currently do not support substitution-with-coefficients.
556
          // We also ensure types are correct here, which avoids substituting
557
          // a term of non-integer type for a variable of integer type.
558
183
          if (veqc.isNull() && !expr::hasSubterm(slv, uv)
559
183
              && slv.getType().isSubtypeOf(uv.getType()))
560
          {
561
122
            Trace("nl-ext-cm")
562
61
                << "check-model-subs : " << uv << " -> " << slv << std::endl;
563
61
            bool ret = addCheckModelSubstitution(uv, slv);
564
61
            if (ret)
565
            {
566
122
              Trace("nl-ext-cms") << "...success, model substitution " << uv
567
61
                                  << " -> " << slv << std::endl;
568
61
              d_check_model_solved[eq] = uv;
569
            }
570
61
            return ret;
571
          }
572
        }
573
      }
574
    }
575
    // see if we can assign a variable to a constant
576
486
    for (const Node& uvf : unc_vars_factor)
577
    {
578
331
      Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl;
579
      // cannot already have a bound
580
331
      if (uvf.isVar() && !hasCheckModelAssignment(uvf))
581
      {
582
180
        Node uvfv = computeConcreteModelValue(uvf);
583
90
        if (Trace.isOn("nl-ext-cm"))
584
        {
585
          Trace("nl-ext-cm") << "check-model-bound : exact : " << uvf << " = ";
586
          printRationalApprox("nl-ext-cm", uvfv);
587
          Trace("nl-ext-cm") << std::endl;
588
        }
589
90
        bool ret = addCheckModelSubstitution(uvf, uvfv);
590
        // recurse
591
90
        return ret ? solveEqualitySimple(eq, d, lemmas) : false;
592
      }
593
    }
594
310
    Trace("nl-ext-cms") << "...fail due to constrained invalid terms."
595
155
                        << std::endl;
596
155
    return false;
597
  }
598
152
  else if (var.isNull() || var.getType().isInteger())
599
  {
600
    // cannot solve quadratic equations for integer variables
601
16
    Trace("nl-ext-cms") << "...fail due to variable to solve for." << std::endl;
602
16
    return false;
603
  }
604
605
  // we are linear, it is simple
606
136
  if (a == d_zero)
607
  {
608
127
    if (b == d_zero)
609
    {
610
      Trace("nl-ext-cms") << "...fail due to zero a/b." << std::endl;
611
      Assert(false);
612
      return false;
613
    }
614
254
    Node val = nm->mkConst(-c.getConst<Rational>() / b.getConst<Rational>());
615
127
    if (Trace.isOn("nl-ext-cm"))
616
    {
617
      Trace("nl-ext-cm") << "check-model-bound : exact : " << var << " = ";
618
      printRationalApprox("nl-ext-cm", val);
619
      Trace("nl-ext-cm") << std::endl;
620
    }
621
127
    bool ret = addCheckModelSubstitution(var, val);
622
127
    if (ret)
623
    {
624
127
      Trace("nl-ext-cms") << "...success, solved linear." << std::endl;
625
127
      d_check_model_solved[eq] = var;
626
    }
627
127
    return ret;
628
  }
629
9
  Trace("nl-ext-quad") << "Solve quadratic : " << seq << std::endl;
630
9
  Trace("nl-ext-quad") << "  a : " << a << std::endl;
631
9
  Trace("nl-ext-quad") << "  b : " << b << std::endl;
632
9
  Trace("nl-ext-quad") << "  c : " << c << std::endl;
633
18
  Node two_a = nm->mkNode(MULT, d_two, a);
634
9
  two_a = Rewriter::rewrite(two_a);
635
  Node sqrt_val = nm->mkNode(
636
18
      MINUS, nm->mkNode(MULT, b, b), nm->mkNode(MULT, d_two, two_a, c));
637
9
  sqrt_val = Rewriter::rewrite(sqrt_val);
638
9
  Trace("nl-ext-quad") << "Will approximate sqrt " << sqrt_val << std::endl;
639
9
  Assert(sqrt_val.isConst());
640
  // if it is negative, then we are in conflict
641
9
  if (sqrt_val.getConst<Rational>().sgn() == -1)
642
  {
643
    Node conf = seq.negate();
644
    Trace("nl-ext-lemma") << "NlModel::Lemma : quadratic no root : " << conf
645
                          << std::endl;
646
    lemmas.emplace_back(InferenceId::ARITH_NL_CM_QUADRATIC_EQ, conf);
647
    Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl;
648
    return false;
649
  }
650
9
  if (hasCheckModelAssignment(var))
651
  {
652
    Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for."
653
                        << std::endl;
654
    // two quadratic equations for same variable, give up
655
    return false;
656
  }
657
  // approximate the square root of sqrt_val
658
18
  Node l, u;
659
9
  if (!getApproximateSqrt(sqrt_val, l, u, 15 + d))
660
  {
661
    Trace("nl-ext-cms") << "...fail, could not approximate sqrt." << std::endl;
662
    return false;
663
  }
664
9
  d_used_approx = true;
665
18
  Trace("nl-ext-quad") << "...got " << l << " <= sqrt(" << sqrt_val
666
9
                       << ") <= " << u << std::endl;
667
18
  Node negb = nm->mkConst(-b.getConst<Rational>());
668
18
  Node coeffa = nm->mkConst(Rational(1) / two_a.getConst<Rational>());
669
  // two possible bound regions
670
18
  Node bounds[2][2];
671
18
  Node diff_bound[2];
672
18
  Node m_var = computeConcreteModelValue(var);
673
9
  Assert(m_var.isConst());
674
27
  for (unsigned r = 0; r < 2; r++)
675
  {
676
54
    for (unsigned b2 = 0; b2 < 2; b2++)
677
    {
678
72
      Node val = b2 == 0 ? l : u;
679
      // (-b +- approx_sqrt( b^2 - 4ac ))/2a
680
      Node approx = nm->mkNode(
681
72
          MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val));
682
36
      approx = Rewriter::rewrite(approx);
683
36
      bounds[r][b2] = approx;
684
36
      Assert(approx.isConst());
685
    }
686
18
    if (bounds[r][0].getConst<Rational>() > bounds[r][1].getConst<Rational>())
687
    {
688
      // ensure bound is (lower, upper)
689
18
      Node tmp = bounds[r][0];
690
9
      bounds[r][0] = bounds[r][1];
691
9
      bounds[r][1] = tmp;
692
    }
693
    Node diff =
694
        nm->mkNode(MINUS,
695
                   m_var,
696
72
                   nm->mkNode(MULT,
697
36
                              nm->mkConst(Rational(1) / Rational(2)),
698
72
                              nm->mkNode(PLUS, bounds[r][0], bounds[r][1])));
699
18
    if (Trace.isOn("nl-ext-cm-debug"))
700
    {
701
      Trace("nl-ext-cm-debug") << "Bound option #" << r << " : ";
702
      printRationalApprox("nl-ext-cm-debug", bounds[r][0]);
703
      Trace("nl-ext-cm-debug") << "...";
704
      printRationalApprox("nl-ext-cm-debug", bounds[r][1]);
705
      Trace("nl-ext-cm-debug") << std::endl;
706
    }
707
18
    diff = Rewriter::rewrite(diff);
708
18
    Assert(diff.isConst());
709
18
    diff = nm->mkConst(diff.getConst<Rational>().abs());
710
18
    diff_bound[r] = diff;
711
18
    if (Trace.isOn("nl-ext-cm-debug"))
712
    {
713
      Trace("nl-ext-cm-debug") << "...diff from model value (";
714
      printRationalApprox("nl-ext-cm-debug", m_var);
715
      Trace("nl-ext-cm-debug") << ") is ";
716
      printRationalApprox("nl-ext-cm-debug", diff_bound[r]);
717
      Trace("nl-ext-cm-debug") << std::endl;
718
    }
719
  }
720
  // take the one that var is closer to in the model
721
18
  Node cmp = nm->mkNode(GEQ, diff_bound[0], diff_bound[1]);
722
9
  cmp = Rewriter::rewrite(cmp);
723
9
  Assert(cmp.isConst());
724
9
  unsigned r_use_index = cmp == d_true ? 1 : 0;
725
9
  if (Trace.isOn("nl-ext-cm"))
726
  {
727
    Trace("nl-ext-cm") << "check-model-bound : approximate (sqrt) : ";
728
    printRationalApprox("nl-ext-cm", bounds[r_use_index][0]);
729
    Trace("nl-ext-cm") << " <= " << var << " <= ";
730
    printRationalApprox("nl-ext-cm", bounds[r_use_index][1]);
731
    Trace("nl-ext-cm") << std::endl;
732
  }
733
  bool ret =
734
9
      addCheckModelBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
735
9
  if (ret)
736
  {
737
9
    d_check_model_solved[eq] = var;
738
9
    Trace("nl-ext-cms") << "...success, solved quadratic." << std::endl;
739
  }
740
9
  return ret;
741
}
742
743
4494
bool NlModel::simpleCheckModelLit(Node lit)
744
{
745
8988
  Trace("nl-ext-cms") << "*** Simple check-model lit for " << lit << "..."
746
4494
                      << std::endl;
747
4494
  if (lit.isConst())
748
  {
749
2058
    Trace("nl-ext-cms") << "  return constant." << std::endl;
750
2058
    return lit.getConst<bool>();
751
  }
752
2436
  NodeManager* nm = NodeManager::currentNM();
753
2436
  bool pol = lit.getKind() != kind::NOT;
754
4872
  Node atom = lit.getKind() == kind::NOT ? lit[0] : lit;
755
756
2436
  if (atom.getKind() == EQUAL)
757
  {
758
    // x = a is ( x >= a ^ x <= a )
759
599
    for (unsigned i = 0; i < 2; i++)
760
    {
761
844
      Node lit2 = nm->mkNode(GEQ, atom[i], atom[1 - i]);
762
599
      if (!pol)
763
      {
764
519
        lit2 = lit2.negate();
765
      }
766
599
      lit2 = Rewriter::rewrite(lit2);
767
599
      bool success = simpleCheckModelLit(lit2);
768
599
      if (success != pol)
769
      {
770
        // false != true -> one conjunct of equality is false, we fail
771
        // true != false -> one disjunct of disequality is true, we succeed
772
354
        return success;
773
      }
774
    }
775
    // both checks passed and polarity is true, or both checks failed and
776
    // polarity is false
777
    return pol;
778
  }
779
2082
  else if (atom.getKind() != GEQ)
780
  {
781
    Trace("nl-ext-cms") << "  failed due to unknown literal." << std::endl;
782
    return false;
783
  }
784
  // get the monomial sum
785
4164
  std::map<Node, Node> msum;
786
2082
  if (!ArithMSum::getMonomialSumLit(atom, msum))
787
  {
788
    Trace("nl-ext-cms") << "  failed due to get msum." << std::endl;
789
    return false;
790
  }
791
  // simple interval analysis
792
2082
  if (simpleCheckModelMsum(msum, pol))
793
  {
794
1492
    return true;
795
  }
796
  // can also try reasoning about univariate quadratic equations
797
1180
  Trace("nl-ext-cms-debug")
798
590
      << "* Try univariate quadratic analysis..." << std::endl;
799
1180
  std::vector<Node> vs_invalid;
800
1180
  std::unordered_set<Node> vs;
801
1180
  std::map<Node, Node> v_a;
802
1180
  std::map<Node, Node> v_b;
803
  // get coefficients...
804
1760
  for (std::pair<const Node, Node>& m : msum)
805
  {
806
2340
    Node v = m.first;
807
1170
    if (!v.isNull())
808
    {
809
724
      if (v.isVar())
810
      {
811
98
        v_b[v] = m.second.isNull() ? d_one : m.second;
812
98
        vs.insert(v);
813
      }
814
1341
      else if (v.getKind() == NONLINEAR_MULT && v.getNumChildren() == 2
815
1341
               && v[0] == v[1] && v[0].isVar())
816
      {
817
89
        v_a[v[0]] = m.second.isNull() ? d_one : m.second;
818
89
        vs.insert(v[0]);
819
      }
820
      else
821
      {
822
537
        vs_invalid.push_back(v);
823
      }
824
    }
825
  }
826
  // solve the valid variables...
827
590
  Node invalid_vsum = vs_invalid.empty() ? d_zero
828
457
                                         : (vs_invalid.size() == 1
829
377
                                                ? vs_invalid[0]
830
2014
                                                : nm->mkNode(PLUS, vs_invalid));
831
  // substitution to try
832
1180
  std::vector<Node> qvars;
833
1180
  std::vector<Node> qsubs;
834
723
  for (const Node& v : vs)
835
  {
836
    // is it a valid variable?
837
    std::map<Node, std::pair<Node, Node>>::iterator bit =
838
133
        d_check_model_bounds.find(v);
839
133
    if (!expr::hasSubterm(invalid_vsum, v) && bit != d_check_model_bounds.end())
840
    {
841
133
      std::map<Node, Node>::iterator it = v_a.find(v);
842
133
      if (it != v_a.end())
843
      {
844
178
        Node a = it->second;
845
89
        Assert(a.isConst());
846
89
        int asgn = a.getConst<Rational>().sgn();
847
89
        Assert(asgn != 0);
848
178
        Node t = nm->mkNode(MULT, a, v, v);
849
178
        Node b = d_zero;
850
89
        it = v_b.find(v);
851
89
        if (it != v_b.end())
852
        {
853
54
          b = it->second;
854
54
          t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v));
855
        }
856
89
        t = Rewriter::rewrite(t);
857
178
        Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
858
89
                                  << t << "..." << std::endl;
859
89
        Trace("nl-ext-cms-debug") << "    a = " << a << std::endl;
860
89
        Trace("nl-ext-cms-debug") << "    b = " << b << std::endl;
861
        // find maximal/minimal value on the interval
862
        Node apex = nm->mkNode(
863
178
            DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
864
89
        apex = Rewriter::rewrite(apex);
865
89
        Assert(apex.isConst());
866
        // for lower, upper, whether we are greater than the apex
867
        bool cmp[2];
868
178
        Node boundn[2];
869
267
        for (unsigned r = 0; r < 2; r++)
870
        {
871
178
          boundn[r] = r == 0 ? bit->second.first : bit->second.second;
872
356
          Node cmpn = nm->mkNode(GT, boundn[r], apex);
873
178
          cmpn = Rewriter::rewrite(cmpn);
874
178
          Assert(cmpn.isConst());
875
178
          cmp[r] = cmpn.getConst<bool>();
876
        }
877
89
        Trace("nl-ext-cms-debug") << "  apex " << apex << std::endl;
878
178
        Trace("nl-ext-cms-debug")
879
89
            << "  lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
880
178
        Trace("nl-ext-cms-debug")
881
89
            << "  upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
882
89
        Assert(boundn[0].getConst<Rational>()
883
               <= boundn[1].getConst<Rational>());
884
178
        Node s;
885
89
        qvars.push_back(v);
886
89
        if (cmp[0] != cmp[1])
887
        {
888
          Assert(!cmp[0] && cmp[1]);
889
          // does the sign match the bound?
890
          if ((asgn == 1) == pol)
891
          {
892
            // the apex is the max/min value
893
            s = apex;
894
            Trace("nl-ext-cms-debug") << "  ...set to apex." << std::endl;
895
          }
896
          else
897
          {
898
            // it is one of the endpoints, plug in and compare
899
            Node tcmpn[2];
900
            for (unsigned r = 0; r < 2; r++)
901
            {
902
              qsubs.push_back(boundn[r]);
903
              Node ts = arithSubstitute(t, qvars, qsubs);
904
              tcmpn[r] = Rewriter::rewrite(ts);
905
              qsubs.pop_back();
906
            }
907
            Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
908
            Trace("nl-ext-cms-debug")
909
                << "  ...both sides of apex, compare " << tcmp << std::endl;
910
            tcmp = Rewriter::rewrite(tcmp);
911
            Assert(tcmp.isConst());
912
            unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
913
            Trace("nl-ext-cms-debug")
914
                << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
915
                << std::endl;
916
            s = boundn[bindex_use];
917
          }
918
        }
919
        else
920
        {
921
          // both to one side of the apex
922
          // we figure out which bound to use (lower or upper) based on
923
          // three factors:
924
          // (1) whether a's sign is positive,
925
          // (2) whether we are greater than the apex of the parabola,
926
          // (3) the polarity of the constraint, i.e. >= or <=.
927
          // there are 8 cases of these factors, which we test here.
928
89
          unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1;
929
178
          Trace("nl-ext-cms-debug")
930
89
              << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
931
89
              << std::endl;
932
89
          s = boundn[bindex_use];
933
        }
934
89
        Assert(!s.isNull());
935
89
        qsubs.push_back(s);
936
178
        Trace("nl-ext-cms") << "* set bound based on quadratic : " << v
937
89
                            << " -> " << s << std::endl;
938
      }
939
    }
940
  }
941
590
  if (!qvars.empty())
942
  {
943
89
    Assert(qvars.size() == qsubs.size());
944
178
    Node slit = arithSubstitute(lit, qvars, qsubs);
945
89
    slit = Rewriter::rewrite(slit);
946
89
    return simpleCheckModelLit(slit);
947
  }
948
501
  return false;
949
}
950
951
2082
bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol)
952
{
953
2082
  Trace("nl-ext-cms-debug") << "* Try simple interval analysis..." << std::endl;
954
2082
  NodeManager* nm = NodeManager::currentNM();
955
  // map from transcendental functions to whether they were set to lower
956
  // bound
957
2082
  bool simpleSuccess = true;
958
4164
  std::map<Node, bool> set_bound;
959
4164
  std::vector<Node> sum_bound;
960
6175
  for (const std::pair<const Node, Node>& m : msum)
961
  {
962
8240
    Node v = m.first;
963
4147
    if (v.isNull())
964
    {
965
1741
      sum_bound.push_back(m.second.isNull() ? d_one : m.second);
966
    }
967
    else
968
    {
969
2406
      Trace("nl-ext-cms-debug") << "- monomial : " << v << std::endl;
970
      // --- whether we should set a lower bound for this monomial
971
      bool set_lower =
972
2406
          (m.second.isNull() || m.second.getConst<Rational>().sgn() == 1)
973
2406
          == pol;
974
4812
      Trace("nl-ext-cms-debug")
975
2406
          << "set bound to " << (set_lower ? "lower" : "upper") << std::endl;
976
977
      // --- Collect variables and factors in v
978
4758
      std::vector<Node> vars;
979
4758
      std::vector<unsigned> factors;
980
2406
      if (v.getKind() == NONLINEAR_MULT)
981
      {
982
165
        unsigned last_start = 0;
983
495
        for (unsigned i = 0, nchildren = v.getNumChildren(); i < nchildren; i++)
984
        {
985
          // are we at the end?
986
330
          if (i + 1 == nchildren || v[i + 1] != v[i])
987
          {
988
165
            unsigned vfact = 1 + (i - last_start);
989
165
            last_start = (i + 1);
990
165
            vars.push_back(v[i]);
991
165
            factors.push_back(vfact);
992
          }
993
        }
994
      }
995
      else
996
      {
997
2241
        vars.push_back(v);
998
2241
        factors.push_back(1);
999
      }
1000
1001
      // --- Get the lower and upper bounds and sign information.
1002
      // Whether we have an (odd) number of negative factors in vars, apart
1003
      // from the variable at choose_index.
1004
2406
      bool has_neg_factor = false;
1005
2406
      int choose_index = -1;
1006
4758
      std::vector<Node> ls;
1007
4758
      std::vector<Node> us;
1008
      // the relevant sign information for variables with odd exponents:
1009
      //   1: both signs of the interval of this variable are positive,
1010
      //  -1: both signs of the interval of this variable are negative.
1011
4758
      std::vector<int> signs;
1012
2406
      Trace("nl-ext-cms-debug") << "get sign information..." << std::endl;
1013
4812
      for (unsigned i = 0, size = vars.size(); i < size; i++)
1014
      {
1015
4812
        Node vc = vars[i];
1016
2406
        unsigned vcfact = factors[i];
1017
2406
        if (Trace.isOn("nl-ext-cms-debug"))
1018
        {
1019
          Trace("nl-ext-cms-debug") << "-- " << vc;
1020
          if (vcfact > 1)
1021
          {
1022
            Trace("nl-ext-cms-debug") << "^" << vcfact;
1023
          }
1024
          Trace("nl-ext-cms-debug") << " ";
1025
        }
1026
        std::map<Node, std::pair<Node, Node>>::iterator bit =
1027
2406
            d_check_model_bounds.find(vc);
1028
        // if there is a model bound for this term
1029
2406
        if (bit != d_check_model_bounds.end())
1030
        {
1031
4812
          Node l = bit->second.first;
1032
4812
          Node u = bit->second.second;
1033
2406
          ls.push_back(l);
1034
2406
          us.push_back(u);
1035
2406
          int vsign = 0;
1036
2406
          if (vcfact % 2 == 1)
1037
          {
1038
2241
            vsign = 1;
1039
2241
            int lsgn = l.getConst<Rational>().sgn();
1040
2241
            int usgn = u.getConst<Rational>().sgn();
1041
4482
            Trace("nl-ext-cms-debug")
1042
2241
                << "bound_sign(" << lsgn << "," << usgn << ") ";
1043
2241
            if (lsgn == -1)
1044
            {
1045
345
              if (usgn < 1)
1046
              {
1047
                // must have a negative factor
1048
345
                has_neg_factor = !has_neg_factor;
1049
345
                vsign = -1;
1050
              }
1051
              else if (choose_index == -1)
1052
              {
1053
                // set the choose index to this
1054
                choose_index = i;
1055
                vsign = 0;
1056
              }
1057
              else
1058
              {
1059
                // ambiguous, can't determine the bound
1060
                Trace("nl-ext-cms")
1061
                    << "  failed due to ambiguious monomial." << std::endl;
1062
                return false;
1063
              }
1064
            }
1065
          }
1066
2406
          Trace("nl-ext-cms-debug") << " -> " << vsign << std::endl;
1067
2406
          signs.push_back(vsign);
1068
        }
1069
        else
1070
        {
1071
          Assert(d_check_model_witnesses.find(vc)
1072
                 == d_check_model_witnesses.end())
1073
              << "No variable should be assigned a witness term if we get "
1074
                 "here. "
1075
              << vc << " is, though." << std::endl;
1076
          Trace("nl-ext-cms-debug") << std::endl;
1077
          Trace("nl-ext-cms")
1078
              << "  failed due to unknown bound for " << vc << std::endl;
1079
          // should either assign a model bound or eliminate the variable
1080
          // via substitution
1081
          Assert(false);
1082
          return false;
1083
        }
1084
      }
1085
      // whether we will try to minimize/maximize (-1/1) the absolute value
1086
2406
      int setAbs = (set_lower == has_neg_factor) ? 1 : -1;
1087
4812
      Trace("nl-ext-cms-debug")
1088
2406
          << "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal")
1089
2406
          << std::endl;
1090
1091
4758
      std::vector<Node> vbs;
1092
2406
      Trace("nl-ext-cms-debug") << "set bounds..." << std::endl;
1093
4758
      for (unsigned i = 0, size = vars.size(); i < size; i++)
1094
      {
1095
4758
        Node vc = vars[i];
1096
2406
        unsigned vcfact = factors[i];
1097
4758
        Node l = ls[i];
1098
4758
        Node u = us[i];
1099
        bool vc_set_lower;
1100
2406
        int vcsign = signs[i];
1101
4812
        Trace("nl-ext-cms-debug")
1102
2406
            << "Bounds for " << vc << " : " << l << ", " << u
1103
2406
            << ", sign : " << vcsign << ", factor : " << vcfact << std::endl;
1104
2406
        if (l == u)
1105
        {
1106
          // by convention, always say it is lower if they are the same
1107
          vc_set_lower = true;
1108
          Trace("nl-ext-cms-debug")
1109
              << "..." << vc << " equal bound, set to lower" << std::endl;
1110
        }
1111
        else
1112
        {
1113
2406
          if (vcfact % 2 == 0)
1114
          {
1115
            // minimize or maximize its absolute value
1116
330
            Rational la = l.getConst<Rational>().abs();
1117
330
            Rational ua = u.getConst<Rational>().abs();
1118
165
            if (la == ua)
1119
            {
1120
              // by convention, always say it is lower if abs are the same
1121
              vc_set_lower = true;
1122
              Trace("nl-ext-cms-debug")
1123
                  << "..." << vc << " equal abs, set to lower" << std::endl;
1124
            }
1125
            else
1126
            {
1127
165
              vc_set_lower = (la > ua) == (setAbs == 1);
1128
            }
1129
          }
1130
2241
          else if (signs[i] == 0)
1131
          {
1132
            // we choose this index to match the overall set_lower
1133
            vc_set_lower = set_lower;
1134
          }
1135
          else
1136
          {
1137
2241
            vc_set_lower = (signs[i] != setAbs);
1138
          }
1139
4812
          Trace("nl-ext-cms-debug")
1140
2406
              << "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper")
1141
2406
              << std::endl;
1142
        }
1143
        // check whether this is a conflicting bound
1144
2406
        std::map<Node, bool>::iterator itsb = set_bound.find(vc);
1145
2406
        if (itsb == set_bound.end())
1146
        {
1147
2336
          set_bound[vc] = vc_set_lower;
1148
        }
1149
70
        else if (itsb->second != vc_set_lower)
1150
        {
1151
108
          Trace("nl-ext-cms")
1152
54
              << "  failed due to conflicting bound for " << vc << std::endl;
1153
54
          return false;
1154
        }
1155
        // must over/under approximate based on vc_set_lower, computed above
1156
4704
        Node vb = vc_set_lower ? l : u;
1157
4815
        for (unsigned i2 = 0; i2 < vcfact; i2++)
1158
        {
1159
2463
          vbs.push_back(vb);
1160
        }
1161
      }
1162
2352
      if (!simpleSuccess)
1163
      {
1164
        break;
1165
      }
1166
4704
      Node vbound = vbs.size() == 1 ? vbs[0] : nm->mkNode(MULT, vbs);
1167
2352
      sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound));
1168
    }
1169
  }
1170
  // if the exact bound was computed via simple analysis above
1171
  // make the bound
1172
4056
  Node bound;
1173
2028
  if (sum_bound.size() > 1)
1174
  {
1175
1843
    bound = nm->mkNode(kind::PLUS, sum_bound);
1176
  }
1177
185
  else if (sum_bound.size() == 1)
1178
  {
1179
185
    bound = sum_bound[0];
1180
  }
1181
  else
1182
  {
1183
    bound = d_zero;
1184
  }
1185
  // make the comparison
1186
4056
  Node comp = nm->mkNode(kind::GEQ, bound, d_zero);
1187
2028
  if (!pol)
1188
  {
1189
1270
    comp = comp.negate();
1190
  }
1191
2028
  Trace("nl-ext-cms") << "  comparison is : " << comp << std::endl;
1192
2028
  comp = Rewriter::rewrite(comp);
1193
2028
  Assert(comp.isConst());
1194
2028
  Trace("nl-ext-cms") << "  returned : " << comp << std::endl;
1195
2028
  return comp == d_true;
1196
}
1197
1198
9
bool NlModel::getApproximateSqrt(Node c, Node& l, Node& u, unsigned iter) const
1199
{
1200
9
  Assert(c.isConst());
1201
9
  if (c == d_one || c == d_zero)
1202
  {
1203
    l = c;
1204
    u = c;
1205
    return true;
1206
  }
1207
18
  Rational rc = c.getConst<Rational>();
1208
1209
18
  Rational rl = rc < Rational(1) ? rc : Rational(1);
1210
18
  Rational ru = rc < Rational(1) ? Rational(1) : rc;
1211
9
  unsigned count = 0;
1212
18
  Rational half = Rational(1) / Rational(2);
1213
351
  while (count < iter)
1214
  {
1215
342
    Rational curr = half * (rl + ru);
1216
342
    Rational curr_sq = curr * curr;
1217
171
    if (curr_sq == rc)
1218
    {
1219
      rl = curr;
1220
      ru = curr;
1221
      break;
1222
    }
1223
171
    else if (curr_sq < rc)
1224
    {
1225
81
      rl = curr;
1226
    }
1227
    else
1228
    {
1229
90
      ru = curr;
1230
    }
1231
171
    count++;
1232
  }
1233
1234
9
  NodeManager* nm = NodeManager::currentNM();
1235
9
  l = nm->mkConst(rl);
1236
9
  u = nm->mkConst(ru);
1237
9
  return true;
1238
}
1239
1240
103012
void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
1241
{
1242
103012
  if (Trace.isOn(c))
1243
  {
1244
    Trace(c) << "  " << n << " -> ";
1245
    for (int i = 1; i >= 0; --i)
1246
    {
1247
      std::map<Node, Node>::const_iterator it = d_mv[i].find(n);
1248
      Assert(it != d_mv[i].end());
1249
      if (it->second.isConst())
1250
      {
1251
        printRationalApprox(c, it->second, prec);
1252
      }
1253
      else
1254
      {
1255
        Trace(c) << "?";
1256
      }
1257
      Trace(c) << (i == 1 ? " [actual: " : " ]");
1258
    }
1259
    Trace(c) << std::endl;
1260
  }
1261
103012
}
1262
1263
502
void NlModel::getModelValueRepair(
1264
    std::map<Node, Node>& arithModel,
1265
    std::map<Node, std::pair<Node, Node>>& approximations,
1266
    std::map<Node, Node>& witnesses,
1267
    bool witnessToValue)
1268
{
1269
502
  Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl;
1270
  // If we extended the model with entries x -> 0 for unconstrained values,
1271
  // we first update the map to the extended one.
1272
502
  if (d_arithVal.size() > arithModel.size())
1273
  {
1274
4
    arithModel = d_arithVal;
1275
  }
1276
  // Record the approximations we used. This code calls the
1277
  // recordApproximation method of the model, which overrides the model
1278
  // values for variables that we solved for, using techniques specific to
1279
  // this class.
1280
502
  NodeManager* nm = NodeManager::currentNM();
1281
22
  for (const std::pair<const Node, std::pair<Node, Node>>& cb :
1282
502
       d_check_model_bounds)
1283
  {
1284
44
    Node l = cb.second.first;
1285
44
    Node u = cb.second.second;
1286
44
    Node pred;
1287
44
    Node v = cb.first;
1288
22
    if (l != u)
1289
    {
1290
22
      pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v));
1291
22
      Trace("nl-model") << v << " approximated as " << pred << std::endl;
1292
44
      Node witness;
1293
22
      if (witnessToValue)
1294
      {
1295
        // witness is the midpoint
1296
        witness = nm->mkNode(
1297
            MULT, nm->mkConst(Rational(1, 2)), nm->mkNode(PLUS, l, u));
1298
        witness = Rewriter::rewrite(witness);
1299
        Trace("nl-model") << v << " witness is " << witness << std::endl;
1300
      }
1301
22
      approximations[v] = std::pair<Node, Node>(pred, witness);
1302
    }
1303
    else
1304
    {
1305
      // overwrite
1306
      arithModel[v] = l;
1307
      Trace("nl-model") << v << " exact approximation is " << l << std::endl;
1308
    }
1309
  }
1310
509
  for (const auto& vw : d_check_model_witnesses)
1311
  {
1312
7
    Trace("nl-model") << vw.first << " witness is " << vw.second << std::endl;
1313
7
    witnesses.emplace(vw.first, vw.second);
1314
  }
1315
  // Also record the exact values we used. An exact value can be seen as a
1316
  // special kind approximation of the form (witness x. x = exact_value).
1317
  // Notice that the above term gets rewritten such that the choice function
1318
  // is eliminated.
1319
746
  for (size_t i = 0, num = d_check_model_vars.size(); i < num; i++)
1320
  {
1321
488
    Node v = d_check_model_vars[i];
1322
488
    Node s = d_check_model_subs[i];
1323
    // overwrite
1324
244
    arithModel[v] = s;
1325
244
    Trace("nl-model") << v << " solved is " << s << std::endl;
1326
  }
1327
1328
  // multiplication terms should not be given values; their values are
1329
  // implied by the monomials that they consist of
1330
1004
  std::vector<Node> amErase;
1331
10534
  for (const std::pair<const Node, Node>& am : arithModel)
1332
  {
1333
10032
    if (am.first.getKind() == NONLINEAR_MULT)
1334
    {
1335
2020
      amErase.push_back(am.first);
1336
    }
1337
  }
1338
2522
  for (const Node& ae : amErase)
1339
  {
1340
2020
    arithModel.erase(ae);
1341
  }
1342
502
}
1343
1344
}  // namespace nl
1345
}  // namespace arith
1346
}  // namespace theory
1347
29577
}  // namespace cvc5