GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/icp/icp_solver.cpp Lines: 158 200 79.0 %
Date: 2021-05-22 Branches: 249 613 40.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Gereon Kremer, Andres Noetzli
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implements a ICP-based solver for nonlinear arithmetic.
14
 */
15
16
#include "theory/arith/nl/icp/icp_solver.h"
17
18
#include <iostream>
19
20
#include "base/check.h"
21
#include "base/output.h"
22
#include "expr/node_algorithm.h"
23
#include "theory/arith/arith_msum.h"
24
#include "theory/arith/inference_manager.h"
25
#include "theory/arith/nl/poly_conversion.h"
26
#include "theory/arith/normal_form.h"
27
#include "theory/rewriter.h"
28
#include "util/poly_util.h"
29
30
namespace cvc5 {
31
namespace theory {
32
namespace arith {
33
namespace nl {
34
namespace icp {
35
36
#ifdef CVC5_POLY_IMP
37
38
namespace {
39
/** A simple wrapper to nicely print an interval assignment. */
40
struct IAWrapper
41
{
42
  const poly::IntervalAssignment& ia;
43
  const VariableMapper& vm;
44
};
45
inline std::ostream& operator<<(std::ostream& os, const IAWrapper& iaw)
46
{
47
  os << "{ ";
48
  bool first = true;
49
  for (const auto& v : iaw.vm.mVarpolyCVC)
50
  {
51
    if (iaw.ia.has(v.first))
52
    {
53
      if (first)
54
      {
55
        first = false;
56
      }
57
      else
58
      {
59
        os << ", ";
60
      }
61
      os << v.first << " -> " << iaw.ia.get(v.first);
62
    }
63
  }
64
  return os << " }";
65
}
66
}  // namespace
67
68
18
std::vector<Node> ICPSolver::collectVariables(const Node& n) const
69
{
70
36
  std::unordered_set<TNode> tmp;
71
18
  expr::getVariables(n, tmp);
72
18
  std::vector<Node> res;
73
38
  for (const auto& t : tmp)
74
  {
75
20
    res.emplace_back(t);
76
  }
77
36
  return res;
78
}
79
80
80
std::vector<Candidate> ICPSolver::constructCandidates(const Node& n)
81
{
82
160
  Node tmp = Rewriter::rewrite(n);
83
80
  if (tmp.isConst())
84
  {
85
    return {};
86
  }
87
160
  auto comp = Comparison::parseNormalForm(tmp).decompose(false);
88
80
  Kind k = std::get<1>(comp);
89
80
  if (k == Kind::DISTINCT)
90
  {
91
54
    return {};
92
  }
93
52
  auto poly = std::get<0>(comp);
94
95
52
  std::vector<Candidate> result;
96
52
  std::unordered_set<TNode> vars;
97
26
  expr::getVariables(n, vars);
98
68
  for (const auto& v : vars)
99
  {
100
42
    Trace("nl-icp") << "\tChecking " << n << " for " << v << std::endl;
101
102
84
    std::map<Node, Node> msum;
103
42
    ArithMSum::getMonomialSum(poly.getNode(), msum);
104
105
84
    Node veq_c;
106
84
    Node val;
107
108
42
    int isolated = ArithMSum::isolate(v, msum, veq_c, val, k);
109
42
    if (isolated == 1)
110
    {
111
10
      poly::Variable lhs = d_mapper(v);
112
10
      poly::SignCondition rel = poly::SignCondition::EQ;
113
10
      switch (k)
114
      {
115
6
        case Kind::LT: rel = poly::SignCondition::LT; break;
116
        case Kind::LEQ: rel = poly::SignCondition::LE; break;
117
4
        case Kind::EQUAL: rel = poly::SignCondition::EQ; break;
118
        case Kind::DISTINCT: rel = poly::SignCondition::NE; break;
119
        case Kind::GT: rel = poly::SignCondition::GT; break;
120
        case Kind::GEQ: rel = poly::SignCondition::GE; break;
121
        default: Assert(false) << "Unexpected kind: " << k;
122
      }
123
20
      poly::Rational rhsmult;
124
20
      poly::Polynomial rhs = as_poly_polynomial(val, d_mapper, rhsmult);
125
10
      rhsmult = poly::Rational(1) / rhsmult;
126
      // only correct up to a constant (denominator is thrown away!)
127
10
      if (!veq_c.isNull())
128
      {
129
        rhsmult = poly_utils::toRational(veq_c.getConst<Rational>());
130
      }
131
20
      Candidate res{lhs, rel, rhs, rhsmult, n, collectVariables(val)};
132
10
      Trace("nl-icp") << "\tAdded " << res << " from " << n << std::endl;
133
10
      result.emplace_back(res);
134
    }
135
32
    else if (isolated == -1)
136
    {
137
8
      poly::Variable lhs = d_mapper(v);
138
8
      poly::SignCondition rel = poly::SignCondition::EQ;
139
8
      switch (k)
140
      {
141
8
        case Kind::LT: rel = poly::SignCondition::GT; break;
142
        case Kind::LEQ: rel = poly::SignCondition::GE; break;
143
        case Kind::EQUAL: rel = poly::SignCondition::EQ; break;
144
        case Kind::DISTINCT: rel = poly::SignCondition::NE; break;
145
        case Kind::GT: rel = poly::SignCondition::LT; break;
146
        case Kind::GEQ: rel = poly::SignCondition::LE; break;
147
        default: Assert(false) << "Unexpected kind: " << k;
148
      }
149
16
      poly::Rational rhsmult;
150
16
      poly::Polynomial rhs = as_poly_polynomial(val, d_mapper, rhsmult);
151
8
      rhsmult = poly::Rational(1) / rhsmult;
152
8
      if (!veq_c.isNull())
153
      {
154
        rhsmult = poly_utils::toRational(veq_c.getConst<Rational>());
155
      }
156
16
      Candidate res{lhs, rel, rhs, rhsmult, n, collectVariables(val)};
157
8
      Trace("nl-icp") << "\tAdded " << res << " from " << n << std::endl;
158
8
      result.emplace_back(res);
159
    }
160
  }
161
26
  return result;
162
}
163
164
522
void ICPSolver::addCandidate(const Node& n)
165
{
166
522
  auto it = d_candidateCache.find(n);
167
522
  if (it != d_candidateCache.end())
168
  {
169
528
    for (const auto& c : it->second)
170
    {
171
86
      d_state.d_candidates.emplace_back(c);
172
    }
173
  }
174
  else
175
  {
176
160
    auto cands = constructCandidates(n);
177
80
    d_candidateCache.emplace(n, cands);
178
98
    for (const auto& c : cands)
179
    {
180
18
      d_state.d_candidates.emplace_back(c);
181
36
      Trace("nl-icp") << "Bumping budget because of the new candidate"
182
18
                      << std::endl;
183
18
      d_budget += d_budgetIncrement;
184
    }
185
  }
186
522
}
187
188
32
void ICPSolver::initOrigins()
189
{
190
92
  for (const auto& vars : d_state.d_bounds.get())
191
  {
192
60
    const Bounds& i = vars.second;
193
120
    Trace("nl-icp") << "Adding initial " << vars.first << " -> " << i
194
60
                    << std::endl;
195
60
    if (!i.lower_origin.isNull())
196
    {
197
48
      Trace("nl-icp") << "\tAdding lower " << i.lower_origin << std::endl;
198
48
      d_state.d_origins.add(vars.first, i.lower_origin, {});
199
    }
200
60
    if (!i.upper_origin.isNull())
201
    {
202
44
      Trace("nl-icp") << "\tAdding upper " << i.upper_origin << std::endl;
203
44
      d_state.d_origins.add(vars.first, i.upper_origin, {});
204
    }
205
  }
206
32
}
207
208
36
PropagationResult ICPSolver::doPropagationRound()
209
{
210
36
  if (d_budget <= 0)
211
  {
212
    Trace("nl-icp") << "ICP budget exceeded" << std::endl;
213
    return PropagationResult::NOT_CHANGED;
214
  }
215
36
  d_state.d_conflict.clear();
216
72
  Trace("nl-icp") << "Starting propagation with "
217
36
                  << IAWrapper{d_state.d_assignment, d_mapper} << std::endl;
218
36
  Trace("nl-icp") << "Current budget: " << d_budget << std::endl;
219
36
  PropagationResult res = PropagationResult::NOT_CHANGED;
220
154
  for (const auto& c : d_state.d_candidates)
221
  {
222
118
    --d_budget;
223
118
    PropagationResult cres = c.propagate(d_state.d_assignment, 100);
224
118
    switch (cres)
225
    {
226
114
      case PropagationResult::NOT_CHANGED: break;
227
2
      case PropagationResult::CONTRACTED:
228
      case PropagationResult::CONTRACTED_STRONGLY:
229
2
        d_state.d_origins.add(d_mapper(c.lhs), c.origin, c.rhsVariables);
230
2
        res = PropagationResult::CONTRACTED;
231
2
        break;
232
2
      case PropagationResult::CONTRACTED_WITHOUT_CURRENT:
233
      case PropagationResult::CONTRACTED_STRONGLY_WITHOUT_CURRENT:
234
2
        d_state.d_origins.add(d_mapper(c.lhs), c.origin, c.rhsVariables, false);
235
2
        res = PropagationResult::CONTRACTED;
236
2
        break;
237
      case PropagationResult::CONFLICT:
238
        d_state.d_origins.add(d_mapper(c.lhs), c.origin, c.rhsVariables);
239
        d_state.d_conflict = d_state.d_origins.getOrigins(d_mapper(c.lhs));
240
        return PropagationResult::CONFLICT;
241
    }
242
118
    switch (cres)
243
    {
244
2
      case PropagationResult::CONTRACTED_STRONGLY:
245
      case PropagationResult::CONTRACTED_STRONGLY_WITHOUT_CURRENT:
246
4
        Trace("nl-icp") << "Bumping budget because of a strong contraction"
247
2
                        << std::endl;
248
2
        d_budget += d_budgetIncrement;
249
2
        break;
250
116
      default: break;
251
    }
252
  }
253
36
  return res;
254
}
255
256
4
std::vector<Node> ICPSolver::generateLemmas() const
257
{
258
4
  auto nm = NodeManager::currentNM();
259
4
  std::vector<Node> lemmas;
260
261
12
  for (const auto& vars : d_mapper.mVarCVCpoly)
262
  {
263
8
    if (!d_state.d_assignment.has(vars.second)) continue;
264
16
    Node v = vars.first;
265
16
    poly::Interval i = d_state.d_assignment.get(vars.second);
266
8
    if (!is_minus_infinity(get_lower(i)))
267
    {
268
8
      Kind rel = get_lower_open(i) ? Kind::GT : Kind::GEQ;
269
16
      Node c = nm->mkNode(rel, v, value_to_node(get_lower(i), v));
270
8
      if (!d_state.d_origins.isInOrigins(v, c))
271
      {
272
16
        Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v));
273
8
        Trace("nl-icp") << premise << " => " << c << std::endl;
274
16
        Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c));
275
8
        if (lemma.isConst())
276
        {
277
          Assert(lemma == nm->mkConst<bool>(true));
278
        }
279
        else
280
        {
281
8
          Trace("nl-icp") << "Adding lemma " << lemma << std::endl;
282
8
          lemmas.emplace_back(lemma);
283
        }
284
      }
285
    }
286
8
    if (!is_plus_infinity(get_upper(i)))
287
    {
288
8
      Kind rel = get_upper_open(i) ? Kind::LT : Kind::LEQ;
289
16
      Node c = nm->mkNode(rel, v, value_to_node(get_upper(i), v));
290
8
      if (!d_state.d_origins.isInOrigins(v, c))
291
      {
292
16
        Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v));
293
8
        Trace("nl-icp") << premise << " => " << c << std::endl;
294
16
        Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c));
295
8
        if (lemma.isConst())
296
        {
297
          Assert(lemma == nm->mkConst<bool>(true));
298
        }
299
        else
300
        {
301
8
          Trace("nl-icp") << "Adding lemma " << lemma << std::endl;
302
8
          lemmas.emplace_back(lemma);
303
        }
304
      }
305
    }
306
  }
307
4
  return lemmas;
308
}
309
310
32
void ICPSolver::reset(const std::vector<Node>& assertions)
311
{
312
32
  d_state.reset();
313
646
  for (const auto& n : assertions)
314
  {
315
614
    Trace("nl-icp") << "Adding " << n << std::endl;
316
614
    if (n.getKind() != Kind::CONST_BOOLEAN)
317
    {
318
614
      if (!d_state.d_bounds.add(n))
319
      {
320
522
        addCandidate(n);
321
      }
322
    }
323
  }
324
32
}
325
326
32
void ICPSolver::check()
327
{
328
32
  initOrigins();
329
32
  d_state.d_assignment = getBounds(d_mapper, d_state.d_bounds);
330
32
  bool did_progress = false;
331
32
  bool progress = false;
332
36
  do
333
  {
334
36
    switch (doPropagationRound())
335
    {
336
32
      case icp::PropagationResult::NOT_CHANGED: progress = false; break;
337
4
      case icp::PropagationResult::CONTRACTED:
338
      case icp::PropagationResult::CONTRACTED_STRONGLY:
339
      case icp::PropagationResult::CONTRACTED_WITHOUT_CURRENT:
340
      case icp::PropagationResult::CONTRACTED_STRONGLY_WITHOUT_CURRENT:
341
4
        did_progress = true;
342
4
        progress = true;
343
4
        break;
344
      case icp::PropagationResult::CONFLICT:
345
        Trace("nl-icp") << "Found a conflict: " << d_state.d_conflict
346
                        << std::endl;
347
348
        std::vector<Node> mis;
349
        for (const auto& n : d_state.d_conflict)
350
        {
351
          mis.emplace_back(n.negate());
352
        }
353
        d_im.addPendingLemma(NodeManager::currentNM()->mkOr(mis),
354
                             InferenceId::ARITH_NL_ICP_CONFLICT);
355
        did_progress = true;
356
        progress = false;
357
        break;
358
    }
359
  } while (progress);
360
32
  if (did_progress)
361
  {
362
8
    std::vector<Node> lemmas = generateLemmas();
363
20
    for (const auto& l : lemmas)
364
    {
365
16
      d_im.addPendingLemma(l, InferenceId::ARITH_NL_ICP_PROPAGATION);
366
    }
367
  }
368
32
}
369
370
#else /* CVC5_POLY_IMP */
371
372
void ICPSolver::reset(const std::vector<Node>& assertions)
373
{
374
  Unimplemented() << "ICPSolver requires cvc5 to be configured with LibPoly";
375
}
376
377
void ICPSolver::check()
378
{
379
  Unimplemented() << "ICPSolver requires cvc5 to be configured with LibPoly";
380
}
381
382
#endif /* CVC5_POLY_IMP */
383
384
}  // namespace icp
385
}  // namespace nl
386
}  // namespace arith
387
}  // namespace theory
388
28191
}  // namespace cvc5