1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Andrew Reynolds |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Rewriting based on learned literals |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "preprocessing/passes/learned_rewrite.h" |
17 |
|
|
18 |
|
#include "expr/skolem_manager.h" |
19 |
|
#include "expr/term_context_stack.h" |
20 |
|
#include "preprocessing/assertion_pipeline.h" |
21 |
|
#include "smt/smt_statistics_registry.h" |
22 |
|
#include "theory/arith/arith_msum.h" |
23 |
|
#include "theory/rewriter.h" |
24 |
|
#include "util/rational.h" |
25 |
|
|
26 |
|
using namespace cvc5::theory; |
27 |
|
using namespace cvc5::kind; |
28 |
|
|
29 |
|
namespace cvc5 { |
30 |
|
namespace preprocessing { |
31 |
|
namespace passes { |
32 |
|
|
33 |
|
const char* toString(LearnedRewriteId i) |
34 |
|
{ |
35 |
|
switch (i) |
36 |
|
{ |
37 |
|
case LearnedRewriteId::NON_ZERO_DEN: return "NON_ZERO_DEN"; |
38 |
|
case LearnedRewriteId::INT_MOD_RANGE: return "INT_MOD_RANGE"; |
39 |
|
case LearnedRewriteId::PRED_POS_LB: return "PRED_POS_LB"; |
40 |
|
case LearnedRewriteId::PRED_ZERO_LB: return "PRED_ZERO_LB"; |
41 |
|
case LearnedRewriteId::PRED_NEG_UB: return "PRED_NEG_UB"; |
42 |
|
case LearnedRewriteId::NONE: return "NONE"; |
43 |
|
default: return "?LearnedRewriteId?"; |
44 |
|
} |
45 |
|
} |
46 |
|
|
47 |
|
std::ostream& operator<<(std::ostream& out, LearnedRewriteId i) |
48 |
|
{ |
49 |
|
out << toString(i); |
50 |
|
return out; |
51 |
|
} |
52 |
|
|
53 |
9853 |
LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext) |
54 |
|
: PreprocessingPass(preprocContext, "learned-rewrite"), |
55 |
9853 |
d_lrewCount(smtStatisticsRegistry().registerHistogram<LearnedRewriteId>( |
56 |
19706 |
"LearnedRewrite::lrewCount")) |
57 |
|
{ |
58 |
9853 |
} |
59 |
|
|
60 |
2 |
PreprocessingPassResult LearnedRewrite::applyInternal( |
61 |
|
AssertionPipeline* assertionsToPreprocess) |
62 |
|
{ |
63 |
2 |
NodeManager* nm = NodeManager::currentNM(); |
64 |
4 |
arith::BoundInference binfer; |
65 |
4 |
std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals(); |
66 |
4 |
std::unordered_set<Node> llrw; |
67 |
4 |
std::unordered_map<TNode, Node> visited; |
68 |
2 |
if (learnedLits.empty()) |
69 |
|
{ |
70 |
|
Trace("learned-rewrite-ll") << "No learned literals" << std::endl; |
71 |
|
return PreprocessingPassResult::NO_CONFLICT; |
72 |
|
} |
73 |
|
else |
74 |
|
{ |
75 |
2 |
Trace("learned-rewrite-ll") << "Learned literals:" << std::endl; |
76 |
2 |
std::map<Node, Node> originLit; |
77 |
10 |
for (const Node& l : learnedLits) |
78 |
|
{ |
79 |
|
// maybe use the literal for bound inference? |
80 |
8 |
bool pol = l.getKind()!=NOT; |
81 |
16 |
TNode atom = pol ? l : l[0]; |
82 |
8 |
Kind ak = atom.getKind(); |
83 |
8 |
Assert(ak != LT && ak != GT && ak != LEQ); |
84 |
8 |
if ((ak == EQUAL && pol) || ak == GEQ) |
85 |
|
{ |
86 |
|
// provide as < if negated >= |
87 |
12 |
Node atomu; |
88 |
6 |
if (!pol) |
89 |
|
{ |
90 |
|
atomu = nm->mkNode(LT, atom[0], atom[1]); |
91 |
|
originLit[atomu] = l; |
92 |
|
} |
93 |
|
else |
94 |
|
{ |
95 |
6 |
atomu = l; |
96 |
6 |
originLit[l] = l; |
97 |
|
} |
98 |
6 |
binfer.add(atomu); |
99 |
|
} |
100 |
8 |
Trace("learned-rewrite-ll") << "- " << l << std::endl; |
101 |
|
} |
102 |
2 |
const std::map<Node, arith::Bounds>& bs = binfer.get(); |
103 |
|
// get the literals that were critical, i.e. used in the derivation of a |
104 |
|
// bound |
105 |
8 |
for (const std::pair<const Node, arith::Bounds>& b : bs) |
106 |
|
{ |
107 |
18 |
for (size_t i = 0; i < 2; i++) |
108 |
|
{ |
109 |
24 |
Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin; |
110 |
12 |
if (!origin.isNull()) |
111 |
|
{ |
112 |
6 |
Assert (originLit.find(origin)!=originLit.end()); |
113 |
6 |
llrw.insert(originLit[origin]); |
114 |
|
} |
115 |
|
} |
116 |
|
} |
117 |
|
// rewrite the non-critical learned literals, some may be redundant |
118 |
2 |
for (const Node& l : learnedLits) |
119 |
|
{ |
120 |
2 |
if (llrw.find(l) != llrw.end()) |
121 |
|
{ |
122 |
|
continue; |
123 |
|
} |
124 |
2 |
Node e = rewriteLearnedRec(l, binfer, llrw, visited); |
125 |
2 |
if (e.isConst()) |
126 |
|
{ |
127 |
|
// ignore true |
128 |
2 |
if (e.getConst<bool>()) |
129 |
|
{ |
130 |
|
continue; |
131 |
|
} |
132 |
|
// conflict, we are done |
133 |
2 |
assertionsToPreprocess->push_back(e); |
134 |
2 |
return PreprocessingPassResult::CONFLICT; |
135 |
|
} |
136 |
|
llrw.insert(e); |
137 |
|
} |
138 |
|
Trace("learned-rewrite-ll") << "end" << std::endl; |
139 |
|
} |
140 |
|
size_t size = assertionsToPreprocess->size(); |
141 |
|
for (size_t i = 0; i < size; ++i) |
142 |
|
{ |
143 |
|
Node prev = (*assertionsToPreprocess)[i]; |
144 |
|
Trace("learned-rewrite-assert") |
145 |
|
<< "LearnedRewrite: assert: " << prev << std::endl; |
146 |
|
Node e = rewriteLearnedRec(prev, binfer, llrw, visited); |
147 |
|
if (e != prev) |
148 |
|
{ |
149 |
|
Trace("learned-rewrite-assert") |
150 |
|
<< ".......................: " << e << std::endl; |
151 |
|
assertionsToPreprocess->replace(i, e); |
152 |
|
} |
153 |
|
} |
154 |
|
// Add the conjunction of learned literals back to assertions. Notice that |
155 |
|
// in some cases we may add top-level assertions back to the assertion list |
156 |
|
// unchanged. |
157 |
|
if (!llrw.empty()) |
158 |
|
{ |
159 |
|
std::vector<Node> llrvec(llrw.begin(), llrw.end()); |
160 |
|
Node llc = nm->mkAnd(llrvec); |
161 |
|
Trace("learned-rewrite-assert") |
162 |
|
<< "Re-add rewritten learned conjunction: " << llc << std::endl; |
163 |
|
assertionsToPreprocess->push_back(llc); |
164 |
|
} |
165 |
|
|
166 |
|
return PreprocessingPassResult::NO_CONFLICT; |
167 |
|
} |
168 |
|
|
169 |
2 |
Node LearnedRewrite::rewriteLearnedRec(Node n, |
170 |
|
arith::BoundInference& binfer, |
171 |
|
std::unordered_set<Node>& lems, |
172 |
|
std::unordered_map<TNode, Node>& visited) |
173 |
|
{ |
174 |
2 |
NodeManager* nm = NodeManager::currentNM(); |
175 |
2 |
std::unordered_map<TNode, Node>::iterator it; |
176 |
4 |
std::vector<TNode> visit; |
177 |
4 |
TNode cur; |
178 |
2 |
visit.push_back(n); |
179 |
46 |
do |
180 |
|
{ |
181 |
48 |
cur = visit.back(); |
182 |
48 |
visit.pop_back(); |
183 |
48 |
it = visited.find(cur); |
184 |
48 |
if (lems.find(cur) != lems.end()) |
185 |
|
{ |
186 |
|
// n is a learned literal: replace by true, not considered a rewrite |
187 |
|
// for statistics |
188 |
|
visited[cur] = nm->mkConst(true); |
189 |
|
continue; |
190 |
|
} |
191 |
48 |
if (it == visited.end()) |
192 |
|
{ |
193 |
|
// mark pre-visited with null; will post-visit to construct final node |
194 |
|
// in the block below. |
195 |
20 |
visited[cur] = Node::null(); |
196 |
20 |
visit.push_back(cur); |
197 |
20 |
visit.insert(visit.end(), cur.begin(), cur.end()); |
198 |
|
} |
199 |
28 |
else if (it->second.isNull()) |
200 |
|
{ |
201 |
40 |
Node ret = cur; |
202 |
20 |
bool needsRcons = false; |
203 |
40 |
std::vector<Node> children; |
204 |
20 |
if (cur.getMetaKind() == kind::metakind::PARAMETERIZED) |
205 |
|
{ |
206 |
|
children.push_back(cur.getOperator()); |
207 |
|
} |
208 |
46 |
for (const Node& cn : cur) |
209 |
|
{ |
210 |
26 |
it = visited.find(cn); |
211 |
26 |
Assert(it != visited.end()); |
212 |
26 |
Assert(!it->second.isNull()); |
213 |
26 |
needsRcons = needsRcons || cn != it->second; |
214 |
26 |
children.push_back(it->second); |
215 |
|
} |
216 |
20 |
if (needsRcons) |
217 |
|
{ |
218 |
8 |
ret = nm->mkNode(cur.getKind(), children); |
219 |
|
} |
220 |
|
// rewrite here |
221 |
20 |
ret = rewriteLearned(ret, binfer, lems); |
222 |
20 |
visited[cur] = ret; |
223 |
|
} |
224 |
48 |
} while (!visit.empty()); |
225 |
2 |
Assert(visited.find(n) != visited.end()); |
226 |
2 |
Assert(!visited.find(n)->second.isNull()); |
227 |
4 |
return visited[n]; |
228 |
|
} |
229 |
|
|
230 |
20 |
Node LearnedRewrite::rewriteLearned(Node n, |
231 |
|
arith::BoundInference& binfer, |
232 |
|
std::unordered_set<Node>& lems) |
233 |
|
{ |
234 |
20 |
NodeManager* nm = NodeManager::currentNM(); |
235 |
20 |
Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl; |
236 |
20 |
Node nr = Rewriter::rewrite(n); |
237 |
20 |
Kind k = nr.getKind(); |
238 |
20 |
if (k == INTS_DIVISION || k == INTS_MODULUS || k == DIVISION) |
239 |
|
{ |
240 |
|
// simpler if we know the divisor is non-zero |
241 |
12 |
Node num = n[0]; |
242 |
12 |
Node den = n[1]; |
243 |
6 |
bool isNonZeroDen = false; |
244 |
6 |
if (den.isConst()) |
245 |
|
{ |
246 |
|
isNonZeroDen = (den.getConst<Rational>().sgn() != 0); |
247 |
|
} |
248 |
|
else |
249 |
|
{ |
250 |
12 |
arith::Bounds db = binfer.get(den); |
251 |
12 |
Trace("learned-rewrite-rr-debug") |
252 |
6 |
<< "Bounds for " << den << " : " << db.lower_value << " " |
253 |
6 |
<< db.upper_value << std::endl; |
254 |
12 |
if (!db.lower_value.isNull() |
255 |
6 |
&& db.lower_value.getConst<Rational>().sgn() == 1) |
256 |
|
{ |
257 |
6 |
isNonZeroDen = true; |
258 |
|
} |
259 |
|
else if (!db.upper_value.isNull() |
260 |
|
&& db.upper_value.getConst<Rational>().sgn() == -1) |
261 |
|
{ |
262 |
|
isNonZeroDen = true; |
263 |
|
} |
264 |
|
} |
265 |
6 |
if (isNonZeroDen) |
266 |
|
{ |
267 |
12 |
Trace("learned-rewrite-rr-debug") |
268 |
6 |
<< "...non-zero denominator" << std::endl; |
269 |
6 |
Kind nk = k; |
270 |
6 |
switch (k) |
271 |
|
{ |
272 |
|
case INTS_DIVISION: nk = INTS_DIVISION_TOTAL; break; |
273 |
6 |
case INTS_MODULUS: nk = INTS_MODULUS_TOTAL; break; |
274 |
|
case DIVISION: nk = DIVISION_TOTAL; break; |
275 |
|
default: Assert(false); break; |
276 |
|
} |
277 |
12 |
std::vector<Node> children; |
278 |
6 |
children.insert(children.end(), n.begin(), n.end()); |
279 |
12 |
Node ret = nm->mkNode(nk, children); |
280 |
6 |
nr = returnRewriteLearned(nr, ret, LearnedRewriteId::NON_ZERO_DEN); |
281 |
6 |
nr = Rewriter::rewrite(nr); |
282 |
6 |
k = nr.getKind(); |
283 |
|
} |
284 |
|
} |
285 |
|
// constant int mod elimination by bound inference |
286 |
20 |
if (k == INTS_MODULUS_TOTAL) |
287 |
|
{ |
288 |
12 |
Node num = n[0]; |
289 |
12 |
Node den = n[1]; |
290 |
12 |
arith::Bounds db = binfer.get(den); |
291 |
12 |
if ((!db.lower_value.isNull() |
292 |
6 |
&& db.lower_value.getConst<Rational>().sgn() == 1) |
293 |
12 |
|| (!db.upper_value.isNull() |
294 |
|
&& db.upper_value.getConst<Rational>().sgn() == -1)) |
295 |
|
{ |
296 |
6 |
Rational bden = db.upper_value.isNull() |
297 |
|
? db.lower_value.getConst<Rational>() |
298 |
12 |
: db.upper_value.getConst<Rational>().abs(); |
299 |
|
// if 0 <= UB(num) < LB(den) or 0 <= UB(num) < -UB(den) |
300 |
12 |
arith::Bounds nb = binfer.get(num); |
301 |
6 |
if (!nb.upper_value.isNull()) |
302 |
|
{ |
303 |
|
Rational bnum = nb.upper_value.getConst<Rational>(); |
304 |
|
if (bnum.sgn() != -1 && bnum < bden) |
305 |
|
{ |
306 |
|
nr = returnRewriteLearned(nr, nr[0], LearnedRewriteId::INT_MOD_RANGE); |
307 |
|
} |
308 |
|
} |
309 |
|
// could also do num + k*den checks |
310 |
|
} |
311 |
|
} |
312 |
14 |
else if (k == GEQ || (k == EQUAL && nr[0].getType().isReal())) |
313 |
|
{ |
314 |
|
std::map<Node, Node> msum; |
315 |
|
if (ArithMSum::getMonomialSumLit(nr, msum)) |
316 |
|
{ |
317 |
|
Rational lb(0); |
318 |
|
Rational ub(0); |
319 |
|
bool lbSuccess = true; |
320 |
|
bool ubSuccess = true; |
321 |
|
Rational one(1); |
322 |
|
if (Trace.isOn("learned-rewrite-arith-lit")) |
323 |
|
{ |
324 |
|
Trace("learned-rewrite-arith-lit") |
325 |
|
<< "Arithmetic lit: " << nr << std::endl; |
326 |
|
for (const std::pair<const Node, Node>& m : msum) |
327 |
|
{ |
328 |
|
Trace("learned-rewrite-arith-lit") |
329 |
|
<< " " << m.first << ", " << m.second << std::endl; |
330 |
|
} |
331 |
|
} |
332 |
|
for (const std::pair<const Node, Node>& m : msum) |
333 |
|
{ |
334 |
|
bool isOneCoeff = m.second.isNull(); |
335 |
|
Assert(isOneCoeff || m.second.isConst()); |
336 |
|
if (m.first.isNull()) |
337 |
|
{ |
338 |
|
lb = lb + (isOneCoeff ? one : m.second.getConst<Rational>()); |
339 |
|
ub = ub + (isOneCoeff ? one : m.second.getConst<Rational>()); |
340 |
|
} |
341 |
|
else |
342 |
|
{ |
343 |
|
arith::Bounds b = binfer.get(m.first); |
344 |
|
bool isNeg = !isOneCoeff && m.second.getConst<Rational>().sgn() == -1; |
345 |
|
// flip lower/upper if negative coefficient |
346 |
|
TNode l = isNeg ? b.upper_value : b.lower_value; |
347 |
|
TNode u = isNeg ? b.lower_value : b.upper_value; |
348 |
|
if (lbSuccess && !l.isNull()) |
349 |
|
{ |
350 |
|
Rational lc = l.getConst<Rational>(); |
351 |
|
lb = lb |
352 |
|
+ (isOneCoeff ? lc |
353 |
|
: Rational(lc * m.second.getConst<Rational>())); |
354 |
|
} |
355 |
|
else |
356 |
|
{ |
357 |
|
lbSuccess = false; |
358 |
|
} |
359 |
|
if (ubSuccess && !u.isNull()) |
360 |
|
{ |
361 |
|
Rational uc = u.getConst<Rational>(); |
362 |
|
ub = ub |
363 |
|
+ (isOneCoeff ? uc |
364 |
|
: Rational(uc * m.second.getConst<Rational>())); |
365 |
|
} |
366 |
|
else |
367 |
|
{ |
368 |
|
ubSuccess = false; |
369 |
|
} |
370 |
|
if (!lbSuccess && !ubSuccess) |
371 |
|
{ |
372 |
|
break; |
373 |
|
} |
374 |
|
} |
375 |
|
} |
376 |
|
if (lbSuccess) |
377 |
|
{ |
378 |
|
if (lb.sgn() == 1) |
379 |
|
{ |
380 |
|
// if positive lower bound, then GEQ is true, EQUAL is false |
381 |
|
Node ret = nm->mkConst(k == GEQ); |
382 |
|
nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_POS_LB); |
383 |
|
return nr; |
384 |
|
} |
385 |
|
else if (lb.sgn() == 0 && k == GEQ) |
386 |
|
{ |
387 |
|
// zero lower bound, GEQ is true |
388 |
|
Node ret = nm->mkConst(true); |
389 |
|
nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_ZERO_LB); |
390 |
|
return nr; |
391 |
|
} |
392 |
|
} |
393 |
|
else if (ubSuccess) |
394 |
|
{ |
395 |
|
if (ub.sgn() == -1) |
396 |
|
{ |
397 |
|
// if negative upper bound, then GEQ and EQUAL are false |
398 |
|
Node ret = nm->mkConst(false); |
399 |
|
nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_NEG_UB); |
400 |
|
return nr; |
401 |
|
} |
402 |
|
} |
403 |
|
} |
404 |
|
} |
405 |
20 |
return nr; |
406 |
|
} |
407 |
|
|
408 |
6 |
Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id) |
409 |
|
{ |
410 |
6 |
if (Trace.isOn("learned-rewrite")) |
411 |
|
{ |
412 |
|
Trace("learned-rewrite") << "LearnedRewrite::Rewrite: (" << id << ") " << n |
413 |
|
<< " == " << nr << std::endl; |
414 |
|
} |
415 |
6 |
d_lrewCount << id; |
416 |
6 |
return nr; |
417 |
|
} |
418 |
|
|
419 |
|
} // namespace passes |
420 |
|
} // namespace preprocessing |
421 |
29340 |
} // namespace cvc5 |