GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/learned_rewrite.cpp Lines: 120 222 54.1 %
Date: 2021-09-15 Branches: 214 890 24.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Rewriting based on learned literals
14
 */
15
16
#include "preprocessing/passes/learned_rewrite.h"
17
18
#include "expr/skolem_manager.h"
19
#include "expr/term_context_stack.h"
20
#include "preprocessing/assertion_pipeline.h"
21
#include "smt/smt_statistics_registry.h"
22
#include "theory/arith/arith_msum.h"
23
#include "theory/rewriter.h"
24
#include "util/rational.h"
25
26
using namespace cvc5::theory;
27
using namespace cvc5::kind;
28
29
namespace cvc5 {
30
namespace preprocessing {
31
namespace passes {
32
33
const char* toString(LearnedRewriteId i)
34
{
35
  switch (i)
36
  {
37
    case LearnedRewriteId::NON_ZERO_DEN: return "NON_ZERO_DEN";
38
    case LearnedRewriteId::INT_MOD_RANGE: return "INT_MOD_RANGE";
39
    case LearnedRewriteId::PRED_POS_LB: return "PRED_POS_LB";
40
    case LearnedRewriteId::PRED_ZERO_LB: return "PRED_ZERO_LB";
41
    case LearnedRewriteId::PRED_NEG_UB: return "PRED_NEG_UB";
42
    case LearnedRewriteId::NONE: return "NONE";
43
    default: return "?LearnedRewriteId?";
44
  }
45
}
46
47
std::ostream& operator<<(std::ostream& out, LearnedRewriteId i)
48
{
49
  out << toString(i);
50
  return out;
51
}
52
53
9942
LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext)
54
    : PreprocessingPass(preprocContext, "learned-rewrite"),
55
9942
      d_lrewCount(statisticsRegistry().registerHistogram<LearnedRewriteId>(
56
19884
          "LearnedRewrite::lrewCount"))
57
{
58
9942
}
59
60
2
PreprocessingPassResult LearnedRewrite::applyInternal(
61
    AssertionPipeline* assertionsToPreprocess)
62
{
63
2
  NodeManager* nm = NodeManager::currentNM();
64
4
  arith::BoundInference binfer;
65
4
  std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
66
4
  std::unordered_set<Node> llrw;
67
4
  std::unordered_map<TNode, Node> visited;
68
2
  if (learnedLits.empty())
69
  {
70
    Trace("learned-rewrite-ll") << "No learned literals" << std::endl;
71
    return PreprocessingPassResult::NO_CONFLICT;
72
  }
73
  else
74
  {
75
2
    Trace("learned-rewrite-ll") << "Learned literals:" << std::endl;
76
2
    std::map<Node, Node> originLit;
77
10
    for (const Node& l : learnedLits)
78
    {
79
      // maybe use the literal for bound inference?
80
8
      bool pol = l.getKind()!=NOT;
81
16
      TNode atom = pol ? l : l[0];
82
8
      Kind ak = atom.getKind();
83
8
      Assert(ak != LT && ak != GT && ak != LEQ);
84
8
      if ((ak == EQUAL && pol) || ak == GEQ)
85
      {
86
        // provide as < if negated >=
87
12
        Node atomu;
88
6
        if (!pol)
89
        {
90
          atomu = nm->mkNode(LT, atom[0], atom[1]);
91
          originLit[atomu] = l;
92
        }
93
        else
94
        {
95
6
          atomu = l;
96
6
          originLit[l] = l;
97
        }
98
6
        binfer.add(atomu);
99
      }
100
8
      Trace("learned-rewrite-ll") << "- " << l << std::endl;
101
    }
102
2
    const std::map<Node, arith::Bounds>& bs = binfer.get();
103
    // get the literals that were critical, i.e. used in the derivation of a
104
    // bound
105
8
    for (const std::pair<const Node, arith::Bounds>& b : bs)
106
    {
107
18
      for (size_t i = 0; i < 2; i++)
108
      {
109
24
        Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin;
110
12
        if (!origin.isNull())
111
        {
112
6
          Assert (originLit.find(origin)!=originLit.end());
113
6
          llrw.insert(originLit[origin]);
114
        }
115
      }
116
    }
117
    // rewrite the non-critical learned literals, some may be redundant
118
2
    for (const Node& l : learnedLits)
119
    {
120
2
      if (llrw.find(l) != llrw.end())
121
      {
122
        continue;
123
      }
124
2
      Node e = rewriteLearnedRec(l, binfer, llrw, visited);
125
2
      if (e.isConst())
126
      {
127
        // ignore true
128
2
        if (e.getConst<bool>())
129
        {
130
          continue;
131
        }
132
        // conflict, we are done
133
2
        assertionsToPreprocess->push_back(e);
134
2
        return PreprocessingPassResult::CONFLICT;
135
      }
136
      llrw.insert(e);
137
    }
138
    Trace("learned-rewrite-ll") << "end" << std::endl;
139
  }
140
  size_t size = assertionsToPreprocess->size();
141
  for (size_t i = 0; i < size; ++i)
142
  {
143
    Node prev = (*assertionsToPreprocess)[i];
144
    Trace("learned-rewrite-assert")
145
        << "LearnedRewrite: assert: " << prev << std::endl;
146
    Node e = rewriteLearnedRec(prev, binfer, llrw, visited);
147
    if (e != prev)
148
    {
149
      Trace("learned-rewrite-assert")
150
          << ".......................: " << e << std::endl;
151
      assertionsToPreprocess->replace(i, e);
152
    }
153
  }
154
  // Add the conjunction of learned literals back to assertions. Notice that
155
  // in some cases we may add top-level assertions back to the assertion list
156
  // unchanged.
157
  if (!llrw.empty())
158
  {
159
    std::vector<Node> llrvec(llrw.begin(), llrw.end());
160
    Node llc = nm->mkAnd(llrvec);
161
    Trace("learned-rewrite-assert")
162
        << "Re-add rewritten learned conjunction: " << llc << std::endl;
163
    assertionsToPreprocess->push_back(llc);
164
  }
165
166
  return PreprocessingPassResult::NO_CONFLICT;
167
}
168
169
2
Node LearnedRewrite::rewriteLearnedRec(Node n,
170
                                       arith::BoundInference& binfer,
171
                                       std::unordered_set<Node>& lems,
172
                                       std::unordered_map<TNode, Node>& visited)
173
{
174
2
  NodeManager* nm = NodeManager::currentNM();
175
2
  std::unordered_map<TNode, Node>::iterator it;
176
4
  std::vector<TNode> visit;
177
4
  TNode cur;
178
2
  visit.push_back(n);
179
46
  do
180
  {
181
48
    cur = visit.back();
182
48
    visit.pop_back();
183
48
    it = visited.find(cur);
184
48
    if (lems.find(cur) != lems.end())
185
    {
186
      // n is a learned literal: replace by true, not considered a rewrite
187
      // for statistics
188
      visited[cur] = nm->mkConst(true);
189
      continue;
190
    }
191
48
    if (it == visited.end())
192
    {
193
      // mark pre-visited with null; will post-visit to construct final node
194
      // in the block below.
195
20
      visited[cur] = Node::null();
196
20
      visit.push_back(cur);
197
20
      visit.insert(visit.end(), cur.begin(), cur.end());
198
    }
199
28
    else if (it->second.isNull())
200
    {
201
40
      Node ret = cur;
202
20
      bool needsRcons = false;
203
40
      std::vector<Node> children;
204
20
      if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
205
      {
206
        children.push_back(cur.getOperator());
207
      }
208
46
      for (const Node& cn : cur)
209
      {
210
26
        it = visited.find(cn);
211
26
        Assert(it != visited.end());
212
26
        Assert(!it->second.isNull());
213
26
        needsRcons = needsRcons || cn != it->second;
214
26
        children.push_back(it->second);
215
      }
216
20
      if (needsRcons)
217
      {
218
8
        ret = nm->mkNode(cur.getKind(), children);
219
      }
220
      // rewrite here
221
20
      ret = rewriteLearned(ret, binfer, lems);
222
20
      visited[cur] = ret;
223
    }
224
48
  } while (!visit.empty());
225
2
  Assert(visited.find(n) != visited.end());
226
2
  Assert(!visited.find(n)->second.isNull());
227
4
  return visited[n];
228
}
229
230
20
Node LearnedRewrite::rewriteLearned(Node n,
231
                                    arith::BoundInference& binfer,
232
                                    std::unordered_set<Node>& lems)
233
{
234
20
  NodeManager* nm = NodeManager::currentNM();
235
20
  Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl;
236
20
  Node nr = rewrite(n);
237
20
  Kind k = nr.getKind();
238
20
  if (k == INTS_DIVISION || k == INTS_MODULUS || k == DIVISION)
239
  {
240
    // simpler if we know the divisor is non-zero
241
12
    Node num = n[0];
242
12
    Node den = n[1];
243
6
    bool isNonZeroDen = false;
244
6
    if (den.isConst())
245
    {
246
      isNonZeroDen = (den.getConst<Rational>().sgn() != 0);
247
    }
248
    else
249
    {
250
12
      arith::Bounds db = binfer.get(den);
251
12
      Trace("learned-rewrite-rr-debug")
252
6
          << "Bounds for " << den << " : " << db.lower_value << " "
253
6
          << db.upper_value << std::endl;
254
12
      if (!db.lower_value.isNull()
255
6
          && db.lower_value.getConst<Rational>().sgn() == 1)
256
      {
257
6
        isNonZeroDen = true;
258
      }
259
      else if (!db.upper_value.isNull()
260
               && db.upper_value.getConst<Rational>().sgn() == -1)
261
      {
262
        isNonZeroDen = true;
263
      }
264
    }
265
6
    if (isNonZeroDen)
266
    {
267
12
      Trace("learned-rewrite-rr-debug")
268
6
          << "...non-zero denominator" << std::endl;
269
6
      Kind nk = k;
270
6
      switch (k)
271
      {
272
        case INTS_DIVISION: nk = INTS_DIVISION_TOTAL; break;
273
6
        case INTS_MODULUS: nk = INTS_MODULUS_TOTAL; break;
274
        case DIVISION: nk = DIVISION_TOTAL; break;
275
        default: Assert(false); break;
276
      }
277
12
      std::vector<Node> children;
278
6
      children.insert(children.end(), n.begin(), n.end());
279
12
      Node ret = nm->mkNode(nk, children);
280
6
      nr = returnRewriteLearned(nr, ret, LearnedRewriteId::NON_ZERO_DEN);
281
6
      nr = rewrite(nr);
282
6
      k = nr.getKind();
283
    }
284
  }
285
  // constant int mod elimination by bound inference
286
20
  if (k == INTS_MODULUS_TOTAL)
287
  {
288
12
    Node num = n[0];
289
12
    Node den = n[1];
290
12
    arith::Bounds db = binfer.get(den);
291
12
    if ((!db.lower_value.isNull()
292
6
         && db.lower_value.getConst<Rational>().sgn() == 1)
293
12
        || (!db.upper_value.isNull()
294
            && db.upper_value.getConst<Rational>().sgn() == -1))
295
    {
296
6
      Rational bden = db.upper_value.isNull()
297
                          ? db.lower_value.getConst<Rational>()
298
12
                          : db.upper_value.getConst<Rational>().abs();
299
      // if 0 <= UB(num) < LB(den) or 0 <= UB(num) < -UB(den)
300
12
      arith::Bounds nb = binfer.get(num);
301
6
      if (!nb.upper_value.isNull())
302
      {
303
        Rational bnum = nb.upper_value.getConst<Rational>();
304
        if (bnum.sgn() != -1 && bnum < bden)
305
        {
306
          nr = returnRewriteLearned(nr, nr[0], LearnedRewriteId::INT_MOD_RANGE);
307
        }
308
      }
309
      // could also do num + k*den checks
310
    }
311
  }
312
14
  else if (k == GEQ || (k == EQUAL && nr[0].getType().isReal()))
313
  {
314
    std::map<Node, Node> msum;
315
    if (ArithMSum::getMonomialSumLit(nr, msum))
316
    {
317
      Rational lb(0);
318
      Rational ub(0);
319
      bool lbSuccess = true;
320
      bool ubSuccess = true;
321
      Rational one(1);
322
      if (Trace.isOn("learned-rewrite-arith-lit"))
323
      {
324
        Trace("learned-rewrite-arith-lit")
325
            << "Arithmetic lit: " << nr << std::endl;
326
        for (const std::pair<const Node, Node>& m : msum)
327
        {
328
          Trace("learned-rewrite-arith-lit")
329
              << "  " << m.first << ", " << m.second << std::endl;
330
        }
331
      }
332
      for (const std::pair<const Node, Node>& m : msum)
333
      {
334
        bool isOneCoeff = m.second.isNull();
335
        Assert(isOneCoeff || m.second.isConst());
336
        if (m.first.isNull())
337
        {
338
          lb = lb + (isOneCoeff ? one : m.second.getConst<Rational>());
339
          ub = ub + (isOneCoeff ? one : m.second.getConst<Rational>());
340
        }
341
        else
342
        {
343
          arith::Bounds b = binfer.get(m.first);
344
          bool isNeg = !isOneCoeff && m.second.getConst<Rational>().sgn() == -1;
345
          // flip lower/upper if negative coefficient
346
          TNode l = isNeg ? b.upper_value : b.lower_value;
347
          TNode u = isNeg ? b.lower_value : b.upper_value;
348
          if (lbSuccess && !l.isNull())
349
          {
350
            Rational lc = l.getConst<Rational>();
351
            lb = lb
352
                 + (isOneCoeff ? lc
353
                               : Rational(lc * m.second.getConst<Rational>()));
354
          }
355
          else
356
          {
357
            lbSuccess = false;
358
          }
359
          if (ubSuccess && !u.isNull())
360
          {
361
            Rational uc = u.getConst<Rational>();
362
            ub = ub
363
                 + (isOneCoeff ? uc
364
                               : Rational(uc * m.second.getConst<Rational>()));
365
          }
366
          else
367
          {
368
            ubSuccess = false;
369
          }
370
          if (!lbSuccess && !ubSuccess)
371
          {
372
            break;
373
          }
374
        }
375
      }
376
      if (lbSuccess)
377
      {
378
        if (lb.sgn() == 1)
379
        {
380
          // if positive lower bound, then GEQ is true, EQUAL is false
381
          Node ret = nm->mkConst(k == GEQ);
382
          nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_POS_LB);
383
          return nr;
384
        }
385
        else if (lb.sgn() == 0 && k == GEQ)
386
        {
387
          // zero lower bound, GEQ is true
388
          Node ret = nm->mkConst(true);
389
          nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_ZERO_LB);
390
          return nr;
391
        }
392
      }
393
      else if (ubSuccess)
394
      {
395
        if (ub.sgn() == -1)
396
        {
397
          // if negative upper bound, then GEQ and EQUAL are false
398
          Node ret = nm->mkConst(false);
399
          nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_NEG_UB);
400
          return nr;
401
        }
402
      }
403
    }
404
  }
405
20
  return nr;
406
}
407
408
6
Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id)
409
{
410
6
  if (Trace.isOn("learned-rewrite"))
411
  {
412
    Trace("learned-rewrite") << "LearnedRewrite::Rewrite: (" << id << ") " << n
413
                             << " == " << nr << std::endl;
414
  }
415
6
  d_lrewCount << id;
416
6
  return nr;
417
}
418
419
}  // namespace passes
420
}  // namespace preprocessing
421
29577
}  // namespace cvc5