GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/nl_model.cpp Lines: 596 745 80.0 %
Date: 2021-03-22 Branches: 1408 3549 39.7 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file nl_model.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Gereon Kremer, Tim King
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Model object for the non-linear extension class
13
 **/
14
15
#include "theory/arith/nl/nl_model.h"
16
17
#include "expr/node_algorithm.h"
18
#include "options/arith_options.h"
19
#include "options/smt_options.h"
20
#include "options/theory_options.h"
21
#include "theory/arith/arith_msum.h"
22
#include "theory/arith/arith_utilities.h"
23
#include "theory/arith/nl/nl_lemma_utils.h"
24
#include "theory/theory_model.h"
25
#include "theory/rewriter.h"
26
27
using namespace CVC4::kind;
28
29
namespace CVC4 {
30
namespace theory {
31
namespace arith {
32
namespace nl {
33
34
4389
NlModel::NlModel(context::Context* c) : d_used_approx(false)
35
{
36
4389
  d_true = NodeManager::currentNM()->mkConst(true);
37
4389
  d_false = NodeManager::currentNM()->mkConst(false);
38
4389
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
39
4389
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
40
4389
  d_two = NodeManager::currentNM()->mkConst(Rational(2));
41
4389
}
42
43
4386
NlModel::~NlModel() {}
44
45
2928
void NlModel::reset(TheoryModel* m, std::map<Node, Node>& arithModel)
46
{
47
2928
  d_model = m;
48
2928
  d_mv[0].clear();
49
2928
  d_mv[1].clear();
50
2928
  d_arithVal.clear();
51
  // process arithModel
52
2928
  std::map<Node, Node>::iterator it;
53
54873
  for (const std::pair<const Node, Node>& m2 : arithModel)
54
  {
55
51945
    d_arithVal[m2.first] = m2.second;
56
  }
57
2928
}
58
59
2959
void NlModel::resetCheck()
60
{
61
2959
  d_used_approx = false;
62
2959
  d_check_model_solved.clear();
63
2959
  d_check_model_bounds.clear();
64
2959
  d_check_model_witnesses.clear();
65
2959
  d_check_model_vars.clear();
66
2959
  d_check_model_subs.clear();
67
2959
}
68
69
356933
Node NlModel::computeConcreteModelValue(Node n)
70
{
71
356933
  return computeModelValue(n, true);
72
}
73
74
219959
Node NlModel::computeAbstractModelValue(Node n)
75
{
76
219959
  return computeModelValue(n, false);
77
}
78
79
1792951
Node NlModel::computeModelValue(Node n, bool isConcrete)
80
{
81
1792951
  unsigned index = isConcrete ? 0 : 1;
82
1792951
  std::map<Node, Node>::iterator it = d_mv[index].find(n);
83
1792951
  if (it != d_mv[index].end())
84
  {
85
1123142
    return it->second;
86
  }
87
1339618
  Trace("nl-ext-mv-debug") << "computeModelValue " << n << ", index=" << index
88
669809
                           << std::endl;
89
1339618
  Node ret;
90
669809
  Kind nk = n.getKind();
91
669809
  if (n.isConst())
92
  {
93
31874
    ret = n;
94
  }
95
637935
  else if (!isConcrete && hasTerm(n))
96
  {
97
    // use model value for abstraction
98
42774
    ret = getRepresentative(n);
99
  }
100
595161
  else if (n.getNumChildren() == 0)
101
  {
102
    // we are interested in the exact value of PI, which cannot be computed.
103
    // hence, we return PI itself when asked for the concrete value.
104
22652
    if (nk == PI)
105
    {
106
469
      ret = n;
107
    }
108
    else
109
    {
110
22183
      ret = getValueInternal(n);
111
    }
112
  }
113
  else
114
  {
115
    // otherwise, compute true value
116
572509
    TheoryId ctid = theory::kindToTheoryId(nk);
117
572509
    if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN)
118
    {
119
      // we directly look up terms not belonging to arithmetic
120
8155
      ret = getValueInternal(n);
121
    }
122
    else
123
    {
124
1128708
      std::vector<Node> children;
125
564354
      if (n.getMetaKind() == metakind::PARAMETERIZED)
126
      {
127
187
        children.push_back(n.getOperator());
128
      }
129
1605130
      for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
130
      {
131
2081552
        Node mc = computeModelValue(n[i], isConcrete);
132
1040776
        children.push_back(mc);
133
      }
134
564354
      ret = NodeManager::currentNM()->mkNode(nk, children);
135
564354
      ret = Rewriter::rewrite(ret);
136
    }
137
  }
138
1339618
  Trace("nl-ext-mv-debug") << "computed " << (index == 0 ? "M" : "M_A") << "["
139
669809
                           << n << "] = " << ret << std::endl;
140
669809
  d_mv[index][n] = ret;
141
669809
  return ret;
142
}
143
144
157025
bool NlModel::hasTerm(Node n) const
145
{
146
157025
  return d_arithVal.find(n) != d_arithVal.end();
147
}
148
149
42774
Node NlModel::getRepresentative(Node n) const
150
{
151
42774
  if (n.isConst())
152
  {
153
    return n;
154
  }
155
42774
  std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
156
42774
  if (it != d_arithVal.end())
157
  {
158
42774
    AlwaysAssert(it->second.isConst());
159
42774
    return it->second;
160
  }
161
  return d_model->getRepresentative(n);
162
}
163
164
30338
Node NlModel::getValueInternal(Node n)
165
{
166
30338
  if (n.isConst())
167
  {
168
    return n;
169
  }
170
30338
  std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
171
30338
  if (it != d_arithVal.end())
172
  {
173
30306
    AlwaysAssert(it->second.isConst());
174
30306
    return it->second;
175
  }
176
  // It is unconstrained in the model, return 0. We additionally add it
177
  // to mapping from the linear solver. This ensures that if the nonlinear
178
  // solver assumes that n = 0, then this assumption is recorded in the overall
179
  // model.
180
32
  d_arithVal[n] = d_zero;
181
32
  return d_zero;
182
}
183
184
82178
int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
185
{
186
164356
  Node ci = computeModelValue(i, isConcrete);
187
164356
  Node cj = computeModelValue(j, isConcrete);
188
82178
  if (ci.isConst())
189
  {
190
82178
    if (cj.isConst())
191
    {
192
82178
      return compareValue(ci, cj, isAbsolute);
193
    }
194
    return 1;
195
  }
196
  return cj.isConst() ? -1 : 0;
197
}
198
199
93105
int NlModel::compareValue(Node i, Node j, bool isAbsolute) const
200
{
201
93105
  Assert(i.isConst() && j.isConst());
202
  int ret;
203
93105
  if (i == j)
204
  {
205
11179
    ret = 0;
206
  }
207
81926
  else if (!isAbsolute)
208
  {
209
9151
    ret = i.getConst<Rational>() < j.getConst<Rational>() ? 1 : -1;
210
  }
211
  else
212
  {
213
218325
    ret = (i.getConst<Rational>().abs() == j.getConst<Rational>().abs()
214
142313
               ? 0
215
281389
               : (i.getConst<Rational>().abs() < j.getConst<Rational>().abs()
216
208614
                      ? 1
217
                      : -1));
218
  }
219
93105
  return ret;
220
}
221
222
222
bool NlModel::checkModel(const std::vector<Node>& assertions,
223
                         unsigned d,
224
                         std::vector<NlLemma>& lemmas)
225
{
226
222
  Trace("nl-ext-cm-debug") << "  solve for equalities..." << std::endl;
227
3112
  for (const Node& atom : assertions)
228
  {
229
    // see if it corresponds to a univariate polynomial equation of degree two
230
2890
    if (atom.getKind() == EQUAL)
231
    {
232
368
      if (!solveEqualitySimple(atom, d, lemmas))
233
      {
234
        // no chance we will satisfy this equality
235
230
        Trace("nl-ext-cm") << "...check-model : failed to solve equality : "
236
115
                           << atom << std::endl;
237
      }
238
    }
239
  }
240
241
  // all remaining variables are constrained to their exact model values
242
444
  Trace("nl-ext-cm-debug") << "  set exact bounds for remaining variables..."
243
222
                           << std::endl;
244
444
  std::unordered_set<TNode, TNodeHashFunction> visited;
245
444
  std::vector<TNode> visit;
246
444
  TNode cur;
247
3112
  for (const Node& a : assertions)
248
  {
249
2890
    visit.push_back(a);
250
11699
    do
251
    {
252
14589
      cur = visit.back();
253
14589
      visit.pop_back();
254
14589
      if (visited.find(cur) == visited.end())
255
      {
256
8842
        visited.insert(cur);
257
8842
        if (cur.getType().isReal() && !cur.isConst())
258
        {
259
2544
          Kind k = cur.getKind();
260
4282
          if (k != MULT && k != PLUS && k != NONLINEAR_MULT
261
3258
              && !isTranscendentalKind(k))
262
          {
263
            // if we have not set an approximate bound for it
264
420
            if (!hasCheckModelAssignment(cur))
265
            {
266
              // set its exact model value in the substitution
267
370
              Node curv = computeConcreteModelValue(cur);
268
185
              if (Trace.isOn("nl-ext-cm"))
269
              {
270
                Trace("nl-ext-cm")
271
                    << "check-model-bound : exact : " << cur << " = ";
272
                printRationalApprox("nl-ext-cm", curv);
273
                Trace("nl-ext-cm") << std::endl;
274
              }
275
185
              bool ret = addCheckModelSubstitution(cur, curv);
276
185
              AlwaysAssert(ret);
277
            }
278
          }
279
        }
280
20541
        for (const Node& cn : cur)
281
        {
282
11699
          visit.push_back(cn);
283
        }
284
      }
285
14589
    } while (!visit.empty());
286
  }
287
288
222
  Trace("nl-ext-cm-debug") << "  check assertions..." << std::endl;
289
444
  std::vector<Node> check_assertions;
290
3112
  for (const Node& a : assertions)
291
  {
292
2890
    if (d_check_model_solved.find(a) == d_check_model_solved.end())
293
    {
294
5274
      Node av = a;
295
      // apply the substitution to a
296
2637
      if (!d_check_model_vars.empty())
297
      {
298
2087
        av = arithSubstitute(av, d_check_model_vars, d_check_model_subs);
299
2087
        av = Rewriter::rewrite(av);
300
      }
301
      // simple check literal
302
2637
      if (!simpleCheckModelLit(av))
303
      {
304
742
        Trace("nl-ext-cm") << "...check-model : assertion failed : " << a
305
371
                           << std::endl;
306
371
        check_assertions.push_back(av);
307
742
        Trace("nl-ext-cm-debug")
308
371
            << "...check-model : failed assertion, value : " << av << std::endl;
309
      }
310
    }
311
  }
312
313
222
  if (!check_assertions.empty())
314
  {
315
185
    Trace("nl-ext-cm") << "...simple check failed." << std::endl;
316
    // TODO (#1450) check model for general case
317
185
    return false;
318
  }
319
37
  Trace("nl-ext-cm") << "...simple check succeeded!" << std::endl;
320
37
  return true;
321
}
322
323
385
bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
324
{
325
  // should not substitute the same variable twice
326
770
  Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s
327
385
                        << std::endl;
328
  // should not set exact bound more than once
329
1155
  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
330
1155
      != d_check_model_vars.end())
331
  {
332
    Trace("nl-ext-model") << "...ERROR: already has value." << std::endl;
333
    // this should never happen since substitutions should be applied eagerly
334
    Assert(false);
335
    return false;
336
  }
337
  // if we previously had an approximate bound, the exact bound should be in its
338
  // range
339
  std::map<Node, std::pair<Node, Node>>::iterator itb =
340
385
      d_check_model_bounds.find(v);
341
385
  if (itb != d_check_model_bounds.end())
342
  {
343
    if (s.getConst<Rational>() >= itb->second.first.getConst<Rational>()
344
        || s.getConst<Rational>() <= itb->second.second.getConst<Rational>())
345
    {
346
      Trace("nl-ext-model")
347
          << "...ERROR: already has bound which is out of range." << std::endl;
348
      return false;
349
    }
350
  }
351
385
  Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end())
352
      << "We tried to add a substitution where we already had a witness term."
353
      << std::endl;
354
770
  std::vector<Node> varsTmp;
355
385
  varsTmp.push_back(v);
356
770
  std::vector<Node> subsTmp;
357
385
  subsTmp.push_back(s);
358
1173
  for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++)
359
  {
360
1576
    Node ms = d_check_model_subs[i];
361
1576
    Node mss = arithSubstitute(ms, varsTmp, subsTmp);
362
788
    if (mss != ms)
363
    {
364
27
      mss = Rewriter::rewrite(mss);
365
    }
366
788
    d_check_model_subs[i] = mss;
367
  }
368
385
  d_check_model_vars.push_back(v);
369
385
  d_check_model_subs.push_back(s);
370
385
  return true;
371
}
372
373
425
bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u)
374
{
375
850
  Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " "
376
425
                        << u << "]" << std::endl;
377
425
  if (l == u)
378
  {
379
    // bound is exact, can add as substitution
380
2
    return addCheckModelSubstitution(v, l);
381
  }
382
  // should not set a bound for a value that is exact
383
1269
  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
384
1269
      != d_check_model_vars.end())
385
  {
386
    Trace("nl-ext-model")
387
        << "...ERROR: setting bound for variable that already has exact value."
388
        << std::endl;
389
    Assert(false);
390
    return false;
391
  }
392
423
  Assert(l.isConst());
393
423
  Assert(u.isConst());
394
423
  Assert(l.getConst<Rational>() <= u.getConst<Rational>());
395
423
  d_check_model_bounds[v] = std::pair<Node, Node>(l, u);
396
423
  if (Trace.isOn("nl-ext-cm"))
397
  {
398
    Trace("nl-ext-cm") << "check-model-bound : approximate : ";
399
    printRationalApprox("nl-ext-cm", l);
400
    Trace("nl-ext-cm") << " <= " << v << " <= ";
401
    printRationalApprox("nl-ext-cm", u);
402
    Trace("nl-ext-cm") << std::endl;
403
  }
404
423
  return true;
405
}
406
407
bool NlModel::addCheckModelWitness(TNode v, TNode w)
408
{
409
  Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w
410
                        << std::endl;
411
  // should not set a witness for a value that is already set
412
  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
413
      != d_check_model_vars.end())
414
  {
415
    Trace("nl-ext-model") << "...ERROR: setting witness for variable that "
416
                             "already has a constant value."
417
                          << std::endl;
418
    Assert(false);
419
    return false;
420
  }
421
  d_check_model_witnesses.emplace(v, w);
422
  return true;
423
}
424
425
576
bool NlModel::hasCheckModelAssignment(Node v) const
426
{
427
576
  if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
428
  {
429
35
    return true;
430
  }
431
541
  if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end())
432
  {
433
    return true;
434
  }
435
1082
  return std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
436
1623
         != d_check_model_vars.end();
437
}
438
439
271
void NlModel::setUsedApproximate() { d_used_approx = true; }
440
441
33
bool NlModel::usedApproximate() const { return d_used_approx; }
442
443
440
bool NlModel::solveEqualitySimple(Node eq,
444
                                  unsigned d,
445
                                  std::vector<NlLemma>& lemmas)
446
{
447
880
  Node seq = eq;
448
440
  if (!d_check_model_vars.empty())
449
  {
450
320
    seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs);
451
320
    seq = Rewriter::rewrite(seq);
452
320
    if (seq.isConst())
453
    {
454
90
      if (seq.getConst<bool>())
455
      {
456
        // already true
457
90
        d_check_model_solved[eq] = Node::null();
458
90
        return true;
459
      }
460
      return false;
461
    }
462
  }
463
350
  Trace("nl-ext-cms") << "simple solve equality " << seq << "..." << std::endl;
464
350
  Assert(seq.getKind() == EQUAL);
465
700
  std::map<Node, Node> msum;
466
350
  if (!ArithMSum::getMonomialSumLit(seq, msum))
467
  {
468
    Trace("nl-ext-cms") << "...fail, could not determine monomial sum."
469
                        << std::endl;
470
    return false;
471
  }
472
350
  bool is_valid = true;
473
  // the variable we will solve a quadratic equation for
474
700
  Node var;
475
700
  Node a = d_zero;
476
700
  Node b = d_zero;
477
700
  Node c = d_zero;
478
350
  NodeManager* nm = NodeManager::currentNM();
479
  // the list of variables that occur as a monomial in msum, and whose value
480
  // is so far unconstrained in the model.
481
700
  std::unordered_set<Node, NodeHashFunction> unc_vars;
482
  // the list of variables that occur as a factor in a monomial, and whose
483
  // value is so far unconstrained in the model.
484
700
  std::unordered_set<Node, NodeHashFunction> unc_vars_factor;
485
991
  for (std::pair<const Node, Node>& m : msum)
486
  {
487
1282
    Node v = m.first;
488
1282
    Node coeff = m.second.isNull() ? d_one : m.second;
489
641
    if (v.isNull())
490
    {
491
133
      c = coeff;
492
    }
493
508
    else if (v.getKind() == NONLINEAR_MULT)
494
    {
495
346
      if (v.getNumChildren() == 2 && v[0].isVar() && v[0] == v[1]
496
277
          && (var.isNull() || var == v[0]))
497
      {
498
        // may solve quadratic
499
40
        a = coeff;
500
40
        var = v[0];
501
      }
502
      else
503
      {
504
76
        is_valid = false;
505
152
        Trace("nl-ext-cms-debug")
506
76
            << "...invalid due to non-linear monomial " << v << std::endl;
507
        // may wish to set an exact bound for a factor and repeat
508
230
        for (const Node& vc : v)
509
        {
510
154
          unc_vars_factor.insert(vc);
511
        }
512
      }
513
    }
514
392
    else if (!v.isVar() || (!var.isNull() && var != v))
515
    {
516
478
      Trace("nl-ext-cms-debug")
517
239
          << "...invalid due to factor " << v << std::endl;
518
      // cannot solve multivariate
519
239
      if (is_valid)
520
      {
521
157
        is_valid = false;
522
        // if b is non-zero, then var is also an unconstrained variable
523
157
        if (b != d_zero)
524
        {
525
47
          unc_vars.insert(var);
526
47
          unc_vars_factor.insert(var);
527
        }
528
      }
529
      // if v is unconstrained, we may turn this equality into a substitution
530
239
      unc_vars.insert(v);
531
239
      unc_vars_factor.insert(v);
532
    }
533
    else
534
    {
535
      // set the variable to solve for
536
153
      b = coeff;
537
153
      var = v;
538
    }
539
  }
540
350
  if (!is_valid)
541
  {
542
    // see if we can solve for a variable?
543
445
    for (const Node& uv : unc_vars)
544
    {
545
263
      Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl;
546
      // cannot already have a bound
547
263
      if (uv.isVar() && !hasCheckModelAssignment(uv))
548
      {
549
47
        Node slv;
550
47
        Node veqc;
551
47
        if (ArithMSum::isolate(uv, msum, veqc, slv, EQUAL) != 0)
552
        {
553
47
          Assert(!slv.isNull());
554
          // Currently do not support substitution-with-coefficients.
555
          // We also ensure types are correct here, which avoids substituting
556
          // a term of non-integer type for a variable of integer type.
557
141
          if (veqc.isNull() && !expr::hasSubterm(slv, uv)
558
141
              && slv.getType().isSubtypeOf(uv.getType()))
559
          {
560
94
            Trace("nl-ext-cm")
561
47
                << "check-model-subs : " << uv << " -> " << slv << std::endl;
562
47
            bool ret = addCheckModelSubstitution(uv, slv);
563
47
            if (ret)
564
            {
565
94
              Trace("nl-ext-cms") << "...success, model substitution " << uv
566
47
                                  << " -> " << slv << std::endl;
567
47
              d_check_model_solved[eq] = uv;
568
            }
569
47
            return ret;
570
          }
571
        }
572
      }
573
    }
574
    // see if we can assign a variable to a constant
575
380
    for (const Node& uvf : unc_vars_factor)
576
    {
577
270
      Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl;
578
      // cannot already have a bound
579
270
      if (uvf.isVar() && !hasCheckModelAssignment(uvf))
580
      {
581
144
        Node uvfv = computeConcreteModelValue(uvf);
582
72
        if (Trace.isOn("nl-ext-cm"))
583
        {
584
          Trace("nl-ext-cm") << "check-model-bound : exact : " << uvf << " = ";
585
          printRationalApprox("nl-ext-cm", uvfv);
586
          Trace("nl-ext-cm") << std::endl;
587
        }
588
72
        bool ret = addCheckModelSubstitution(uvf, uvfv);
589
        // recurse
590
72
        return ret ? solveEqualitySimple(eq, d, lemmas) : false;
591
      }
592
    }
593
220
    Trace("nl-ext-cms") << "...fail due to constrained invalid terms."
594
110
                        << std::endl;
595
110
    return false;
596
  }
597
121
  else if (var.isNull() || var.getType().isInteger())
598
  {
599
    // cannot solve quadratic equations for integer variables
600
5
    Trace("nl-ext-cms") << "...fail due to variable to solve for." << std::endl;
601
5
    return false;
602
  }
603
604
  // we are linear, it is simple
605
116
  if (a == d_zero)
606
  {
607
79
    if (b == d_zero)
608
    {
609
      Trace("nl-ext-cms") << "...fail due to zero a/b." << std::endl;
610
      Assert(false);
611
      return false;
612
    }
613
158
    Node val = nm->mkConst(-c.getConst<Rational>() / b.getConst<Rational>());
614
79
    if (Trace.isOn("nl-ext-cm"))
615
    {
616
      Trace("nl-ext-cm") << "check-model-bound : exact : " << var << " = ";
617
      printRationalApprox("nl-ext-cm", val);
618
      Trace("nl-ext-cm") << std::endl;
619
    }
620
79
    bool ret = addCheckModelSubstitution(var, val);
621
79
    if (ret)
622
    {
623
79
      Trace("nl-ext-cms") << "...success, solved linear." << std::endl;
624
79
      d_check_model_solved[eq] = var;
625
    }
626
79
    return ret;
627
  }
628
37
  Trace("nl-ext-quad") << "Solve quadratic : " << seq << std::endl;
629
37
  Trace("nl-ext-quad") << "  a : " << a << std::endl;
630
37
  Trace("nl-ext-quad") << "  b : " << b << std::endl;
631
37
  Trace("nl-ext-quad") << "  c : " << c << std::endl;
632
74
  Node two_a = nm->mkNode(MULT, d_two, a);
633
37
  two_a = Rewriter::rewrite(two_a);
634
  Node sqrt_val = nm->mkNode(
635
74
      MINUS, nm->mkNode(MULT, b, b), nm->mkNode(MULT, d_two, two_a, c));
636
37
  sqrt_val = Rewriter::rewrite(sqrt_val);
637
37
  Trace("nl-ext-quad") << "Will approximate sqrt " << sqrt_val << std::endl;
638
37
  Assert(sqrt_val.isConst());
639
  // if it is negative, then we are in conflict
640
37
  if (sqrt_val.getConst<Rational>().sgn() == -1)
641
  {
642
    Node conf = seq.negate();
643
    Trace("nl-ext-lemma") << "NlModel::Lemma : quadratic no root : " << conf
644
                          << std::endl;
645
    lemmas.emplace_back(InferenceId::ARITH_NL_CM_QUADRATIC_EQ, conf);
646
    Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl;
647
    return false;
648
  }
649
37
  if (hasCheckModelAssignment(var))
650
  {
651
    Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for."
652
                        << std::endl;
653
    // two quadratic equations for same variable, give up
654
    return false;
655
  }
656
  // approximate the square root of sqrt_val
657
74
  Node l, u;
658
37
  if (!getApproximateSqrt(sqrt_val, l, u, 15 + d))
659
  {
660
    Trace("nl-ext-cms") << "...fail, could not approximate sqrt." << std::endl;
661
    return false;
662
  }
663
37
  d_used_approx = true;
664
74
  Trace("nl-ext-quad") << "...got " << l << " <= sqrt(" << sqrt_val
665
37
                       << ") <= " << u << std::endl;
666
74
  Node negb = nm->mkConst(-b.getConst<Rational>());
667
74
  Node coeffa = nm->mkConst(Rational(1) / two_a.getConst<Rational>());
668
  // two possible bound regions
669
74
  Node bounds[2][2];
670
74
  Node diff_bound[2];
671
74
  Node m_var = computeConcreteModelValue(var);
672
37
  Assert(m_var.isConst());
673
111
  for (unsigned r = 0; r < 2; r++)
674
  {
675
222
    for (unsigned b2 = 0; b2 < 2; b2++)
676
    {
677
296
      Node val = b2 == 0 ? l : u;
678
      // (-b +- approx_sqrt( b^2 - 4ac ))/2a
679
      Node approx = nm->mkNode(
680
296
          MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val));
681
148
      approx = Rewriter::rewrite(approx);
682
148
      bounds[r][b2] = approx;
683
148
      Assert(approx.isConst());
684
    }
685
74
    if (bounds[r][0].getConst<Rational>() > bounds[r][1].getConst<Rational>())
686
    {
687
      // ensure bound is (lower, upper)
688
70
      Node tmp = bounds[r][0];
689
35
      bounds[r][0] = bounds[r][1];
690
35
      bounds[r][1] = tmp;
691
    }
692
    Node diff =
693
        nm->mkNode(MINUS,
694
                   m_var,
695
296
                   nm->mkNode(MULT,
696
148
                              nm->mkConst(Rational(1) / Rational(2)),
697
296
                              nm->mkNode(PLUS, bounds[r][0], bounds[r][1])));
698
74
    if (Trace.isOn("nl-ext-cm-debug"))
699
    {
700
      Trace("nl-ext-cm-debug") << "Bound option #" << r << " : ";
701
      printRationalApprox("nl-ext-cm-debug", bounds[r][0]);
702
      Trace("nl-ext-cm-debug") << "...";
703
      printRationalApprox("nl-ext-cm-debug", bounds[r][1]);
704
      Trace("nl-ext-cm-debug") << std::endl;
705
    }
706
74
    diff = Rewriter::rewrite(diff);
707
74
    Assert(diff.isConst());
708
74
    diff = nm->mkConst(diff.getConst<Rational>().abs());
709
74
    diff_bound[r] = diff;
710
74
    if (Trace.isOn("nl-ext-cm-debug"))
711
    {
712
      Trace("nl-ext-cm-debug") << "...diff from model value (";
713
      printRationalApprox("nl-ext-cm-debug", m_var);
714
      Trace("nl-ext-cm-debug") << ") is ";
715
      printRationalApprox("nl-ext-cm-debug", diff_bound[r]);
716
      Trace("nl-ext-cm-debug") << std::endl;
717
    }
718
  }
719
  // take the one that var is closer to in the model
720
74
  Node cmp = nm->mkNode(GEQ, diff_bound[0], diff_bound[1]);
721
37
  cmp = Rewriter::rewrite(cmp);
722
37
  Assert(cmp.isConst());
723
37
  unsigned r_use_index = cmp == d_true ? 1 : 0;
724
37
  if (Trace.isOn("nl-ext-cm"))
725
  {
726
    Trace("nl-ext-cm") << "check-model-bound : approximate (sqrt) : ";
727
    printRationalApprox("nl-ext-cm", bounds[r_use_index][0]);
728
    Trace("nl-ext-cm") << " <= " << var << " <= ";
729
    printRationalApprox("nl-ext-cm", bounds[r_use_index][1]);
730
    Trace("nl-ext-cm") << std::endl;
731
  }
732
  bool ret =
733
37
      addCheckModelBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
734
37
  if (ret)
735
  {
736
37
    d_check_model_solved[eq] = var;
737
37
    Trace("nl-ext-cms") << "...success, solved quadratic." << std::endl;
738
  }
739
37
  return ret;
740
}
741
742
3604
bool NlModel::simpleCheckModelLit(Node lit)
743
{
744
7208
  Trace("nl-ext-cms") << "*** Simple check-model lit for " << lit << "..."
745
3604
                      << std::endl;
746
3604
  if (lit.isConst())
747
  {
748
1007
    Trace("nl-ext-cms") << "  return constant." << std::endl;
749
1007
    return lit.getConst<bool>();
750
  }
751
2597
  NodeManager* nm = NodeManager::currentNM();
752
2597
  bool pol = lit.getKind() != kind::NOT;
753
5194
  Node atom = lit.getKind() == kind::NOT ? lit[0] : lit;
754
755
2597
  if (atom.getKind() == EQUAL)
756
  {
757
    // x = a is ( x >= a ^ x <= a )
758
825
    for (unsigned i = 0; i < 2; i++)
759
    {
760
1178
      Node lit2 = nm->mkNode(GEQ, atom[i], atom[1 - i]);
761
825
      if (!pol)
762
      {
763
745
        lit2 = lit2.negate();
764
      }
765
825
      lit2 = Rewriter::rewrite(lit2);
766
825
      bool success = simpleCheckModelLit(lit2);
767
825
      if (success != pol)
768
      {
769
        // false != true -> one conjunct of equality is false, we fail
770
        // true != false -> one disjunct of disequality is true, we succeed
771
472
        return success;
772
      }
773
    }
774
    // both checks passed and polarity is true, or both checks failed and
775
    // polarity is false
776
    return pol;
777
  }
778
2125
  else if (atom.getKind() != GEQ)
779
  {
780
    Trace("nl-ext-cms") << "  failed due to unknown literal." << std::endl;
781
    return false;
782
  }
783
  // get the monomial sum
784
4250
  std::map<Node, Node> msum;
785
2125
  if (!ArithMSum::getMonomialSumLit(atom, msum))
786
  {
787
    Trace("nl-ext-cms") << "  failed due to get msum." << std::endl;
788
    return false;
789
  }
790
  // simple interval analysis
791
2125
  if (simpleCheckModelMsum(msum, pol))
792
  {
793
1418
    return true;
794
  }
795
  // can also try reasoning about univariate quadratic equations
796
1414
  Trace("nl-ext-cms-debug")
797
707
      << "* Try univariate quadratic analysis..." << std::endl;
798
1414
  std::vector<Node> vs_invalid;
799
1414
  std::unordered_set<Node, NodeHashFunction> vs;
800
1414
  std::map<Node, Node> v_a;
801
1414
  std::map<Node, Node> v_b;
802
  // get coefficients...
803
2097
  for (std::pair<const Node, Node>& m : msum)
804
  {
805
2780
    Node v = m.first;
806
1390
    if (!v.isNull())
807
    {
808
840
      if (v.isVar())
809
      {
810
168
        v_b[v] = m.second.isNull() ? d_one : m.second;
811
168
        vs.insert(v);
812
      }
813
1486
      else if (v.getKind() == NONLINEAR_MULT && v.getNumChildren() == 2
814
1486
               && v[0] == v[1] && v[0].isVar())
815
      {
816
142
        v_a[v[0]] = m.second.isNull() ? d_one : m.second;
817
142
        vs.insert(v[0]);
818
      }
819
      else
820
      {
821
530
        vs_invalid.push_back(v);
822
      }
823
    }
824
  }
825
  // solve the valid variables...
826
707
  Node invalid_vsum = vs_invalid.empty() ? d_zero
827
450
                                         : (vs_invalid.size() == 1
828
370
                                                ? vs_invalid[0]
829
2234
                                                : nm->mkNode(PLUS, vs_invalid));
830
  // substitution to try
831
1414
  std::vector<Node> qvars;
832
1414
  std::vector<Node> qsubs;
833
964
  for (const Node& v : vs)
834
  {
835
    // is it a valid variable?
836
    std::map<Node, std::pair<Node, Node>>::iterator bit =
837
257
        d_check_model_bounds.find(v);
838
257
    if (!expr::hasSubterm(invalid_vsum, v) && bit != d_check_model_bounds.end())
839
    {
840
257
      std::map<Node, Node>::iterator it = v_a.find(v);
841
257
      if (it != v_a.end())
842
      {
843
284
        Node a = it->second;
844
142
        Assert(a.isConst());
845
142
        int asgn = a.getConst<Rational>().sgn();
846
142
        Assert(asgn != 0);
847
284
        Node t = nm->mkNode(MULT, a, v, v);
848
284
        Node b = d_zero;
849
142
        it = v_b.find(v);
850
142
        if (it != v_b.end())
851
        {
852
53
          b = it->second;
853
53
          t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v));
854
        }
855
142
        t = Rewriter::rewrite(t);
856
284
        Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
857
142
                                  << t << "..." << std::endl;
858
142
        Trace("nl-ext-cms-debug") << "    a = " << a << std::endl;
859
142
        Trace("nl-ext-cms-debug") << "    b = " << b << std::endl;
860
        // find maximal/minimal value on the interval
861
        Node apex = nm->mkNode(
862
284
            DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
863
142
        apex = Rewriter::rewrite(apex);
864
142
        Assert(apex.isConst());
865
        // for lower, upper, whether we are greater than the apex
866
        bool cmp[2];
867
284
        Node boundn[2];
868
426
        for (unsigned r = 0; r < 2; r++)
869
        {
870
284
          boundn[r] = r == 0 ? bit->second.first : bit->second.second;
871
568
          Node cmpn = nm->mkNode(GT, boundn[r], apex);
872
284
          cmpn = Rewriter::rewrite(cmpn);
873
284
          Assert(cmpn.isConst());
874
284
          cmp[r] = cmpn.getConst<bool>();
875
        }
876
142
        Trace("nl-ext-cms-debug") << "  apex " << apex << std::endl;
877
284
        Trace("nl-ext-cms-debug")
878
142
            << "  lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
879
284
        Trace("nl-ext-cms-debug")
880
142
            << "  upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
881
142
        Assert(boundn[0].getConst<Rational>()
882
               <= boundn[1].getConst<Rational>());
883
284
        Node s;
884
142
        qvars.push_back(v);
885
142
        if (cmp[0] != cmp[1])
886
        {
887
          Assert(!cmp[0] && cmp[1]);
888
          // does the sign match the bound?
889
          if ((asgn == 1) == pol)
890
          {
891
            // the apex is the max/min value
892
            s = apex;
893
            Trace("nl-ext-cms-debug") << "  ...set to apex." << std::endl;
894
          }
895
          else
896
          {
897
            // it is one of the endpoints, plug in and compare
898
            Node tcmpn[2];
899
            for (unsigned r = 0; r < 2; r++)
900
            {
901
              qsubs.push_back(boundn[r]);
902
              Node ts = arithSubstitute(t, qvars, qsubs);
903
              tcmpn[r] = Rewriter::rewrite(ts);
904
              qsubs.pop_back();
905
            }
906
            Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
907
            Trace("nl-ext-cms-debug")
908
                << "  ...both sides of apex, compare " << tcmp << std::endl;
909
            tcmp = Rewriter::rewrite(tcmp);
910
            Assert(tcmp.isConst());
911
            unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
912
            Trace("nl-ext-cms-debug")
913
                << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
914
                << std::endl;
915
            s = boundn[bindex_use];
916
          }
917
        }
918
        else
919
        {
920
          // both to one side of the apex
921
          // we figure out which bound to use (lower or upper) based on
922
          // three factors:
923
          // (1) whether a's sign is positive,
924
          // (2) whether we are greater than the apex of the parabola,
925
          // (3) the polarity of the constraint, i.e. >= or <=.
926
          // there are 8 cases of these factors, which we test here.
927
142
          unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1;
928
284
          Trace("nl-ext-cms-debug")
929
142
              << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
930
142
              << std::endl;
931
142
          s = boundn[bindex_use];
932
        }
933
142
        Assert(!s.isNull());
934
142
        qsubs.push_back(s);
935
284
        Trace("nl-ext-cms") << "* set bound based on quadratic : " << v
936
142
                            << " -> " << s << std::endl;
937
      }
938
    }
939
  }
940
707
  if (!qvars.empty())
941
  {
942
142
    Assert(qvars.size() == qsubs.size());
943
284
    Node slit = arithSubstitute(lit, qvars, qsubs);
944
142
    slit = Rewriter::rewrite(slit);
945
142
    return simpleCheckModelLit(slit);
946
  }
947
565
  return false;
948
}
949
950
2125
bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol)
951
{
952
2125
  Trace("nl-ext-cms-debug") << "* Try simple interval analysis..." << std::endl;
953
2125
  NodeManager* nm = NodeManager::currentNM();
954
  // map from transcendental functions to whether they were set to lower
955
  // bound
956
2125
  bool simpleSuccess = true;
957
4250
  std::map<Node, bool> set_bound;
958
4250
  std::vector<Node> sum_bound;
959
6331
  for (const std::pair<const Node, Node>& m : msum)
960
  {
961
8465
    Node v = m.first;
962
4259
    if (v.isNull())
963
    {
964
1817
      sum_bound.push_back(m.second.isNull() ? d_one : m.second);
965
    }
966
    else
967
    {
968
2442
      Trace("nl-ext-cms-debug") << "- monomial : " << v << std::endl;
969
      // --- whether we should set a lower bound for this monomial
970
      bool set_lower =
971
2442
          (m.second.isNull() || m.second.getConst<Rational>().sgn() == 1)
972
2442
          == pol;
973
4884
      Trace("nl-ext-cms-debug")
974
2442
          << "set bound to " << (set_lower ? "lower" : "upper") << std::endl;
975
976
      // --- Collect variables and factors in v
977
4831
      std::vector<Node> vars;
978
4831
      std::vector<unsigned> factors;
979
2442
      if (v.getKind() == NONLINEAR_MULT)
980
      {
981
252
        unsigned last_start = 0;
982
756
        for (unsigned i = 0, nchildren = v.getNumChildren(); i < nchildren; i++)
983
        {
984
          // are we at the end?
985
504
          if (i + 1 == nchildren || v[i + 1] != v[i])
986
          {
987
252
            unsigned vfact = 1 + (i - last_start);
988
252
            last_start = (i + 1);
989
252
            vars.push_back(v[i]);
990
252
            factors.push_back(vfact);
991
          }
992
        }
993
      }
994
      else
995
      {
996
2190
        vars.push_back(v);
997
2190
        factors.push_back(1);
998
      }
999
1000
      // --- Get the lower and upper bounds and sign information.
1001
      // Whether we have an (odd) number of negative factors in vars, apart
1002
      // from the variable at choose_index.
1003
2442
      bool has_neg_factor = false;
1004
2442
      int choose_index = -1;
1005
4831
      std::vector<Node> ls;
1006
4831
      std::vector<Node> us;
1007
      // the relevant sign information for variables with odd exponents:
1008
      //   1: both signs of the interval of this variable are positive,
1009
      //  -1: both signs of the interval of this variable are negative.
1010
4831
      std::vector<int> signs;
1011
2442
      Trace("nl-ext-cms-debug") << "get sign information..." << std::endl;
1012
4884
      for (unsigned i = 0, size = vars.size(); i < size; i++)
1013
      {
1014
4884
        Node vc = vars[i];
1015
2442
        unsigned vcfact = factors[i];
1016
2442
        if (Trace.isOn("nl-ext-cms-debug"))
1017
        {
1018
          Trace("nl-ext-cms-debug") << "-- " << vc;
1019
          if (vcfact > 1)
1020
          {
1021
            Trace("nl-ext-cms-debug") << "^" << vcfact;
1022
          }
1023
          Trace("nl-ext-cms-debug") << " ";
1024
        }
1025
        std::map<Node, std::pair<Node, Node>>::iterator bit =
1026
2442
            d_check_model_bounds.find(vc);
1027
        // if there is a model bound for this term
1028
2442
        if (bit != d_check_model_bounds.end())
1029
        {
1030
4884
          Node l = bit->second.first;
1031
4884
          Node u = bit->second.second;
1032
2442
          ls.push_back(l);
1033
2442
          us.push_back(u);
1034
2442
          int vsign = 0;
1035
2442
          if (vcfact % 2 == 1)
1036
          {
1037
2190
            vsign = 1;
1038
2190
            int lsgn = l.getConst<Rational>().sgn();
1039
2190
            int usgn = u.getConst<Rational>().sgn();
1040
4380
            Trace("nl-ext-cms-debug")
1041
2190
                << "bound_sign(" << lsgn << "," << usgn << ") ";
1042
2190
            if (lsgn == -1)
1043
            {
1044
349
              if (usgn < 1)
1045
              {
1046
                // must have a negative factor
1047
349
                has_neg_factor = !has_neg_factor;
1048
349
                vsign = -1;
1049
              }
1050
              else if (choose_index == -1)
1051
              {
1052
                // set the choose index to this
1053
                choose_index = i;
1054
                vsign = 0;
1055
              }
1056
              else
1057
              {
1058
                // ambiguous, can't determine the bound
1059
                Trace("nl-ext-cms")
1060
                    << "  failed due to ambiguious monomial." << std::endl;
1061
                return false;
1062
              }
1063
            }
1064
          }
1065
2442
          Trace("nl-ext-cms-debug") << " -> " << vsign << std::endl;
1066
2442
          signs.push_back(vsign);
1067
        }
1068
        else
1069
        {
1070
          Assert(d_check_model_witnesses.find(vc)
1071
                 == d_check_model_witnesses.end())
1072
              << "No variable should be assigned a witness term if we get "
1073
                 "here. "
1074
              << vc << " is, though." << std::endl;
1075
          Trace("nl-ext-cms-debug") << std::endl;
1076
          Trace("nl-ext-cms")
1077
              << "  failed due to unknown bound for " << vc << std::endl;
1078
          // should either assign a model bound or eliminate the variable
1079
          // via substitution
1080
          Assert(false);
1081
          return false;
1082
        }
1083
      }
1084
      // whether we will try to minimize/maximize (-1/1) the absolute value
1085
2442
      int setAbs = (set_lower == has_neg_factor) ? 1 : -1;
1086
4884
      Trace("nl-ext-cms-debug")
1087
2442
          << "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal")
1088
2442
          << std::endl;
1089
1090
4831
      std::vector<Node> vbs;
1091
2442
      Trace("nl-ext-cms-debug") << "set bounds..." << std::endl;
1092
4831
      for (unsigned i = 0, size = vars.size(); i < size; i++)
1093
      {
1094
4831
        Node vc = vars[i];
1095
2442
        unsigned vcfact = factors[i];
1096
4831
        Node l = ls[i];
1097
4831
        Node u = us[i];
1098
        bool vc_set_lower;
1099
2442
        int vcsign = signs[i];
1100
4884
        Trace("nl-ext-cms-debug")
1101
2442
            << "Bounds for " << vc << " : " << l << ", " << u
1102
2442
            << ", sign : " << vcsign << ", factor : " << vcfact << std::endl;
1103
2442
        if (l == u)
1104
        {
1105
          // by convention, always say it is lower if they are the same
1106
          vc_set_lower = true;
1107
          Trace("nl-ext-cms-debug")
1108
              << "..." << vc << " equal bound, set to lower" << std::endl;
1109
        }
1110
        else
1111
        {
1112
2442
          if (vcfact % 2 == 0)
1113
          {
1114
            // minimize or maximize its absolute value
1115
504
            Rational la = l.getConst<Rational>().abs();
1116
504
            Rational ua = u.getConst<Rational>().abs();
1117
252
            if (la == ua)
1118
            {
1119
              // by convention, always say it is lower if abs are the same
1120
              vc_set_lower = true;
1121
              Trace("nl-ext-cms-debug")
1122
                  << "..." << vc << " equal abs, set to lower" << std::endl;
1123
            }
1124
            else
1125
            {
1126
252
              vc_set_lower = (la > ua) == (setAbs == 1);
1127
            }
1128
          }
1129
2190
          else if (signs[i] == 0)
1130
          {
1131
            // we choose this index to match the overall set_lower
1132
            vc_set_lower = set_lower;
1133
          }
1134
          else
1135
          {
1136
2190
            vc_set_lower = (signs[i] != setAbs);
1137
          }
1138
4884
          Trace("nl-ext-cms-debug")
1139
2442
              << "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper")
1140
2442
              << std::endl;
1141
        }
1142
        // check whether this is a conflicting bound
1143
2442
        std::map<Node, bool>::iterator itsb = set_bound.find(vc);
1144
2442
        if (itsb == set_bound.end())
1145
        {
1146
2379
          set_bound[vc] = vc_set_lower;
1147
        }
1148
63
        else if (itsb->second != vc_set_lower)
1149
        {
1150
106
          Trace("nl-ext-cms")
1151
53
              << "  failed due to conflicting bound for " << vc << std::endl;
1152
53
          return false;
1153
        }
1154
        // must over/under approximate based on vc_set_lower, computed above
1155
4778
        Node vb = vc_set_lower ? l : u;
1156
4977
        for (unsigned i2 = 0; i2 < vcfact; i2++)
1157
        {
1158
2588
          vbs.push_back(vb);
1159
        }
1160
      }
1161
2389
      if (!simpleSuccess)
1162
      {
1163
        break;
1164
      }
1165
4778
      Node vbound = vbs.size() == 1 ? vbs[0] : nm->mkNode(MULT, vbs);
1166
2389
      sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound));
1167
    }
1168
  }
1169
  // if the exact bound was computed via simple analysis above
1170
  // make the bound
1171
4144
  Node bound;
1172
2072
  if (sum_bound.size() > 1)
1173
  {
1174
1917
    bound = nm->mkNode(kind::PLUS, sum_bound);
1175
  }
1176
155
  else if (sum_bound.size() == 1)
1177
  {
1178
155
    bound = sum_bound[0];
1179
  }
1180
  else
1181
  {
1182
    bound = d_zero;
1183
  }
1184
  // make the comparison
1185
4144
  Node comp = nm->mkNode(kind::GEQ, bound, d_zero);
1186
2072
  if (!pol)
1187
  {
1188
1469
    comp = comp.negate();
1189
  }
1190
2072
  Trace("nl-ext-cms") << "  comparison is : " << comp << std::endl;
1191
2072
  comp = Rewriter::rewrite(comp);
1192
2072
  Assert(comp.isConst());
1193
2072
  Trace("nl-ext-cms") << "  returned : " << comp << std::endl;
1194
2072
  return comp == d_true;
1195
}
1196
1197
37
bool NlModel::getApproximateSqrt(Node c, Node& l, Node& u, unsigned iter) const
1198
{
1199
37
  Assert(c.isConst());
1200
37
  if (c == d_one || c == d_zero)
1201
  {
1202
    l = c;
1203
    u = c;
1204
    return true;
1205
  }
1206
74
  Rational rc = c.getConst<Rational>();
1207
1208
74
  Rational rl = rc < Rational(1) ? rc : Rational(1);
1209
74
  Rational ru = rc < Rational(1) ? Rational(1) : rc;
1210
37
  unsigned count = 0;
1211
74
  Rational half = Rational(1) / Rational(2);
1212
1677
  while (count < iter)
1213
  {
1214
1642
    Rational curr = half * (rl + ru);
1215
1642
    Rational curr_sq = curr * curr;
1216
822
    if (curr_sq == rc)
1217
    {
1218
2
      rl = curr;
1219
2
      ru = curr;
1220
2
      break;
1221
    }
1222
820
    else if (curr_sq < rc)
1223
    {
1224
371
      rl = curr;
1225
    }
1226
    else
1227
    {
1228
449
      ru = curr;
1229
    }
1230
820
    count++;
1231
  }
1232
1233
37
  NodeManager* nm = NodeManager::currentNM();
1234
37
  l = nm->mkConst(rl);
1235
37
  u = nm->mkConst(ru);
1236
37
  return true;
1237
}
1238
1239
50038
void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
1240
{
1241
50038
  if (Trace.isOn(c))
1242
  {
1243
    Trace(c) << "  " << n << " -> ";
1244
    for (int i = 1; i >= 0; --i)
1245
    {
1246
      std::map<Node, Node>::const_iterator it = d_mv[i].find(n);
1247
      Assert(it != d_mv[i].end());
1248
      if (it->second.isConst())
1249
      {
1250
        printRationalApprox(c, it->second, prec);
1251
      }
1252
      else
1253
      {
1254
        Trace(c) << "?";
1255
      }
1256
      Trace(c) << (i == 1 ? " [actual: " : " ]");
1257
    }
1258
    Trace(c) << std::endl;
1259
  }
1260
50038
}
1261
1262
256
void NlModel::getModelValueRepair(
1263
    std::map<Node, Node>& arithModel,
1264
    std::map<Node, std::pair<Node, Node>>& approximations,
1265
    std::map<Node, Node>& witnesses)
1266
{
1267
256
  Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl;
1268
  // If we extended the model with entries x -> 0 for unconstrained values,
1269
  // we first update the map to the extended one.
1270
256
  if (d_arithVal.size() > arithModel.size())
1271
  {
1272
4
    arithModel = d_arithVal;
1273
  }
1274
  // Record the approximations we used. This code calls the
1275
  // recordApproximation method of the model, which overrides the model
1276
  // values for variables that we solved for, using techniques specific to
1277
  // this class.
1278
256
  NodeManager* nm = NodeManager::currentNM();
1279
35
  for (const std::pair<const Node, std::pair<Node, Node>>& cb :
1280
256
       d_check_model_bounds)
1281
  {
1282
70
    Node l = cb.second.first;
1283
70
    Node u = cb.second.second;
1284
70
    Node pred;
1285
70
    Node v = cb.first;
1286
35
    if (l != u)
1287
    {
1288
35
      pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v));
1289
35
      Trace("nl-model") << v << " approximated as " << pred << std::endl;
1290
70
      Node witness;
1291
72
      if (options::modelWitnessValue())
1292
      {
1293
        // witness is the midpoint
1294
3
        witness = nm->mkNode(
1295
4
            MULT, nm->mkConst(Rational(1, 2)), nm->mkNode(PLUS, l, u));
1296
1
        witness = Rewriter::rewrite(witness);
1297
1
        Trace("nl-model") << v << " witness is " << witness << std::endl;
1298
      }
1299
35
      approximations[v] = std::pair<Node, Node>(pred, witness);
1300
    }
1301
    else
1302
    {
1303
      // overwrite
1304
      arithModel[v] = l;
1305
      Trace("nl-model") << v << " exact approximation is " << l << std::endl;
1306
    }
1307
  }
1308
256
  for (const auto& vw : d_check_model_witnesses)
1309
  {
1310
    Trace("nl-model") << vw.first << " witness is " << vw.second << std::endl;
1311
    witnesses.emplace(vw.first, vw.second);
1312
  }
1313
  // Also record the exact values we used. An exact value can be seen as a
1314
  // special kind approximation of the form (witness x. x = exact_value).
1315
  // Notice that the above term gets rewritten such that the choice function
1316
  // is eliminated.
1317
410
  for (size_t i = 0, num = d_check_model_vars.size(); i < num; i++)
1318
  {
1319
308
    Node v = d_check_model_vars[i];
1320
308
    Node s = d_check_model_subs[i];
1321
    // overwrite
1322
154
    arithModel[v] = s;
1323
154
    Trace("nl-model") << v << " solved is " << s << std::endl;
1324
  }
1325
1326
  // multiplication terms should not be given values; their values are
1327
  // implied by the monomials that they consist of
1328
512
  std::vector<Node> amErase;
1329
6709
  for (const std::pair<const Node, Node>& am : arithModel)
1330
  {
1331
6453
    if (am.first.getKind() == NONLINEAR_MULT)
1332
    {
1333
1283
      amErase.push_back(am.first);
1334
    }
1335
  }
1336
1539
  for (const Node& ae : amErase)
1337
  {
1338
1283
    arithModel.erase(ae);
1339
  }
1340
256
}
1341
1342
}  // namespace nl
1343
}  // namespace arith
1344
}  // namespace theory
1345
26713
}  // namespace CVC4