1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Andrew Reynolds, Andres Noetzli, Tianyi Liang |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Implementation of virtual term substitution term cache. |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "theory/quantifiers/cegqi/vts_term_cache.h" |
17 |
|
|
18 |
|
#include "expr/node_algorithm.h" |
19 |
|
#include "expr/skolem_manager.h" |
20 |
|
#include "theory/arith/arith_msum.h" |
21 |
|
#include "theory/quantifiers/quantifiers_inference_manager.h" |
22 |
|
#include "theory/rewriter.h" |
23 |
|
|
24 |
|
using namespace cvc5::kind; |
25 |
|
|
26 |
|
namespace cvc5 { |
27 |
|
namespace theory { |
28 |
|
namespace quantifiers { |
29 |
|
|
30 |
6207 |
VtsTermCache::VtsTermCache(QuantifiersInferenceManager& qim) : d_qim(qim) |
31 |
|
{ |
32 |
6207 |
d_zero = NodeManager::currentNM()->mkConst(Rational(0)); |
33 |
6207 |
} |
34 |
|
|
35 |
6505 |
void VtsTermCache::getVtsTerms(std::vector<Node>& t, |
36 |
|
bool isFree, |
37 |
|
bool create, |
38 |
|
bool inc_delta) |
39 |
|
{ |
40 |
6505 |
if (inc_delta) |
41 |
|
{ |
42 |
12828 |
Node delta = getVtsDelta(isFree, create); |
43 |
6414 |
if (!delta.isNull()) |
44 |
|
{ |
45 |
1575 |
t.push_back(delta); |
46 |
|
} |
47 |
|
} |
48 |
6505 |
NodeManager* nm = NodeManager::currentNM(); |
49 |
19515 |
for (unsigned r = 0; r < 2; r++) |
50 |
|
{ |
51 |
26020 |
TypeNode tn = r == 0 ? nm->realType() : nm->integerType(); |
52 |
26020 |
Node inf = getVtsInfinity(tn, isFree, create); |
53 |
13010 |
if (!inf.isNull()) |
54 |
|
{ |
55 |
1292 |
t.push_back(inf); |
56 |
|
} |
57 |
|
} |
58 |
6505 |
} |
59 |
|
|
60 |
18477 |
Node VtsTermCache::getVtsDelta(bool isFree, bool create) |
61 |
|
{ |
62 |
18477 |
if (create) |
63 |
|
{ |
64 |
107 |
NodeManager* nm = NodeManager::currentNM(); |
65 |
107 |
SkolemManager* sm = nm->getSkolemManager(); |
66 |
107 |
if (d_vts_delta_free.isNull()) |
67 |
|
{ |
68 |
25 |
d_vts_delta_free = |
69 |
100 |
sm->mkDummySkolem("delta_free", |
70 |
50 |
nm->realType(), |
71 |
25 |
"free delta for virtual term substitution"); |
72 |
50 |
Node delta_lem = nm->mkNode(GT, d_vts_delta_free, d_zero); |
73 |
25 |
d_qim.lemma(delta_lem, InferenceId::QUANTIFIERS_CEGQI_VTS_LB_DELTA); |
74 |
|
} |
75 |
107 |
if (d_vts_delta.isNull()) |
76 |
|
{ |
77 |
100 |
d_vts_delta = sm->mkDummySkolem( |
78 |
75 |
"delta", nm->realType(), "delta for virtual term substitution"); |
79 |
|
// mark as a virtual term |
80 |
|
VirtualTermSkolemAttribute vtsa; |
81 |
25 |
d_vts_delta.setAttribute(vtsa, true); |
82 |
|
} |
83 |
|
} |
84 |
18477 |
return isFree ? d_vts_delta_free : d_vts_delta; |
85 |
|
} |
86 |
|
|
87 |
27387 |
Node VtsTermCache::getVtsInfinity(TypeNode tn, bool isFree, bool create) |
88 |
|
{ |
89 |
27387 |
if (create) |
90 |
|
{ |
91 |
152 |
NodeManager* nm = NodeManager::currentNM(); |
92 |
152 |
SkolemManager* sm = nm->getSkolemManager(); |
93 |
152 |
if (d_vts_inf_free[tn].isNull()) |
94 |
|
{ |
95 |
12 |
d_vts_inf_free[tn] = sm->mkDummySkolem( |
96 |
|
"inf_free", tn, "free infinity for virtual term substitution"); |
97 |
|
} |
98 |
152 |
if (d_vts_inf[tn].isNull()) |
99 |
|
{ |
100 |
12 |
d_vts_inf[tn] = sm->mkDummySkolem( |
101 |
|
"inf", tn, "infinity for virtual term substitution"); |
102 |
|
// mark as a virtual term |
103 |
|
VirtualTermSkolemAttribute vtsa; |
104 |
12 |
d_vts_inf[tn].setAttribute(vtsa, true); |
105 |
|
} |
106 |
|
} |
107 |
27387 |
return isFree ? d_vts_inf_free[tn] : d_vts_inf[tn]; |
108 |
|
} |
109 |
|
|
110 |
228 |
Node VtsTermCache::substituteVtsFreeTerms(Node n) |
111 |
|
{ |
112 |
456 |
std::vector<Node> vars; |
113 |
228 |
getVtsTerms(vars, false, false); |
114 |
456 |
std::vector<Node> vars_free; |
115 |
228 |
getVtsTerms(vars_free, true, false); |
116 |
228 |
Assert(vars.size() == vars_free.size()); |
117 |
228 |
if (vars.empty()) |
118 |
|
{ |
119 |
|
return n; |
120 |
|
} |
121 |
|
return n.substitute( |
122 |
228 |
vars.begin(), vars.end(), vars_free.begin(), vars_free.end()); |
123 |
|
} |
124 |
|
|
125 |
1911 |
Node VtsTermCache::rewriteVtsSymbols(Node n) |
126 |
|
{ |
127 |
1911 |
NodeManager* nm = NodeManager::currentNM(); |
128 |
1911 |
if ((n.getKind() == EQUAL || n.getKind() == GEQ)) |
129 |
|
{ |
130 |
1180 |
Trace("quant-vts-debug") << "VTS : process " << n << std::endl; |
131 |
2360 |
Node rew_vts_inf; |
132 |
1180 |
bool rew_delta = false; |
133 |
|
// rewriting infinity always takes precedence over rewriting delta |
134 |
3540 |
for (unsigned r = 0; r < 2; r++) |
135 |
|
{ |
136 |
4720 |
TypeNode tn = r == 0 ? nm->realType() : nm->integerType(); |
137 |
4720 |
Node inf = getVtsInfinity(tn, false, false); |
138 |
2360 |
if (!inf.isNull() && expr::hasSubterm(n, inf)) |
139 |
|
{ |
140 |
632 |
if (rew_vts_inf.isNull()) |
141 |
|
{ |
142 |
628 |
rew_vts_inf = inf; |
143 |
|
} |
144 |
|
else |
145 |
|
{ |
146 |
|
// for mixed int/real with multiple infinities |
147 |
8 |
Trace("quant-vts-debug") << "Multiple infinities...equate " << inf |
148 |
4 |
<< " = " << rew_vts_inf << std::endl; |
149 |
8 |
std::vector<Node> subs_lhs; |
150 |
4 |
subs_lhs.push_back(inf); |
151 |
8 |
std::vector<Node> subs_rhs; |
152 |
4 |
subs_rhs.push_back(rew_vts_inf); |
153 |
4 |
n = n.substitute(subs_lhs.begin(), |
154 |
|
subs_lhs.end(), |
155 |
|
subs_rhs.begin(), |
156 |
|
subs_rhs.end()); |
157 |
4 |
n = Rewriter::rewrite(n); |
158 |
|
// may have cancelled |
159 |
4 |
if (!expr::hasSubterm(n, rew_vts_inf)) |
160 |
|
{ |
161 |
|
rew_vts_inf = Node::null(); |
162 |
|
} |
163 |
|
} |
164 |
|
} |
165 |
|
} |
166 |
1180 |
if (rew_vts_inf.isNull()) |
167 |
|
{ |
168 |
552 |
if (!d_vts_delta.isNull() && expr::hasSubterm(n, d_vts_delta)) |
169 |
|
{ |
170 |
338 |
rew_delta = true; |
171 |
|
} |
172 |
|
} |
173 |
1180 |
if (!rew_vts_inf.isNull() || rew_delta) |
174 |
|
{ |
175 |
966 |
std::map<Node, Node> msum; |
176 |
966 |
if (ArithMSum::getMonomialSumLit(n, msum)) |
177 |
|
{ |
178 |
966 |
if (Trace.isOn("quant-vts-debug")) |
179 |
|
{ |
180 |
|
Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl; |
181 |
|
ArithMSum::debugPrintMonomialSum(msum, "quant-vts-debug"); |
182 |
|
} |
183 |
1932 |
Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta; |
184 |
966 |
Assert(!vts_sym.isNull()); |
185 |
1932 |
Node iso_n; |
186 |
1932 |
Node nlit; |
187 |
966 |
int res = ArithMSum::isolate(vts_sym, msum, iso_n, n.getKind(), true); |
188 |
966 |
if (res != 0) |
189 |
|
{ |
190 |
1916 |
Trace("quant-vts-debug") << "VTS isolated : -> " << iso_n |
191 |
958 |
<< ", res = " << res << std::endl; |
192 |
1916 |
Node slv = iso_n[res == 1 ? 1 : 0]; |
193 |
|
// ensure the vts terms have been eliminated |
194 |
958 |
if (containsVtsTerm(slv)) |
195 |
|
{ |
196 |
368 |
Trace("quant-vts-warn") |
197 |
184 |
<< "Bad vts literal : " << n << ", contains " << vts_sym |
198 |
184 |
<< " but bad solved form " << slv << "." << std::endl; |
199 |
|
// safe case: just convert to free symbols |
200 |
184 |
nlit = substituteVtsFreeTerms(n); |
201 |
184 |
Trace("quant-vts-debug") << "...return " << nlit << std::endl; |
202 |
184 |
return nlit; |
203 |
|
} |
204 |
|
else |
205 |
|
{ |
206 |
774 |
if (!rew_vts_inf.isNull()) |
207 |
|
{ |
208 |
444 |
nlit = nm->mkConst(n.getKind() == GEQ && res == 1); |
209 |
|
} |
210 |
|
else |
211 |
|
{ |
212 |
330 |
Assert(iso_n[res == 1 ? 0 : 1] == d_vts_delta); |
213 |
330 |
if (n.getKind() == EQUAL) |
214 |
|
{ |
215 |
42 |
nlit = nm->mkConst(false); |
216 |
|
} |
217 |
288 |
else if (res == 1) |
218 |
|
{ |
219 |
160 |
nlit = nm->mkNode(GEQ, d_zero, slv); |
220 |
|
} |
221 |
|
else |
222 |
|
{ |
223 |
128 |
nlit = nm->mkNode(GT, slv, d_zero); |
224 |
|
} |
225 |
|
} |
226 |
|
} |
227 |
774 |
Trace("quant-vts-debug") << "Return " << nlit << std::endl; |
228 |
774 |
return nlit; |
229 |
|
} |
230 |
|
else |
231 |
|
{ |
232 |
16 |
Trace("quant-vts-warn") |
233 |
8 |
<< "Bad vts literal : " << n << ", contains " << vts_sym |
234 |
8 |
<< " but could not isolate." << std::endl; |
235 |
|
// safe case: just convert to free symbols |
236 |
8 |
nlit = substituteVtsFreeTerms(n); |
237 |
8 |
Trace("quant-vts-debug") << "...return " << nlit << std::endl; |
238 |
8 |
return nlit; |
239 |
|
} |
240 |
|
} |
241 |
|
} |
242 |
214 |
return n; |
243 |
|
} |
244 |
731 |
else if (n.getKind() == FORALL) |
245 |
|
{ |
246 |
|
// cannot traverse beneath quantifiers |
247 |
36 |
return substituteVtsFreeTerms(n); |
248 |
|
} |
249 |
695 |
bool childChanged = false; |
250 |
1390 |
std::vector<Node> children; |
251 |
2523 |
for (const Node& nc : n) |
252 |
|
{ |
253 |
3656 |
Node nn = rewriteVtsSymbols(nc); |
254 |
1828 |
children.push_back(nn); |
255 |
1828 |
childChanged = childChanged || nn != nc; |
256 |
|
} |
257 |
695 |
if (childChanged) |
258 |
|
{ |
259 |
679 |
if (n.getMetaKind() == kind::metakind::PARAMETERIZED) |
260 |
|
{ |
261 |
|
children.insert(children.begin(), n.getOperator()); |
262 |
|
} |
263 |
1358 |
Node ret = nm->mkNode(n.getKind(), children); |
264 |
679 |
Trace("quant-vts-debug") << "...make node " << ret << std::endl; |
265 |
679 |
return ret; |
266 |
|
} |
267 |
16 |
return n; |
268 |
|
} |
269 |
|
|
270 |
958 |
bool VtsTermCache::containsVtsTerm(Node n, bool isFree) |
271 |
|
{ |
272 |
1916 |
std::vector<Node> t; |
273 |
958 |
getVtsTerms(t, isFree, false); |
274 |
1916 |
return expr::hasSubterm(n, t); |
275 |
|
} |
276 |
|
|
277 |
5000 |
bool VtsTermCache::containsVtsTerm(std::vector<Node>& n, bool isFree) |
278 |
|
{ |
279 |
10000 |
std::vector<Node> t; |
280 |
5000 |
getVtsTerms(t, isFree, false); |
281 |
5000 |
if (!t.empty()) |
282 |
|
{ |
283 |
269 |
for (const Node& nc : n) |
284 |
|
{ |
285 |
187 |
if (expr::hasSubterm(nc, t)) |
286 |
|
{ |
287 |
87 |
return true; |
288 |
|
} |
289 |
|
} |
290 |
|
} |
291 |
4913 |
return false; |
292 |
|
} |
293 |
|
|
294 |
|
bool VtsTermCache::containsVtsInfinity(Node n, bool isFree) |
295 |
|
{ |
296 |
|
std::vector<Node> t; |
297 |
|
getVtsTerms(t, isFree, false, false); |
298 |
|
return expr::hasSubterm(n, t); |
299 |
|
} |
300 |
|
|
301 |
|
} // namespace quantifiers |
302 |
|
} // namespace theory |
303 |
27735 |
} // namespace cvc5 |