GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/icp/icp_solver.cpp Lines: 168 200 84.0 %
Date: 2021-09-18 Branches: 267 611 43.7 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Gereon Kremer, Andres Noetzli
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implements a ICP-based solver for nonlinear arithmetic.
14
 */
15
16
#include "theory/arith/nl/icp/icp_solver.h"
17
18
#include <iostream>
19
20
#include "base/check.h"
21
#include "base/output.h"
22
#include "expr/node_algorithm.h"
23
#include "theory/arith/arith_msum.h"
24
#include "theory/arith/inference_manager.h"
25
#include "theory/arith/nl/poly_conversion.h"
26
#include "theory/arith/normal_form.h"
27
#include "theory/rewriter.h"
28
#include "util/poly_util.h"
29
30
namespace cvc5 {
31
namespace theory {
32
namespace arith {
33
namespace nl {
34
namespace icp {
35
36
#ifdef CVC5_POLY_IMP
37
38
namespace {
39
/** A simple wrapper to nicely print an interval assignment. */
40
struct IAWrapper
41
{
42
  const poly::IntervalAssignment& ia;
43
  const VariableMapper& vm;
44
};
45
inline std::ostream& operator<<(std::ostream& os, const IAWrapper& iaw)
46
{
47
  os << "{ ";
48
  bool first = true;
49
  for (const auto& v : iaw.vm.mVarpolyCVC)
50
  {
51
    if (iaw.ia.has(v.first))
52
    {
53
      if (first)
54
      {
55
        first = false;
56
      }
57
      else
58
      {
59
        os << ", ";
60
      }
61
      os << v.first << " -> " << iaw.ia.get(v.first);
62
    }
63
  }
64
  return os << " }";
65
}
66
}  // namespace
67
68
30
std::vector<Node> ICPSolver::collectVariables(const Node& n) const
69
{
70
60
  std::unordered_set<TNode> tmp;
71
30
  expr::getVariables(n, tmp);
72
30
  std::vector<Node> res;
73
62
  for (const auto& t : tmp)
74
  {
75
32
    res.emplace_back(t);
76
  }
77
60
  return res;
78
}
79
80
92
std::vector<Candidate> ICPSolver::constructCandidates(const Node& n)
81
{
82
184
  Node tmp = Rewriter::rewrite(n);
83
92
  if (tmp.isConst())
84
  {
85
    return {};
86
  }
87
184
  auto comp = Comparison::parseNormalForm(tmp).decompose(false);
88
92
  Kind k = std::get<1>(comp);
89
92
  if (k == Kind::DISTINCT)
90
  {
91
54
    return {};
92
  }
93
76
  auto poly = std::get<0>(comp);
94
95
76
  std::vector<Candidate> result;
96
76
  std::unordered_set<TNode> vars;
97
38
  expr::getVariables(n, vars);
98
104
  for (const auto& v : vars)
99
  {
100
66
    Trace("nl-icp") << "\tChecking " << n << " for " << v << std::endl;
101
102
132
    std::map<Node, Node> msum;
103
66
    ArithMSum::getMonomialSum(poly.getNode(), msum);
104
105
132
    Node veq_c;
106
132
    Node val;
107
108
66
    int isolated = ArithMSum::isolate(v, msum, veq_c, val, k);
109
66
    if (isolated == 1)
110
    {
111
20
      poly::Variable lhs = d_mapper(v);
112
20
      poly::SignCondition rel = poly::SignCondition::EQ;
113
20
      switch (k)
114
      {
115
6
        case Kind::LT: rel = poly::SignCondition::LT; break;
116
        case Kind::LEQ: rel = poly::SignCondition::LE; break;
117
14
        case Kind::EQUAL: rel = poly::SignCondition::EQ; break;
118
        case Kind::DISTINCT: rel = poly::SignCondition::NE; break;
119
        case Kind::GT: rel = poly::SignCondition::GT; break;
120
        case Kind::GEQ: rel = poly::SignCondition::GE; break;
121
        default: Assert(false) << "Unexpected kind: " << k;
122
      }
123
40
      poly::Rational rhsmult;
124
40
      poly::Polynomial rhs = as_poly_polynomial(val, d_mapper, rhsmult);
125
20
      rhsmult = poly::Rational(1) / rhsmult;
126
      // only correct up to a constant (denominator is thrown away!)
127
20
      if (!veq_c.isNull())
128
      {
129
        rhsmult = poly_utils::toRational(veq_c.getConst<Rational>());
130
      }
131
40
      Candidate res{lhs, rel, rhs, rhsmult, n, collectVariables(val)};
132
20
      Trace("nl-icp") << "\tAdded " << res << " from " << n << std::endl;
133
20
      result.emplace_back(res);
134
    }
135
46
    else if (isolated == -1)
136
    {
137
10
      poly::Variable lhs = d_mapper(v);
138
10
      poly::SignCondition rel = poly::SignCondition::EQ;
139
10
      switch (k)
140
      {
141
10
        case Kind::LT: rel = poly::SignCondition::GT; break;
142
        case Kind::LEQ: rel = poly::SignCondition::GE; break;
143
        case Kind::EQUAL: rel = poly::SignCondition::EQ; break;
144
        case Kind::DISTINCT: rel = poly::SignCondition::NE; break;
145
        case Kind::GT: rel = poly::SignCondition::LT; break;
146
        case Kind::GEQ: rel = poly::SignCondition::LE; break;
147
        default: Assert(false) << "Unexpected kind: " << k;
148
      }
149
20
      poly::Rational rhsmult;
150
20
      poly::Polynomial rhs = as_poly_polynomial(val, d_mapper, rhsmult);
151
10
      rhsmult = poly::Rational(1) / rhsmult;
152
10
      if (!veq_c.isNull())
153
      {
154
        rhsmult = poly_utils::toRational(veq_c.getConst<Rational>());
155
      }
156
20
      Candidate res{lhs, rel, rhs, rhsmult, n, collectVariables(val)};
157
10
      Trace("nl-icp") << "\tAdded " << res << " from " << n << std::endl;
158
10
      result.emplace_back(res);
159
    }
160
  }
161
38
  return result;
162
}
163
164
848
void ICPSolver::addCandidate(const Node& n)
165
{
166
848
  auto it = d_candidateCache.find(n);
167
848
  if (it != d_candidateCache.end())
168
  {
169
884
    for (const auto& c : it->second)
170
    {
171
128
      d_state.d_candidates.emplace_back(c);
172
    }
173
  }
174
  else
175
  {
176
184
    auto cands = constructCandidates(n);
177
92
    d_candidateCache.emplace(n, cands);
178
122
    for (const auto& c : cands)
179
    {
180
30
      d_state.d_candidates.emplace_back(c);
181
60
      Trace("nl-icp") << "Bumping budget because of the new candidate"
182
30
                      << std::endl;
183
30
      d_budget += d_budgetIncrement;
184
    }
185
  }
186
848
}
187
188
48
void ICPSolver::initOrigins()
189
{
190
140
  for (const auto& vars : d_state.d_bounds.get())
191
  {
192
92
    const Bounds& i = vars.second;
193
184
    Trace("nl-icp") << "Adding initial " << vars.first << " -> " << i
194
92
                    << std::endl;
195
92
    if (!i.lower_origin.isNull())
196
    {
197
80
      Trace("nl-icp") << "\tAdding lower " << i.lower_origin << std::endl;
198
80
      d_state.d_origins.add(vars.first, i.lower_origin, {});
199
    }
200
92
    if (!i.upper_origin.isNull())
201
    {
202
72
      Trace("nl-icp") << "\tAdding upper " << i.upper_origin << std::endl;
203
72
      d_state.d_origins.add(vars.first, i.upper_origin, {});
204
    }
205
  }
206
48
}
207
208
52
PropagationResult ICPSolver::doPropagationRound()
209
{
210
52
  if (d_budget <= 0)
211
  {
212
    Trace("nl-icp") << "ICP budget exceeded" << std::endl;
213
    return PropagationResult::NOT_CHANGED;
214
  }
215
52
  d_state.d_conflict.clear();
216
104
  Trace("nl-icp") << "Starting propagation with "
217
52
                  << IAWrapper{d_state.d_assignment, d_mapper} << std::endl;
218
52
  Trace("nl-icp") << "Current budget: " << d_budget << std::endl;
219
52
  PropagationResult res = PropagationResult::NOT_CHANGED;
220
216
  for (const auto& c : d_state.d_candidates)
221
  {
222
166
    --d_budget;
223
166
    PropagationResult cres = c.propagate(d_state.d_assignment, 100);
224
166
    switch (cres)
225
    {
226
160
      case PropagationResult::NOT_CHANGED: break;
227
      case PropagationResult::CONTRACTED:
228
      case PropagationResult::CONTRACTED_STRONGLY:
229
        d_state.d_origins.add(d_mapper(c.lhs), c.origin, c.rhsVariables);
230
        res = PropagationResult::CONTRACTED;
231
        break;
232
4
      case PropagationResult::CONTRACTED_WITHOUT_CURRENT:
233
      case PropagationResult::CONTRACTED_STRONGLY_WITHOUT_CURRENT:
234
4
        d_state.d_origins.add(d_mapper(c.lhs), c.origin, c.rhsVariables, false);
235
4
        res = PropagationResult::CONTRACTED;
236
4
        break;
237
2
      case PropagationResult::CONFLICT:
238
2
        d_state.d_origins.add(d_mapper(c.lhs), c.origin, c.rhsVariables);
239
2
        d_state.d_conflict = d_state.d_origins.getOrigins(d_mapper(c.lhs));
240
2
        return PropagationResult::CONFLICT;
241
    }
242
164
    switch (cres)
243
    {
244
4
      case PropagationResult::CONTRACTED_STRONGLY:
245
      case PropagationResult::CONTRACTED_STRONGLY_WITHOUT_CURRENT:
246
8
        Trace("nl-icp") << "Bumping budget because of a strong contraction"
247
4
                        << std::endl;
248
4
        d_budget += d_budgetIncrement;
249
4
        break;
250
160
      default: break;
251
    }
252
  }
253
50
  return res;
254
}
255
256
6
std::vector<Node> ICPSolver::generateLemmas() const
257
{
258
6
  auto nm = NodeManager::currentNM();
259
6
  std::vector<Node> lemmas;
260
261
18
  for (const auto& vars : d_mapper.mVarCVCpoly)
262
  {
263
12
    if (!d_state.d_assignment.has(vars.second)) continue;
264
24
    Node v = vars.first;
265
24
    poly::Interval i = d_state.d_assignment.get(vars.second);
266
12
    if (!is_minus_infinity(get_lower(i)))
267
    {
268
12
      Kind rel = get_lower_open(i) ? Kind::GT : Kind::GEQ;
269
24
      Node c = nm->mkNode(rel, v, value_to_node(get_lower(i), v));
270
12
      if (!d_state.d_origins.isInOrigins(v, c))
271
      {
272
24
        Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v));
273
12
        Trace("nl-icp") << premise << " => " << c << std::endl;
274
24
        Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c));
275
12
        if (lemma.isConst())
276
        {
277
          Assert(lemma == nm->mkConst<bool>(true));
278
        }
279
        else
280
        {
281
12
          Trace("nl-icp") << "Adding lemma " << lemma << std::endl;
282
12
          lemmas.emplace_back(lemma);
283
        }
284
      }
285
    }
286
12
    if (!is_plus_infinity(get_upper(i)))
287
    {
288
10
      Kind rel = get_upper_open(i) ? Kind::LT : Kind::LEQ;
289
20
      Node c = nm->mkNode(rel, v, value_to_node(get_upper(i), v));
290
10
      if (!d_state.d_origins.isInOrigins(v, c))
291
      {
292
20
        Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v));
293
10
        Trace("nl-icp") << premise << " => " << c << std::endl;
294
20
        Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c));
295
10
        if (lemma.isConst())
296
        {
297
          Assert(lemma == nm->mkConst<bool>(true));
298
        }
299
        else
300
        {
301
10
          Trace("nl-icp") << "Adding lemma " << lemma << std::endl;
302
10
          lemmas.emplace_back(lemma);
303
        }
304
      }
305
    }
306
  }
307
6
  return lemmas;
308
}
309
310
48
void ICPSolver::reset(const std::vector<Node>& assertions)
311
{
312
48
  d_state.reset();
313
1038
  for (const auto& n : assertions)
314
  {
315
990
    Trace("nl-icp") << "Adding " << n << std::endl;
316
990
    if (n.getKind() != Kind::CONST_BOOLEAN)
317
    {
318
990
      if (!d_state.d_bounds.add(n))
319
      {
320
848
        addCandidate(n);
321
      }
322
    }
323
  }
324
48
}
325
326
48
void ICPSolver::check()
327
{
328
48
  initOrigins();
329
48
  d_state.d_assignment = getBounds(d_mapper, d_state.d_bounds);
330
48
  bool did_progress = false;
331
48
  bool progress = false;
332
52
  do
333
  {
334
52
    switch (doPropagationRound())
335
    {
336
46
      case icp::PropagationResult::NOT_CHANGED: progress = false; break;
337
4
      case icp::PropagationResult::CONTRACTED:
338
      case icp::PropagationResult::CONTRACTED_STRONGLY:
339
      case icp::PropagationResult::CONTRACTED_WITHOUT_CURRENT:
340
      case icp::PropagationResult::CONTRACTED_STRONGLY_WITHOUT_CURRENT:
341
4
        did_progress = true;
342
4
        progress = true;
343
4
        break;
344
2
      case icp::PropagationResult::CONFLICT:
345
4
        Trace("nl-icp") << "Found a conflict: " << d_state.d_conflict
346
2
                        << std::endl;
347
348
4
        std::vector<Node> mis;
349
8
        for (const auto& n : d_state.d_conflict)
350
        {
351
6
          mis.emplace_back(n.negate());
352
        }
353
2
        d_im.addPendingLemma(NodeManager::currentNM()->mkOr(mis),
354
                             InferenceId::ARITH_NL_ICP_CONFLICT);
355
2
        did_progress = true;
356
2
        progress = false;
357
2
        break;
358
    }
359
  } while (progress);
360
48
  if (did_progress)
361
  {
362
12
    std::vector<Node> lemmas = generateLemmas();
363
28
    for (const auto& l : lemmas)
364
    {
365
22
      d_im.addPendingLemma(l, InferenceId::ARITH_NL_ICP_PROPAGATION);
366
    }
367
  }
368
48
}
369
370
#else /* CVC5_POLY_IMP */
371
372
void ICPSolver::reset(const std::vector<Node>& assertions)
373
{
374
  Unimplemented() << "ICPSolver requires cvc5 to be configured with LibPoly";
375
}
376
377
void ICPSolver::check()
378
{
379
  Unimplemented() << "ICPSolver requires cvc5 to be configured with LibPoly";
380
}
381
382
#endif /* CVC5_POLY_IMP */
383
384
}  // namespace icp
385
}  // namespace nl
386
}  // namespace arith
387
}  // namespace theory
388
29574
}  // namespace cvc5