1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Haniel Barbosa, Andrew Reynolds, Mathias Preiner |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Implementation of sygus_unif_rl. |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "theory/quantifiers/sygus/sygus_unif_rl.h" |
17 |
|
|
18 |
|
#include "expr/skolem_manager.h" |
19 |
|
#include "options/base_options.h" |
20 |
|
#include "options/quantifiers_options.h" |
21 |
|
#include "printer/printer.h" |
22 |
|
#include "theory/quantifiers/sygus/synth_conjecture.h" |
23 |
|
#include "theory/quantifiers/sygus/term_database_sygus.h" |
24 |
|
#include "theory/rewriter.h" |
25 |
|
#include "util/random.h" |
26 |
|
|
27 |
|
#include <math.h> |
28 |
|
|
29 |
|
using namespace cvc5::kind; |
30 |
|
|
31 |
|
namespace cvc5 { |
32 |
|
namespace theory { |
33 |
|
namespace quantifiers { |
34 |
|
|
35 |
1191 |
SygusUnifRl::SygusUnifRl(SynthConjecture* p) |
36 |
1191 |
: d_parent(p), d_useCondPool(false), d_useCondPoolIGain(false) |
37 |
|
{ |
38 |
1191 |
} |
39 |
1189 |
SygusUnifRl::~SygusUnifRl() {} |
40 |
14 |
void SygusUnifRl::initializeCandidate( |
41 |
|
TermDbSygus* tds, |
42 |
|
Node f, |
43 |
|
std::vector<Node>& enums, |
44 |
|
std::map<Node, std::vector<Node>>& strategy_lemmas) |
45 |
|
{ |
46 |
|
// initialize |
47 |
28 |
std::vector<Node> all_enums; |
48 |
14 |
SygusUnif::initializeCandidate(tds, f, all_enums, strategy_lemmas); |
49 |
|
// based on the strategy inferred for each function, determine if we are |
50 |
|
// using a unification strategy that is compatible our approach. |
51 |
28 |
StrategyRestrictions restrictions; |
52 |
14 |
if (options::sygusBoolIteReturnConst()) |
53 |
|
{ |
54 |
14 |
restrictions.d_iteReturnBoolConst = true; |
55 |
|
} |
56 |
|
// register the strategy |
57 |
14 |
registerStrategy(f, enums, restrictions.d_unused_strategies); |
58 |
14 |
d_strategy[f].staticLearnRedundantOps(strategy_lemmas, restrictions); |
59 |
|
// Copy candidates and check whether CegisUnif for any of them |
60 |
14 |
if (d_unif_candidates.find(f) != d_unif_candidates.end()) |
61 |
|
{ |
62 |
11 |
d_hd_to_pt[f].clear(); |
63 |
11 |
d_cand_to_eval_hds[f].clear(); |
64 |
11 |
d_cand_to_hd_count[f] = 0; |
65 |
|
} |
66 |
|
// check whether we are using condition enumeration |
67 |
14 |
options::SygusUnifPiMode mode = options::sygusUnifPi(); |
68 |
14 |
d_useCondPool = mode == options::SygusUnifPiMode::CENUM |
69 |
14 |
|| mode == options::SygusUnifPiMode::CENUM_IGAIN; |
70 |
14 |
d_useCondPoolIGain = mode == options::SygusUnifPiMode::CENUM_IGAIN; |
71 |
14 |
} |
72 |
|
|
73 |
|
void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas) |
74 |
|
{ |
75 |
|
// we do not use notify enumeration |
76 |
|
Assert(false); |
77 |
|
} |
78 |
|
|
79 |
1464 |
Node SygusUnifRl::purifyLemma(Node n, |
80 |
|
bool ensureConst, |
81 |
|
std::vector<Node>& model_guards, |
82 |
|
BoolNodePairMap& cache) |
83 |
|
{ |
84 |
1464 |
Trace("sygus-unif-rl-purify") << "PurifyLemma : " << n << "\n"; |
85 |
|
BoolNodePairMap::const_iterator it0 = |
86 |
1464 |
cache.find(BoolNodePair(ensureConst, n)); |
87 |
1464 |
if (it0 != cache.end()) |
88 |
|
{ |
89 |
909 |
Trace("sygus-unif-rl-purify-debug") << "... already visited " << n << "\n"; |
90 |
909 |
return it0->second; |
91 |
|
} |
92 |
|
// Recurse |
93 |
555 |
unsigned size = n.getNumChildren(); |
94 |
555 |
Kind k = n.getKind(); |
95 |
|
// We retrive model value now because purified node may not have a value |
96 |
1110 |
Node nv = n; |
97 |
|
// Whether application of a function-to-synthesize |
98 |
555 |
bool fapp = (n.getKind() == DT_SYGUS_EVAL); |
99 |
555 |
bool u_fapp = false; |
100 |
555 |
bool nu_fapp = false; |
101 |
555 |
if (fapp) |
102 |
|
{ |
103 |
50 |
Assert(std::find(d_candidates.begin(), d_candidates.end(), n[0]) |
104 |
|
!= d_candidates.end()); |
105 |
|
// Whether application of a (non-)unification function-to-synthesize |
106 |
50 |
u_fapp = usingUnif(n[0]); |
107 |
50 |
nu_fapp = !usingUnif(n[0]); |
108 |
|
// get model value of non-top level applications of functions-to-synthesize |
109 |
|
// occurring under a unification function-to-synthesize |
110 |
50 |
if (ensureConst) |
111 |
|
{ |
112 |
2 |
std::map<Node, Node>::iterator it1 = d_cand_to_sol.find(n[0]); |
113 |
|
// if function-to-synthesize, retrieve its built solution to replace in |
114 |
|
// the application before computing the model value |
115 |
2 |
AlwaysAssert(!u_fapp || it1 != d_cand_to_sol.end()); |
116 |
2 |
if (it1 != d_cand_to_sol.end()) |
117 |
|
{ |
118 |
4 |
TNode cand = n[0]; |
119 |
4 |
Node tmp = n.substitute(cand, it1->second); |
120 |
|
// should be concrete, can just use the rewriter |
121 |
2 |
nv = Rewriter::rewrite(tmp); |
122 |
4 |
Trace("sygus-unif-rl-purify") |
123 |
2 |
<< "PurifyLemma : model value for " << tmp << " is " << nv << "\n"; |
124 |
|
} |
125 |
|
else |
126 |
|
{ |
127 |
|
nv = d_parent->getModelValue(n); |
128 |
|
Trace("sygus-unif-rl-purify") |
129 |
|
<< "PurifyLemma : model value for " << n << " is " << nv << "\n"; |
130 |
|
} |
131 |
2 |
Assert(n != nv); |
132 |
|
} |
133 |
|
} |
134 |
|
// Travese to purify |
135 |
555 |
bool childChanged = false; |
136 |
1110 |
std::vector<Node> children; |
137 |
555 |
NodeManager* nm = NodeManager::currentNM(); |
138 |
555 |
SkolemManager* sm = nm->getSkolemManager(); |
139 |
2048 |
for (unsigned i = 0; i < size; ++i) |
140 |
|
{ |
141 |
1543 |
if (i == 0 && fapp) |
142 |
|
{ |
143 |
50 |
children.push_back(n[i]); |
144 |
50 |
continue; |
145 |
|
} |
146 |
|
// Arguments of non-unif functions do not need to be constant |
147 |
|
Node child = purifyLemma( |
148 |
2886 |
n[i], !nu_fapp && (ensureConst || u_fapp), model_guards, cache); |
149 |
1443 |
children.push_back(child); |
150 |
1443 |
childChanged = childChanged || child != n[i]; |
151 |
|
} |
152 |
1110 |
Node nb; |
153 |
555 |
if (childChanged) |
154 |
|
{ |
155 |
348 |
if (n.getMetaKind() == metakind::PARAMETERIZED) |
156 |
|
{ |
157 |
|
Trace("sygus-unif-rl-purify-debug") |
158 |
|
<< "Node " << n << " is parameterized\n"; |
159 |
|
children.insert(children.begin(), n.getOperator()); |
160 |
|
} |
161 |
348 |
if (Trace.isOn("sygus-unif-rl-purify-debug")) |
162 |
|
{ |
163 |
|
Trace("sygus-unif-rl-purify-debug") |
164 |
|
<< "...rebuilding " << n << " with kind " << k << " and children:\n"; |
165 |
|
for (const Node& child : children) |
166 |
|
{ |
167 |
|
Trace("sygus-unif-rl-purify-debug") << "...... " << child << "\n"; |
168 |
|
} |
169 |
|
} |
170 |
348 |
nb = NodeManager::currentNM()->mkNode(k, children); |
171 |
696 |
Trace("sygus-unif-rl-purify") |
172 |
348 |
<< "PurifyLemma : transformed " << n << " into " << nb << "\n"; |
173 |
|
} |
174 |
|
else |
175 |
|
{ |
176 |
207 |
nb = n; |
177 |
|
} |
178 |
|
// Map to point enumerator every unification function-to-synthesize |
179 |
555 |
if (u_fapp) |
180 |
|
{ |
181 |
94 |
Node np; |
182 |
47 |
std::map<Node, Node>::const_iterator it2 = d_app_to_purified.find(nb); |
183 |
47 |
if (it2 == d_app_to_purified.end()) |
184 |
|
{ |
185 |
|
// Build purified head with fresh skolem and recreate node |
186 |
70 |
std::stringstream ss; |
187 |
35 |
ss << nb[0] << "_" << d_cand_to_hd_count[nb[0]]++; |
188 |
70 |
Node new_f = sm->mkDummySkolem(ss.str(), |
189 |
70 |
nb[0].getType(), |
190 |
|
"head of unif evaluation point", |
191 |
140 |
NodeManager::SKOLEM_EXACT_NAME); |
192 |
|
// Adds new enumerator to map from candidate |
193 |
70 |
Trace("sygus-unif-rl-purify") |
194 |
35 |
<< "...new enum " << new_f << " for candidate " << nb[0] << "\n"; |
195 |
35 |
d_cand_to_eval_hds[nb[0]].push_back(new_f); |
196 |
|
// Maps new enumerator to its respective tuple of arguments |
197 |
35 |
d_hd_to_pt[new_f] = |
198 |
70 |
std::vector<Node>(children.begin() + 1, children.end()); |
199 |
35 |
if (Trace.isOn("sygus-unif-rl-purify-debug")) |
200 |
|
{ |
201 |
|
Trace("sygus-unif-rl-purify-debug") << "...[" << new_f << "] --> ( "; |
202 |
|
for (const Node& pt_i : d_hd_to_pt[new_f]) |
203 |
|
{ |
204 |
|
Trace("sygus-unif-rl-purify-debug") << pt_i << " "; |
205 |
|
} |
206 |
|
Trace("sygus-unif-rl-purify-debug") << ")\n"; |
207 |
|
} |
208 |
|
// replace first child and rebulid node |
209 |
35 |
Assert(children.size() > 0); |
210 |
35 |
children[0] = new_f; |
211 |
70 |
Trace("sygus-unif-rl-purify-debug") |
212 |
35 |
<< "Make sygus eval app " << children << std::endl; |
213 |
35 |
np = nm->mkNode(DT_SYGUS_EVAL, children); |
214 |
35 |
d_app_to_purified[nb] = np; |
215 |
|
} |
216 |
|
else |
217 |
|
{ |
218 |
12 |
np = it2->second; |
219 |
|
} |
220 |
94 |
Trace("sygus-unif-rl-purify") |
221 |
47 |
<< "PurifyLemma : purified head and transformed " << nb << " into " |
222 |
47 |
<< np << "\n"; |
223 |
47 |
nb = np; |
224 |
|
} |
225 |
|
// Add equality between purified fapp and model value |
226 |
555 |
if (ensureConst && fapp) |
227 |
|
{ |
228 |
2 |
model_guards.push_back( |
229 |
4 |
NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate()); |
230 |
2 |
nb = nv; |
231 |
4 |
Trace("sygus-unif-rl-purify") |
232 |
2 |
<< "PurifyLemma : adding model eq " << model_guards.back() << "\n"; |
233 |
|
} |
234 |
555 |
nb = Rewriter::rewrite(nb); |
235 |
|
// every non-top level application of function-to-synthesize must be reduced |
236 |
|
// to a concrete constant |
237 |
555 |
Assert(!ensureConst || nb.isConst()); |
238 |
1110 |
Trace("sygus-unif-rl-purify-debug") |
239 |
555 |
<< "... caching [" << n << "] = " << nb << "\n"; |
240 |
555 |
cache[BoolNodePair(ensureConst, n)] = nb; |
241 |
555 |
return nb; |
242 |
|
} |
243 |
|
|
244 |
21 |
Node SygusUnifRl::addRefLemma(Node lemma, |
245 |
|
std::map<Node, std::vector<Node>>& eval_hds) |
246 |
|
{ |
247 |
42 |
Trace("sygus-unif-rl-purify") |
248 |
21 |
<< "Registering lemma at SygusUnif : " << lemma << "\n"; |
249 |
42 |
std::vector<Node> model_guards; |
250 |
42 |
BoolNodePairMap cache; |
251 |
|
// cache previous sizes |
252 |
42 |
std::map<Node, unsigned> prev_n_eval_hds; |
253 |
42 |
for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds) |
254 |
|
{ |
255 |
21 |
prev_n_eval_hds[cp.first] = cp.second.size(); |
256 |
|
} |
257 |
|
|
258 |
|
// Make the purified lemma which will guide the unification utility. |
259 |
21 |
Node plem = purifyLemma(lemma, false, model_guards, cache); |
260 |
21 |
if (!model_guards.empty()) |
261 |
|
{ |
262 |
1 |
model_guards.push_back(plem); |
263 |
1 |
plem = NodeManager::currentNM()->mkNode(OR, model_guards); |
264 |
|
} |
265 |
21 |
plem = Rewriter::rewrite(plem); |
266 |
21 |
Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n"; |
267 |
|
|
268 |
21 |
Trace("sygus-unif-rl-purify") << "Collect new evaluation points...\n"; |
269 |
42 |
for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds) |
270 |
|
{ |
271 |
42 |
Node c = cp.first; |
272 |
21 |
unsigned prevn = 0; |
273 |
21 |
std::map<Node, unsigned>::iterator itp = prev_n_eval_hds.find(c); |
274 |
21 |
if (itp != prev_n_eval_hds.end()) |
275 |
|
{ |
276 |
21 |
prevn = itp->second; |
277 |
|
} |
278 |
56 |
for (unsigned j = prevn, size = cp.second.size(); j < size; j++) |
279 |
|
{ |
280 |
35 |
eval_hds[c].push_back(cp.second[j]); |
281 |
|
// Add new point to respective decision trees |
282 |
35 |
Assert(d_cand_cenums.find(c) != d_cand_cenums.end()); |
283 |
70 |
for (const Node& cenum : d_cand_cenums[c]) |
284 |
|
{ |
285 |
35 |
Assert(d_cenum_to_stratpt.find(cenum) != d_cenum_to_stratpt.end()); |
286 |
70 |
for (const Node& stratpt : d_cenum_to_stratpt[cenum]) |
287 |
|
{ |
288 |
35 |
Assert(d_stratpt_to_dt.find(stratpt) != d_stratpt_to_dt.end()); |
289 |
70 |
Trace("sygus-unif-rl-dt") |
290 |
35 |
<< "Register point with head " << cp.second[j] |
291 |
35 |
<< " to strategy point " << stratpt << "\n"; |
292 |
|
// Register new point from new head |
293 |
35 |
d_stratpt_to_dt[stratpt].d_hds.push_back(cp.second[j]); |
294 |
|
} |
295 |
|
} |
296 |
|
} |
297 |
|
} |
298 |
|
|
299 |
42 |
return plem; |
300 |
|
} |
301 |
|
|
302 |
226 |
void SygusUnifRl::initializeConstructSol() {} |
303 |
226 |
void SygusUnifRl::initializeConstructSolFor(Node f) {} |
304 |
226 |
bool SygusUnifRl::constructSolution(std::vector<Node>& sols, |
305 |
|
std::vector<Node>& lemmas) |
306 |
|
{ |
307 |
226 |
initializeConstructSol(); |
308 |
226 |
bool successful = true; |
309 |
458 |
for (const Node& c : d_candidates) |
310 |
|
{ |
311 |
232 |
if (!usingUnif(c)) |
312 |
|
{ |
313 |
12 |
Node v = d_parent->getModelValue(c); |
314 |
6 |
sols.push_back(v); |
315 |
6 |
continue; |
316 |
|
} |
317 |
226 |
initializeConstructSolFor(c); |
318 |
|
Node v = constructSol( |
319 |
255 |
c, d_strategy[c].getRootEnumerator(), role_equal, 0, lemmas); |
320 |
423 |
if (v.isNull()) |
321 |
|
{ |
322 |
|
// we continue trying to build solutions to accumulate potentitial |
323 |
|
// separation conditions from other decision trees |
324 |
197 |
successful = false; |
325 |
197 |
continue; |
326 |
|
} |
327 |
29 |
sols.push_back(v); |
328 |
29 |
d_cand_to_sol[c] = v; |
329 |
|
} |
330 |
226 |
return successful; |
331 |
|
} |
332 |
|
|
333 |
226 |
Node SygusUnifRl::constructSol( |
334 |
|
Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas) |
335 |
|
{ |
336 |
226 |
indent("sygus-unif-sol", ind); |
337 |
226 |
Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl; |
338 |
|
// retrieve strategy information |
339 |
452 |
TypeNode etn = e.getType(); |
340 |
226 |
EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn); |
341 |
226 |
StrategyNode& snode = tinfo.getStrategyNode(nrole); |
342 |
226 |
if (nrole != role_equal) |
343 |
|
{ |
344 |
|
return Node::null(); |
345 |
|
} |
346 |
|
// is there a decision tree strategy? |
347 |
226 |
std::map<Node, DecisionTreeInfo>::iterator itd = d_stratpt_to_dt.find(e); |
348 |
|
// for now only considering simple case of sole "ITE(cond, e, e)" strategy |
349 |
226 |
if (itd == d_stratpt_to_dt.end()) |
350 |
|
{ |
351 |
|
return Node::null(); |
352 |
|
} |
353 |
226 |
indent("sygus-unif-sol", ind); |
354 |
226 |
Trace("sygus-unif-sol") << "...it has a decision tree strategy.\n"; |
355 |
|
// whether empty set of points |
356 |
226 |
if (d_cand_to_eval_hds[f].empty()) |
357 |
|
{ |
358 |
16 |
Trace("sygus-unif-sol") << "...... no points, return root enum value " |
359 |
8 |
<< d_parent->getModelValue(e) << "\n"; |
360 |
8 |
return d_parent->getModelValue(e); |
361 |
|
} |
362 |
218 |
EnumTypeInfoStrat* etis = snode.d_strats[itd->second.getStrategyIndex()]; |
363 |
436 |
Node sol = itd->second.buildSol(etis->d_cons, lemmas); |
364 |
218 |
Assert(d_useCondPool || !sol.isNull() || !lemmas.empty()); |
365 |
218 |
return sol; |
366 |
|
} |
367 |
|
|
368 |
346 |
bool SygusUnifRl::usingUnif(Node f) const |
369 |
|
{ |
370 |
346 |
return d_unif_candidates.find(f) != d_unif_candidates.end(); |
371 |
|
} |
372 |
|
|
373 |
11 |
Node SygusUnifRl::getConditionForEvaluationPoint(Node e) const |
374 |
|
{ |
375 |
11 |
std::map<Node, DecisionTreeInfo>::const_iterator it = d_stratpt_to_dt.find(e); |
376 |
11 |
Assert(it != d_stratpt_to_dt.end()); |
377 |
11 |
return it->second.getConditionEnumerator(); |
378 |
|
} |
379 |
|
|
380 |
226 |
void SygusUnifRl::setConditions(Node e, |
381 |
|
Node guard, |
382 |
|
const std::vector<Node>& enums, |
383 |
|
const std::vector<Node>& conds) |
384 |
|
{ |
385 |
226 |
std::map<Node, DecisionTreeInfo>::iterator it = d_stratpt_to_dt.find(e); |
386 |
226 |
Assert(it != d_stratpt_to_dt.end()); |
387 |
|
// set the conditions for the appropriate tree |
388 |
226 |
it->second.setConditions(guard, enums, conds); |
389 |
226 |
} |
390 |
|
|
391 |
408 |
std::vector<Node> SygusUnifRl::getEvalPointHeads(Node c) |
392 |
|
{ |
393 |
408 |
std::map<Node, std::vector<Node>>::iterator it = d_cand_to_eval_hds.find(c); |
394 |
408 |
if (it == d_cand_to_eval_hds.end()) |
395 |
|
{ |
396 |
|
return std::vector<Node>(); |
397 |
|
} |
398 |
408 |
return it->second; |
399 |
|
} |
400 |
|
|
401 |
1359 |
bool SygusUnifRl::usingConditionPool() const { return d_useCondPool; } |
402 |
21 |
bool SygusUnifRl::usingConditionPoolInfoGain() const |
403 |
|
{ |
404 |
21 |
return d_useCondPoolIGain; |
405 |
|
} |
406 |
14 |
void SygusUnifRl::registerStrategy( |
407 |
|
Node f, |
408 |
|
std::vector<Node>& enums, |
409 |
|
std::map<Node, std::unordered_set<unsigned>>& unused_strats) |
410 |
|
{ |
411 |
14 |
if (Trace.isOn("sygus-unif-rl-strat")) |
412 |
|
{ |
413 |
|
Trace("sygus-unif-rl-strat") |
414 |
|
<< "Strategy for " << f << " is : " << std::endl; |
415 |
|
d_strategy[f].debugPrint("sygus-unif-rl-strat"); |
416 |
|
} |
417 |
14 |
Trace("sygus-unif-rl-strat") << "Register..." << std::endl; |
418 |
28 |
Node e = d_strategy[f].getRootEnumerator(); |
419 |
28 |
std::map<Node, std::map<NodeRole, bool>> visited; |
420 |
14 |
registerStrategyNode(f, e, role_equal, visited, enums, unused_strats); |
421 |
14 |
} |
422 |
|
|
423 |
14 |
void SygusUnifRl::registerStrategyNode( |
424 |
|
Node f, |
425 |
|
Node e, |
426 |
|
NodeRole nrole, |
427 |
|
std::map<Node, std::map<NodeRole, bool>>& visited, |
428 |
|
std::vector<Node>& enums, |
429 |
|
std::map<Node, std::unordered_set<unsigned>>& unused_strats) |
430 |
|
{ |
431 |
14 |
Trace("sygus-unif-rl-strat") << " register node " << e << std::endl; |
432 |
14 |
if (visited[e].find(nrole) != visited[e].end()) |
433 |
|
{ |
434 |
|
return; |
435 |
|
} |
436 |
14 |
visited[e][nrole] = true; |
437 |
28 |
TypeNode etn = e.getType(); |
438 |
14 |
EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn); |
439 |
14 |
StrategyNode& snode = tinfo.getStrategyNode(nrole); |
440 |
27 |
for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++) |
441 |
|
{ |
442 |
13 |
EnumTypeInfoStrat* etis = snode.d_strats[j]; |
443 |
13 |
StrategyType strat = etis->d_this; |
444 |
|
// is this a simple recursive ITE strategy? |
445 |
13 |
bool success = false; |
446 |
13 |
if (strat == strat_ITE && nrole == role_equal) |
447 |
|
{ |
448 |
11 |
success = true; |
449 |
33 |
for (unsigned c = 1; c <= 2; c++) |
450 |
|
{ |
451 |
44 |
std::pair<Node, NodeRole> child = etis->d_cenum[c]; |
452 |
22 |
if (child.first != e || child.second != nrole) |
453 |
|
{ |
454 |
|
success = false; |
455 |
|
break; |
456 |
|
} |
457 |
|
} |
458 |
11 |
if (success) |
459 |
|
{ |
460 |
22 |
Node cond = etis->d_cenum[0].first; |
461 |
11 |
Assert(etis->d_cenum[0].second == role_ite_condition); |
462 |
22 |
Trace("sygus-unif-rl-strat") |
463 |
11 |
<< " ...detected recursive ITE strategy, condition enumerator : " |
464 |
11 |
<< cond << std::endl; |
465 |
|
// indicate that we will be enumerating values for cond |
466 |
11 |
registerConditionalEnumerator(f, e, cond, j); |
467 |
|
// we will be using a strategy for e |
468 |
11 |
enums.push_back(e); |
469 |
|
} |
470 |
|
} |
471 |
13 |
if (!success) |
472 |
|
{ |
473 |
2 |
unused_strats[e].insert(j); |
474 |
|
} |
475 |
|
// TODO: recurse? for (std::pair<Node, NodeRole>& cec : etis->d_cenum) |
476 |
|
} |
477 |
|
} |
478 |
|
|
479 |
11 |
void SygusUnifRl::registerConditionalEnumerator(Node f, |
480 |
|
Node e, |
481 |
|
Node cond, |
482 |
|
unsigned strategy_index) |
483 |
|
{ |
484 |
|
// only allow one decision tree per strategy point |
485 |
11 |
if (d_stratpt_to_dt.find(e) != d_stratpt_to_dt.end()) |
486 |
|
{ |
487 |
|
return; |
488 |
|
} |
489 |
|
// we will do unification for this candidate |
490 |
11 |
d_unif_candidates.insert(f); |
491 |
|
// add to the list of all conditional enumerators |
492 |
33 |
if (std::find(d_cond_enums.begin(), d_cond_enums.end(), cond) |
493 |
33 |
== d_cond_enums.end()) |
494 |
|
{ |
495 |
11 |
d_cond_enums.push_back(cond); |
496 |
11 |
d_cand_cenums[f].push_back(cond); |
497 |
11 |
d_cenum_to_stratpt[cond].clear(); |
498 |
|
} |
499 |
|
// register that this strategy node has a decision tree construction |
500 |
11 |
d_stratpt_to_dt[e].initialize(cond, this, &d_strategy[f], strategy_index); |
501 |
|
// associate conditional enumerator with strategy node |
502 |
11 |
d_cenum_to_stratpt[cond].push_back(e); |
503 |
|
} |
504 |
|
|
505 |
11 |
void SygusUnifRl::DecisionTreeInfo::initialize(Node cond_enum, |
506 |
|
SygusUnifRl* unif, |
507 |
|
SygusUnifStrategy* strategy, |
508 |
|
unsigned strategy_index) |
509 |
|
{ |
510 |
11 |
d_cond_enum = cond_enum; |
511 |
11 |
d_unif = unif; |
512 |
11 |
d_strategy = strategy; |
513 |
11 |
d_strategy_index = strategy_index; |
514 |
11 |
d_true = NodeManager::currentNM()->mkConst(true); |
515 |
11 |
d_false = NodeManager::currentNM()->mkConst(false); |
516 |
|
// Retrieve template |
517 |
11 |
EnumInfo& eiv = d_strategy->getEnumInfo(d_cond_enum); |
518 |
11 |
d_template = NodePair(eiv.d_template, eiv.d_template_arg); |
519 |
|
// Initialize classifier |
520 |
11 |
d_pt_sep.initialize(this); |
521 |
11 |
} |
522 |
|
|
523 |
226 |
void SygusUnifRl::DecisionTreeInfo::setConditions( |
524 |
|
Node guard, const std::vector<Node>& enums, const std::vector<Node>& conds) |
525 |
|
{ |
526 |
226 |
Assert(enums.size() == conds.size()); |
527 |
|
// set the guard |
528 |
226 |
d_guard = guard; |
529 |
|
// clear old condition values |
530 |
226 |
d_enums.clear(); |
531 |
226 |
d_conds.clear(); |
532 |
|
// set new condition values |
533 |
226 |
d_enums.insert(d_enums.end(), enums.begin(), enums.end()); |
534 |
226 |
d_conds.insert(d_conds.end(), conds.begin(), conds.end()); |
535 |
|
// add to condition pool |
536 |
226 |
if (d_unif->usingConditionPool()) |
537 |
|
{ |
538 |
|
d_cond_mvs.insert(conds.begin(), conds.end()); |
539 |
|
if (Trace.isOn("sygus-unif-cond-pool")) |
540 |
|
{ |
541 |
|
for (const Node& condv : conds) |
542 |
|
{ |
543 |
|
if (d_cond_mvs.find(condv) == d_cond_mvs.end()) |
544 |
|
{ |
545 |
|
Trace("sygus-unif-cond-pool") |
546 |
|
<< " ...adding to condition pool : " |
547 |
|
<< d_unif->d_tds->sygusToBuiltin(condv, condv.getType()) << "\n"; |
548 |
|
} |
549 |
|
} |
550 |
|
} |
551 |
|
} |
552 |
226 |
} |
553 |
|
|
554 |
218 |
unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const |
555 |
|
{ |
556 |
218 |
return d_strategy_index; |
557 |
|
} |
558 |
|
|
559 |
218 |
Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons, |
560 |
|
std::vector<Node>& lemmas) |
561 |
|
{ |
562 |
218 |
if (!d_template.first.isNull()) |
563 |
|
{ |
564 |
|
Trace("sygus-unif-sol") << "...templated conditions unsupported\n"; |
565 |
|
return Node::null(); |
566 |
|
} |
567 |
436 |
Trace("sygus-unif-sol") << "Decision::buildSol with " << d_hds.size() |
568 |
436 |
<< " evaluation heads and " << d_conds.size() |
569 |
218 |
<< " conditions..." << std::endl; |
570 |
|
// reset the trie |
571 |
218 |
d_pt_sep.d_trie.clear(); |
572 |
218 |
return d_unif->usingConditionPool() ? buildSolAllCond(cons, lemmas) |
573 |
218 |
: buildSolMinCond(cons, lemmas); |
574 |
|
} |
575 |
|
|
576 |
|
Node SygusUnifRl::DecisionTreeInfo::buildSolAllCond(Node cons, |
577 |
|
std::vector<Node>& lemmas) |
578 |
|
{ |
579 |
|
// model values for evaluation heads |
580 |
|
std::map<Node, Node> hd_mv; |
581 |
|
// add conditions |
582 |
|
d_conds.clear(); |
583 |
|
d_conds.insert(d_conds.end(), d_cond_mvs.begin(), d_cond_mvs.end()); |
584 |
|
// shuffle conditions before bulding DT |
585 |
|
// |
586 |
|
// this does not impact whether it's possible to build a solution, but it does |
587 |
|
// impact the potential size of the resulting solution (can make it smaller, |
588 |
|
// bigger, or have no impact) and which conditions will be present in the DT, |
589 |
|
// which influences the "quality" of the solution for cases not covered in the |
590 |
|
// current data points |
591 |
|
if (options::sygusUnifShuffleCond()) |
592 |
|
{ |
593 |
|
std::shuffle(d_conds.begin(), d_conds.end(), Random::getRandom()); |
594 |
|
} |
595 |
|
unsigned num_conds = d_conds.size(); |
596 |
|
for (unsigned i = 0; i < num_conds; ++i) |
597 |
|
{ |
598 |
|
d_pt_sep.d_trie.addClassifier(&d_pt_sep, i); |
599 |
|
} |
600 |
|
// add heads |
601 |
|
for (const Node& e : d_hds) |
602 |
|
{ |
603 |
|
Node v = d_unif->d_parent->getModelValue(e); |
604 |
|
hd_mv[e] = v; |
605 |
|
Node er = d_pt_sep.d_trie.add(e, &d_pt_sep, num_conds); |
606 |
|
// are we in conflict? |
607 |
|
if (er == e) |
608 |
|
{ |
609 |
|
// new separation class, no conflict |
610 |
|
continue; |
611 |
|
} |
612 |
|
Assert(hd_mv.find(er) != hd_mv.end()); |
613 |
|
// merged into separation class with same model value, no conflict |
614 |
|
if (hd_mv[e] == hd_mv[er]) |
615 |
|
{ |
616 |
|
continue; |
617 |
|
} |
618 |
|
// conflict. Explanation? |
619 |
|
Trace("sygus-unif-sol") |
620 |
|
<< " ...can't separate " << e << " from " << er << std::endl; |
621 |
|
return Node::null(); |
622 |
|
} |
623 |
|
Trace("sygus-unif-sol") << "...ready to build solution from DT\n"; |
624 |
|
Node sol = extractSol(cons, hd_mv); |
625 |
|
// repeated solution |
626 |
|
if (options::sygusUnifCondIndNoRepeatSol() |
627 |
|
&& d_sols.find(sol) != d_sols.end()) |
628 |
|
{ |
629 |
|
return Node::null(); |
630 |
|
} |
631 |
|
d_sols.insert(sol); |
632 |
|
return sol; |
633 |
|
} |
634 |
|
|
635 |
218 |
Node SygusUnifRl::DecisionTreeInfo::buildSolMinCond(Node cons, |
636 |
|
std::vector<Node>& lemmas) |
637 |
|
{ |
638 |
218 |
NodeManager* nm = NodeManager::currentNM(); |
639 |
|
// model values for evaluation heads |
640 |
436 |
std::map<Node, Node> hd_mv; |
641 |
|
// the current explanation of why there has not yet been a separation conflict |
642 |
436 |
std::vector<Node> exp; |
643 |
|
// is the above explanation ready to be sent out as a lemma? |
644 |
218 |
bool exp_conflict = false; |
645 |
|
// the index of the head we are considering |
646 |
218 |
unsigned hd_counter = 0; |
647 |
|
// the index of the condition we are considering |
648 |
218 |
unsigned c_counter = 0; |
649 |
|
// do we need to resolve a separation conflict? |
650 |
218 |
bool needs_sep_resolve = false; |
651 |
|
// This loop simultaneously builds the solution in terms of a lazy trie |
652 |
|
// (LazyTrieMulti), and checks whether a separation conflict exists. We |
653 |
|
// enforce that the separation conflicts we encounter while building |
654 |
|
// this solution are resolved, in order, by the condition enumerators. |
655 |
|
// If not, then we add a (conflict) lemma stating that the current model |
656 |
|
// value of the condition enumerator must be different. We also call this |
657 |
|
// a "separation lemma". |
658 |
|
// |
659 |
|
// As a simple example, say we have: |
660 |
|
// evalution heads: (eval e1 0 0), (eval e2 1 2) |
661 |
|
// conditions: c1 |
662 |
|
// where M(e1) = x, M(e2) = y, and M(c1) = x>1. After adding e1 and e2, we are |
663 |
|
// in conflict since { e1, e2 } form a separation class, M(e1)!=M(e2), and |
664 |
|
// M(c1) does not separate e1 and e2 since: |
665 |
|
// (x>1){x->0,y->0} = (x>1){x->1,y->2} = false |
666 |
|
// Hence, we would fail to build a solution in this case, and instead send a |
667 |
|
// separation lemma of the form: |
668 |
|
// ~( e1 != e2 ^ c1 = [x<1] ) |
669 |
|
// |
670 |
|
// Say we have: |
671 |
|
// evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3) |
672 |
|
// conditions: c1 c2 |
673 |
|
// where M(e1) = x, M(e2) = y, M(e3) = x+1, M(c1) = x>0 and M(c2) = x<0. |
674 |
|
// After adding e1 and e2, { e1, e2 } form a separation class, M(e1)!=M(e2), |
675 |
|
// but M(c1) separates e1 and e2 since |
676 |
|
// (x>0){x->0,y->0} = false, and |
677 |
|
// (x>1){x->1,y->2} = true |
678 |
|
// Hence, we get new separation classes { e1 } and { e2 }, and afterwards |
679 |
|
// add e3. We then get { e2, e3 } as a separation class, which is also a |
680 |
|
// conflict since M(e2)!=M(e3). We check if M(c2) resolves this conflict. |
681 |
|
// It does not, since (x<1){x->0,y->0} = (x<1){x->1,y->2} = false. Hence, |
682 |
|
// we get a separation lemma: |
683 |
|
// ~( c1 = [x>1] ^ e2 != e3 ^ c2 = [x<1] ) |
684 |
|
// |
685 |
|
// Say we have: |
686 |
|
// evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3) |
687 |
|
// conditions: c1 |
688 |
|
// where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = x>0. |
689 |
|
// After adding e1 and e2, we have separation class { e1, e2 }. This is not a |
690 |
|
// conflict since M(e1)=M(e2). We then add e3, obtaining separation class |
691 |
|
// { e1, e2, e3 }, which is in conflict since M(e3)!=M(e1), and the condition |
692 |
|
// c1 does not separate e3 and the representative of this class, e1. Hence we |
693 |
|
// get a separation lemma of the form: |
694 |
|
// ~( e1 = e2 ^ e1 != e3 ^ c1 = [x>0] ) |
695 |
|
// |
696 |
|
// It also may be the case that we exhaust the pool of condition enumerators. |
697 |
|
// Say we have: |
698 |
|
// evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3) |
699 |
|
// conditions: c1 |
700 |
|
// where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = y>0. After adding e1, e2, |
701 |
|
// and e3, we have a separation class { e1, e2, e3 } that is in conflict |
702 |
|
// since M(e3)!=M(e1). We add the condition c1, which separates into new |
703 |
|
// equivalence classes { e1 }, { e2, e3 }. We are still in separation conflict |
704 |
|
// since M(e3)!=M(e2). However, we do not have any further conditions to use |
705 |
|
// to resolve this conflict. Thus, we add the separation lemma: |
706 |
|
// ~( e1 = e2 ^ e1 != e3 ^ e2 != e3 ^ c1 = [y>0] ^ G_1 ) |
707 |
|
// where G_1 is a guard stating that we use at most 1 condition. |
708 |
436 |
Node e; |
709 |
436 |
Node er; |
710 |
2642 |
while (hd_counter < d_hds.size() || needs_sep_resolve) |
711 |
|
{ |
712 |
1409 |
if (!needs_sep_resolve) |
713 |
|
{ |
714 |
|
// add the head to the trie |
715 |
1390 |
e = d_hds[hd_counter]; |
716 |
1390 |
hd_mv[e] = d_unif->d_parent->getModelValue(e); |
717 |
1390 |
if (Trace.isOn("sygus-unif-sol")) |
718 |
|
{ |
719 |
|
std::stringstream ss; |
720 |
|
TermDbSygus::toStreamSygus(ss, hd_mv[e]); |
721 |
|
Trace("sygus-unif-sol") |
722 |
|
<< " add evaluation head (" << hd_counter << "/" << d_hds.size() |
723 |
|
<< "): " << e << " -> " << ss.str() << std::endl; |
724 |
|
} |
725 |
1390 |
hd_counter++; |
726 |
|
// get the representative of the trie |
727 |
1390 |
er = d_pt_sep.d_trie.add(e, &d_pt_sep, c_counter); |
728 |
1390 |
Trace("sygus-unif-sol") << " ...separation class " << er << std::endl; |
729 |
|
// are we in conflict? |
730 |
1390 |
if (er == e) |
731 |
|
{ |
732 |
|
// new separation class, no conflict |
733 |
1411 |
continue; |
734 |
|
} |
735 |
1172 |
Assert(hd_mv.find(er) != hd_mv.end()); |
736 |
2015 |
if (hd_mv[er] == hd_mv[e]) |
737 |
|
{ |
738 |
|
// merged into separation class with same model value, no conflict |
739 |
|
// add to explanation |
740 |
|
// this states that it mattered that (er = e) at the time that e was |
741 |
|
// added to the trie. Notice that er and e may become separated later, |
742 |
|
// but to ensure the overall invariant, this equality must persist in |
743 |
|
// the explanation. |
744 |
843 |
exp.push_back(er.eqNode(e)); |
745 |
843 |
Trace("sygus-unif-sol") << " ...equal model values " << std::endl; |
746 |
1686 |
Trace("sygus-unif-sol") |
747 |
843 |
<< " ...add to explanation " << er.eqNode(e) << std::endl; |
748 |
843 |
continue; |
749 |
|
} |
750 |
|
} |
751 |
|
// must include in the explanation that we hit a conflict at this point in |
752 |
|
// the construction |
753 |
348 |
exp.push_back(e.eqNode(er).negate()); |
754 |
|
// we are in separation conflict, does the next condition resolve this? |
755 |
|
// check whether we have have exhausted our condition pool. If so, we |
756 |
|
// are in conflict and this conflict depends on the guard. |
757 |
348 |
if (c_counter >= d_conds.size()) |
758 |
|
{ |
759 |
|
// truncated separation lemma |
760 |
123 |
Assert(!d_guard.isNull()); |
761 |
123 |
exp.push_back(d_guard); |
762 |
123 |
exp_conflict = true; |
763 |
320 |
break; |
764 |
|
} |
765 |
225 |
Assert(c_counter < d_conds.size()); |
766 |
244 |
Node ce = d_enums[c_counter]; |
767 |
244 |
Node cv = d_conds[c_counter]; |
768 |
225 |
Assert(ce.getType() == cv.getType()); |
769 |
225 |
if (Trace.isOn("sygus-unif-sol")) |
770 |
|
{ |
771 |
|
std::stringstream ss; |
772 |
|
TermDbSygus::toStreamSygus(ss, cv); |
773 |
|
Trace("sygus-unif-sol") |
774 |
|
<< " add condition (" << c_counter << "/" << d_conds.size() |
775 |
|
<< "): " << ce << " -> " << ss.str() << std::endl; |
776 |
|
} |
777 |
|
// cache the separation class |
778 |
244 |
std::vector<Node> prev_sep_c = d_pt_sep.d_trie.d_rep_to_class[er]; |
779 |
|
// add new classifier |
780 |
225 |
d_pt_sep.d_trie.addClassifier(&d_pt_sep, c_counter); |
781 |
225 |
c_counter++; |
782 |
|
// add to explanation |
783 |
|
// c_exp is a conjunction of testers applied to shared selector chains |
784 |
244 |
Node c_exp = d_unif->d_tds->getExplain()->getExplanationForEquality(ce, cv); |
785 |
225 |
exp.push_back(c_exp); |
786 |
|
std::map<Node, std::vector<Node>>::iterator itr = |
787 |
225 |
d_pt_sep.d_trie.d_rep_to_class.find(e); |
788 |
|
// since e is last in its separation class, if it becomes a representative, |
789 |
|
// then it is separated from all values in prev_sep_c |
790 |
357 |
if (itr != d_pt_sep.d_trie.d_rep_to_class.end()) |
791 |
|
{ |
792 |
264 |
Trace("sygus-unif-sol") |
793 |
132 |
<< " ...resolves separation conflict with all" << std::endl; |
794 |
132 |
needs_sep_resolve = false; |
795 |
132 |
continue; |
796 |
|
} |
797 |
93 |
itr = d_pt_sep.d_trie.d_rep_to_class.find(er); |
798 |
|
// since er is first in its separation class, it remains a representative |
799 |
93 |
Assert(itr != d_pt_sep.d_trie.d_rep_to_class.end()); |
800 |
|
// is e still in the separation class of er? |
801 |
279 |
if (std::find(itr->second.begin(), itr->second.end(), e) |
802 |
279 |
!= itr->second.end()) |
803 |
|
{ |
804 |
148 |
Trace("sygus-unif-sol") |
805 |
74 |
<< " ...does not resolve separation conflict with current" |
806 |
74 |
<< std::endl; |
807 |
|
// the condition does not separate e and er |
808 |
|
// this violates the invariant that the i^th conditional enumerator |
809 |
|
// resolves the i^th separation conflict |
810 |
74 |
exp_conflict = true; |
811 |
74 |
SygusTypeInfo& ti = d_unif->d_tds->getTypeInfo(ce.getType()); |
812 |
|
// The reasoning below is only necessary if we use symbolic constructors. |
813 |
74 |
if (!ti.hasSubtermSymbolicCons()) |
814 |
|
{ |
815 |
74 |
break; |
816 |
|
} |
817 |
|
// Since the explanation of the condition (c_exp above) does not account |
818 |
|
// for builtin subterms, we additionally require that the valuation of |
819 |
|
// the condition is indeed different on the two points. |
820 |
|
// For example, say ce has model value equal to the SyGuS datatype term: |
821 |
|
// C_leq_xy( 0, 1 ) |
822 |
|
// where C_leq_xy is a SyGuS datatype constructor taking two integer |
823 |
|
// constants c_x and c_y, and whose builtin version is: |
824 |
|
// (0*x + 1*y >= 0) |
825 |
|
// Then, c_exp above is: |
826 |
|
// is-C_leq_xy( ce ) |
827 |
|
// which is added to our explanation of the conflict, which does not |
828 |
|
// account for the values of the arguments of C_leq_xy. |
829 |
|
// Now, say that we are in a separation conflict due to f(1,2) and f(2,3) |
830 |
|
// being assigned different values; the value of ce does not separate |
831 |
|
// these two terms since: |
832 |
|
// (y>=0) { x -> 1, y -> 2 } = (y>=0) { x -> 2, y -> 3 } = true |
833 |
|
// The code below adds a constraint that states that the above values are |
834 |
|
// the same, which is part of the reason for the conflict. In the above |
835 |
|
// example, we generate: |
836 |
|
// (DT_SYGUS_EVAL ce 1 2) == (DT_SYGUS_EVAL ce 2 3) { ce -> M(ce) } |
837 |
|
// which unfolds via the SygusEvalUnfold utility to: |
838 |
|
// ( (c_x ce)*1 + (c_y ce)*2 >= 0 ) == ( (c_x ce)*2 + (c_y ce)*3 >= 0 ) |
839 |
|
// where c_x and c_y are the selectors of the subfields of C_leq_xy. |
840 |
|
Trace("sygus-unif-sol-sym") |
841 |
|
<< "Explain symbolic separation conflict" << std::endl; |
842 |
|
std::map<Node, std::vector<Node>>::iterator ith; |
843 |
|
Node ceApp[2]; |
844 |
|
SygusEvalUnfold* eunf = d_unif->d_tds->getEvalUnfold(); |
845 |
|
std::map<Node, Node> vtm; |
846 |
|
vtm[ce] = cv; |
847 |
|
Trace("sygus-unif-sol-sym") |
848 |
|
<< "Model value for " << ce << " is " << cv << std::endl; |
849 |
|
for (unsigned r = 0; r < 2; r++) |
850 |
|
{ |
851 |
|
std::vector<Node> cechildren; |
852 |
|
cechildren.push_back(ce); |
853 |
|
Node ecurr = r == 0 ? e : er; |
854 |
|
ith = d_unif->d_hd_to_pt.find(ecurr); |
855 |
|
AlwaysAssert(ith != d_unif->d_hd_to_pt.end()); |
856 |
|
cechildren.insert( |
857 |
|
cechildren.end(), ith->second.begin(), ith->second.end()); |
858 |
|
Node cea = nm->mkNode(DT_SYGUS_EVAL, cechildren); |
859 |
|
Trace("sygus-unif-sol-sym") |
860 |
|
<< "Sep conflict app #" << r << " : " << cea << std::endl; |
861 |
|
std::vector<Node> tmpExp; |
862 |
|
cea = eunf->unfold(cea, vtm, tmpExp, true, true); |
863 |
|
Trace("sygus-unif-sol-sym") << "Unfolded to : " << cea << std::endl; |
864 |
|
ceApp[r] = cea; |
865 |
|
} |
866 |
|
Node ceAppEq = ceApp[0].eqNode(ceApp[1]); |
867 |
|
Trace("sygus-unif-sol-sym") |
868 |
|
<< "Sep conflict app explanation is : " << ceAppEq << std::endl; |
869 |
|
exp.push_back(ceAppEq); |
870 |
|
break; |
871 |
|
} |
872 |
38 |
Trace("sygus-unif-sol") |
873 |
19 |
<< " ...resolves separation conflict with current, but not all" |
874 |
19 |
<< std::endl; |
875 |
|
// find the new term to resolve a separation |
876 |
38 |
Node new_er = Node::null(); |
877 |
|
// scan the previous list and find the representative of the class that e is |
878 |
|
// now in |
879 |
50 |
for (unsigned i = 0, size = prev_sep_c.size(); i < size; i++) |
880 |
|
{ |
881 |
81 |
Node check_er = prev_sep_c[i]; |
882 |
50 |
if (check_er != er && check_er != e) |
883 |
|
{ |
884 |
31 |
itr = d_pt_sep.d_trie.d_rep_to_class.find(check_er); |
885 |
31 |
if (itr != d_pt_sep.d_trie.d_rep_to_class.end()) |
886 |
|
{ |
887 |
57 |
if (std::find(itr->second.begin(), itr->second.end(), e) |
888 |
57 |
!= itr->second.end()) |
889 |
|
{ |
890 |
19 |
new_er = check_er; |
891 |
19 |
break; |
892 |
|
} |
893 |
|
} |
894 |
|
} |
895 |
|
} |
896 |
|
// should find exactly one |
897 |
19 |
Assert(!new_er.isNull()); |
898 |
19 |
er = new_er; |
899 |
19 |
needs_sep_resolve = true; |
900 |
|
} |
901 |
218 |
if (exp_conflict) |
902 |
|
{ |
903 |
394 |
Node lemma = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp); |
904 |
197 |
lemma = lemma.negate(); |
905 |
197 |
Trace("sygus-unif-sol") << " ......conflict is " << lemma << std::endl; |
906 |
197 |
lemmas.push_back(lemma); |
907 |
197 |
return Node::null(); |
908 |
|
} |
909 |
|
|
910 |
21 |
Trace("sygus-unif-sol") << "...ready to build solution from DT\n"; |
911 |
21 |
return extractSol(cons, hd_mv); |
912 |
|
} |
913 |
|
|
914 |
21 |
Node SygusUnifRl::DecisionTreeInfo::extractSol(Node cons, |
915 |
|
std::map<Node, Node>& hd_mv) |
916 |
|
{ |
917 |
|
// rebuild decision tree using heuristic learning |
918 |
21 |
if (d_unif->usingConditionPoolInfoGain()) |
919 |
|
{ |
920 |
|
recomputeSolHeuristically(hd_mv); |
921 |
|
} |
922 |
21 |
return d_pt_sep.extractSol(cons, hd_mv); |
923 |
|
} |
924 |
|
|
925 |
21 |
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::extractSol( |
926 |
|
Node cons, std::map<Node, Node>& hd_mv) |
927 |
|
{ |
928 |
|
// Traverse trie and build ITE with cons |
929 |
21 |
NodeManager* nm = NodeManager::currentNM(); |
930 |
42 |
std::map<IndTriePair, Node> cache; |
931 |
21 |
std::map<IndTriePair, Node>::iterator it; |
932 |
42 |
std::vector<IndTriePair> visit; |
933 |
21 |
unsigned index = 0; |
934 |
|
LazyTrie* trie; |
935 |
21 |
IndTriePair root = IndTriePair(0, &d_trie.d_trie); |
936 |
21 |
visit.push_back(root); |
937 |
179 |
while (!visit.empty()) |
938 |
|
{ |
939 |
79 |
index = visit.back().first; |
940 |
79 |
trie = visit.back().second; |
941 |
79 |
visit.pop_back(); |
942 |
79 |
IndTriePair cur = IndTriePair(index, trie); |
943 |
79 |
it = cache.find(cur); |
944 |
|
// traverse children so results are saved to build node for parent |
945 |
79 |
if (it == cache.end()) |
946 |
|
{ |
947 |
|
// leaf |
948 |
98 |
if (trie->d_children.empty()) |
949 |
|
{ |
950 |
39 |
Assert(hd_mv.find(trie->d_lazy_child) != hd_mv.end()); |
951 |
39 |
cache[cur] = hd_mv[trie->d_lazy_child]; |
952 |
78 |
Trace("sygus-unif-sol-debug") << "......leaf, build " |
953 |
117 |
<< d_dt->d_unif->d_tds->sygusToBuiltin( |
954 |
117 |
cache[cur], cache[cur].getType()) |
955 |
39 |
<< "\n"; |
956 |
100 |
continue; |
957 |
|
} |
958 |
20 |
cache[cur] = Node::null(); |
959 |
20 |
visit.push_back(cur); |
960 |
58 |
for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children) |
961 |
|
{ |
962 |
38 |
visit.push_back(IndTriePair(index + 1, &p_nt.second)); |
963 |
|
} |
964 |
20 |
continue; |
965 |
|
} |
966 |
|
// retrieve terms of children and build result |
967 |
20 |
Assert(it->second.isNull()); |
968 |
20 |
Assert(trie->d_children.size() == 1 || trie->d_children.size() == 2); |
969 |
38 |
std::vector<Node> children(4); |
970 |
20 |
children[0] = cons; |
971 |
20 |
children[1] = d_dt->d_conds[index]; |
972 |
20 |
unsigned i = 0; |
973 |
58 |
for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children) |
974 |
|
{ |
975 |
38 |
i = p_nt.first.getConst<bool>() ? 2 : 3; |
976 |
38 |
Assert(cache.find(IndTriePair(index + 1, &p_nt.second)) != cache.end()); |
977 |
38 |
children[i] = cache[IndTriePair(index + 1, &p_nt.second)]; |
978 |
38 |
Assert(!children[i].isNull()); |
979 |
|
} |
980 |
|
// condition is useless or result children are equal, no no need for ITE |
981 |
22 |
if (trie->d_children.size() == 1 || children[2] == children[3]) |
982 |
|
{ |
983 |
2 |
cache[cur] = children[i]; |
984 |
4 |
Trace("sygus-unif-sol-debug") |
985 |
2 |
<< "......no need for cond " |
986 |
6 |
<< d_dt->d_unif->d_tds->sygusToBuiltin(d_dt->d_conds[index], |
987 |
6 |
d_dt->d_conds[index].getType()) |
988 |
2 |
<< ", build " |
989 |
6 |
<< d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur], |
990 |
6 |
cache[cur].getType()) |
991 |
2 |
<< "\n"; |
992 |
2 |
continue; |
993 |
|
} |
994 |
18 |
Assert(trie->d_children.size() == 2); |
995 |
18 |
cache[cur] = nm->mkNode(APPLY_CONSTRUCTOR, children); |
996 |
36 |
Trace("sygus-unif-sol-debug") |
997 |
18 |
<< "......build node " |
998 |
36 |
<< d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur], cache[cur].getType()) |
999 |
18 |
<< "\n"; |
1000 |
|
} |
1001 |
21 |
Assert(cache.find(root) != cache.end()); |
1002 |
21 |
Assert(!cache.find(root)->second.isNull()); |
1003 |
42 |
return cache[root]; |
1004 |
|
} |
1005 |
|
|
1006 |
|
void SygusUnifRl::DecisionTreeInfo::recomputeSolHeuristically( |
1007 |
|
std::map<Node, Node>& hd_mv) |
1008 |
|
{ |
1009 |
|
// reset the trie |
1010 |
|
d_pt_sep.d_trie.clear(); |
1011 |
|
// TODO workaround and not really sure this is the last condition, since I put |
1012 |
|
// a set here. Maybe make d_cond_mvs into a vector |
1013 |
|
Node backup_last_cond = d_conds.back(); |
1014 |
|
d_conds.clear(); |
1015 |
|
for (const Node& e : d_hds) |
1016 |
|
{ |
1017 |
|
d_pt_sep.d_trie.add(e, &d_pt_sep, 0); |
1018 |
|
} |
1019 |
|
// init vector of conds |
1020 |
|
std::vector<Node> conds; |
1021 |
|
conds.insert(conds.end(), d_cond_mvs.begin(), d_cond_mvs.end()); |
1022 |
|
|
1023 |
|
// recursively build trie by picking best condition for respective points |
1024 |
|
buildDtInfoGain(d_hds, conds, hd_mv, 1); |
1025 |
|
// if no condition was added (i.e. points are already classified at root |
1026 |
|
// level), use last condition as candidate |
1027 |
|
if (d_conds.empty()) |
1028 |
|
{ |
1029 |
|
Trace("sygus-unif-dt") << "......using last condition " |
1030 |
|
<< d_unif->d_tds->sygusToBuiltin( |
1031 |
|
backup_last_cond, backup_last_cond.getType()) |
1032 |
|
<< " as candidate\n"; |
1033 |
|
d_conds.push_back(backup_last_cond); |
1034 |
|
d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1); |
1035 |
|
} |
1036 |
|
} |
1037 |
|
|
1038 |
|
void SygusUnifRl::DecisionTreeInfo::buildDtInfoGain(std::vector<Node>& hds, |
1039 |
|
std::vector<Node> conds, |
1040 |
|
std::map<Node, Node>& hd_mv, |
1041 |
|
int ind) |
1042 |
|
{ |
1043 |
|
// test if fully classified |
1044 |
|
if (hds.size() < 2) |
1045 |
|
{ |
1046 |
|
indent("sygus-unif-dt", ind); |
1047 |
|
Trace("sygus-unif-dt") << "..set fully classified: " |
1048 |
|
<< (hds.empty() ? "empty" : "unary") << "\n"; |
1049 |
|
return; |
1050 |
|
} |
1051 |
|
Node v1 = hd_mv[hds[0]]; |
1052 |
|
unsigned i = 1, size = hds.size(); |
1053 |
|
for (; i < size; ++i) |
1054 |
|
{ |
1055 |
|
if (hd_mv[hds[i]] != v1) |
1056 |
|
{ |
1057 |
|
break; |
1058 |
|
} |
1059 |
|
} |
1060 |
|
if (i == size) |
1061 |
|
{ |
1062 |
|
indent("sygus-unif-dt", ind); |
1063 |
|
Trace("sygus-unif-dt") << "..set fully classified: " << hds.size() << " " |
1064 |
|
<< (d_unif->d_tds->sygusToBuiltin(v1, v1.getType()) |
1065 |
|
== d_true |
1066 |
|
? "good" |
1067 |
|
: "bad") |
1068 |
|
<< " points\n"; |
1069 |
|
return; |
1070 |
|
} |
1071 |
|
// pick condition to further classify |
1072 |
|
double maxgain = -1; |
1073 |
|
unsigned picked_cond = 0; |
1074 |
|
std::vector<std::pair<std::vector<Node>, std::vector<Node>>> splits; |
1075 |
|
double current_set_entropy = getEntropy(hds, hd_mv, ind); |
1076 |
|
for (unsigned j = 0, conds_size = conds.size(); j < conds_size; ++j) |
1077 |
|
{ |
1078 |
|
std::pair<std::vector<Node>, std::vector<Node>> split = |
1079 |
|
evaluateCond(hds, conds[j]); |
1080 |
|
splits.push_back(split); |
1081 |
|
Assert(hds.size() == split.first.size() + split.second.size()); |
1082 |
|
double gain = |
1083 |
|
current_set_entropy |
1084 |
|
- (split.first.size() * getEntropy(split.first, hd_mv, ind) |
1085 |
|
+ split.second.size() * getEntropy(split.second, hd_mv, ind)) |
1086 |
|
/ hds.size(); |
1087 |
|
indent("sygus-unif-dt-debug", ind); |
1088 |
|
Trace("sygus-unif-dt-debug") |
1089 |
|
<< "..gain of " |
1090 |
|
<< d_unif->d_tds->sygusToBuiltin(conds[j], conds[j].getType()) << " is " |
1091 |
|
<< gain << "\n"; |
1092 |
|
if (gain > maxgain) |
1093 |
|
{ |
1094 |
|
maxgain = gain; |
1095 |
|
picked_cond = j; |
1096 |
|
} |
1097 |
|
} |
1098 |
|
// add picked condition |
1099 |
|
indent("sygus-unif-dt", ind); |
1100 |
|
Trace("sygus-unif-dt") << "..picked condition " |
1101 |
|
<< d_unif->d_tds->sygusToBuiltin( |
1102 |
|
conds[picked_cond], |
1103 |
|
conds[picked_cond].getType()) |
1104 |
|
<< "\n"; |
1105 |
|
d_conds.push_back(conds[picked_cond]); |
1106 |
|
conds.erase(conds.begin() + picked_cond); |
1107 |
|
d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1); |
1108 |
|
// recurse |
1109 |
|
buildDtInfoGain(splits[picked_cond].first, conds, hd_mv, ind + 1); |
1110 |
|
buildDtInfoGain(splits[picked_cond].second, conds, hd_mv, ind + 1); |
1111 |
|
} |
1112 |
|
|
1113 |
|
std::pair<std::vector<Node>, std::vector<Node>> |
1114 |
|
SygusUnifRl::DecisionTreeInfo::evaluateCond(std::vector<Node>& pts, Node cond) |
1115 |
|
{ |
1116 |
|
std::vector<Node> good, bad; |
1117 |
|
for (const Node& pt : pts) |
1118 |
|
{ |
1119 |
|
if (d_pt_sep.computeCond(cond, pt) == d_true) |
1120 |
|
{ |
1121 |
|
good.push_back(pt); |
1122 |
|
continue; |
1123 |
|
} |
1124 |
|
Assert(d_pt_sep.computeCond(cond, pt) == d_false); |
1125 |
|
bad.push_back(pt); |
1126 |
|
} |
1127 |
|
return std::pair<std::vector<Node>, std::vector<Node>>(good, bad); |
1128 |
|
} |
1129 |
|
|
1130 |
|
double SygusUnifRl::DecisionTreeInfo::getEntropy(const std::vector<Node>& hds, |
1131 |
|
std::map<Node, Node>& hd_mv, |
1132 |
|
int ind) |
1133 |
|
{ |
1134 |
|
double p = 0, n = 0; |
1135 |
|
TermDbSygus* tds = d_unif->d_tds; |
1136 |
|
// get number of points evaluated positively and negatively with feature |
1137 |
|
for (const Node& e : hds) |
1138 |
|
{ |
1139 |
|
if (tds->sygusToBuiltin(hd_mv[e]) == d_true) |
1140 |
|
{ |
1141 |
|
p++; |
1142 |
|
continue; |
1143 |
|
} |
1144 |
|
Assert(tds->sygusToBuiltin(hd_mv[e]) == d_false); |
1145 |
|
n++; |
1146 |
|
} |
1147 |
|
// compute entropy |
1148 |
|
return p == 0 || n == 0 ? 0 |
1149 |
|
: ((-p / (p + n)) * log2(p / (p + n))) |
1150 |
|
- ((n / (p + n)) * log2(n / (p + n))); |
1151 |
|
} |
1152 |
|
|
1153 |
11 |
void SygusUnifRl::DecisionTreeInfo::PointSeparator::initialize( |
1154 |
|
DecisionTreeInfo* dt) |
1155 |
|
{ |
1156 |
11 |
d_dt = dt; |
1157 |
11 |
} |
1158 |
|
|
1159 |
1423 |
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::evaluate(Node n, |
1160 |
|
unsigned index) |
1161 |
|
{ |
1162 |
1423 |
Assert(index < d_dt->d_conds.size()); |
1163 |
|
// Retrieve respective built_in condition |
1164 |
2846 |
Node cond = d_dt->d_conds[index]; |
1165 |
2846 |
return computeCond(cond, n); |
1166 |
|
} |
1167 |
|
|
1168 |
1423 |
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::computeCond(Node cond, |
1169 |
|
Node hd) |
1170 |
|
{ |
1171 |
2846 |
std::pair<Node, Node> cond_hd = std::pair<Node, Node>(cond, hd); |
1172 |
|
std::map<std::pair<Node, Node>, Node>::iterator it = |
1173 |
1423 |
d_eval_cond_hd.find(cond_hd); |
1174 |
1423 |
if (it != d_eval_cond_hd.end()) |
1175 |
|
{ |
1176 |
816 |
return it->second; |
1177 |
|
} |
1178 |
1214 |
TypeNode tn = cond.getType(); |
1179 |
1214 |
Node builtin_cond = d_dt->d_unif->d_tds->sygusToBuiltin(cond, tn); |
1180 |
|
// Retrieve evaluation point |
1181 |
607 |
Assert(d_dt->d_unif->d_hd_to_pt.find(hd) != d_dt->d_unif->d_hd_to_pt.end()); |
1182 |
1214 |
std::vector<Node> pt = d_dt->d_unif->d_hd_to_pt[hd]; |
1183 |
|
// compute the result |
1184 |
607 |
if (Trace.isOn("sygus-unif-rl-sep")) |
1185 |
|
{ |
1186 |
|
Trace("sygus-unif-rl-sep") |
1187 |
|
<< "Evaluate cond " << builtin_cond << " on pt " << hd << " ( "; |
1188 |
|
for (const Node& pti : pt) |
1189 |
|
{ |
1190 |
|
Trace("sygus-unif-rl-sep") << pti << " "; |
1191 |
|
} |
1192 |
|
Trace("sygus-unif-rl-sep") << ")\n"; |
1193 |
|
} |
1194 |
1214 |
Node res = d_dt->d_unif->d_tds->evaluateBuiltin(tn, builtin_cond, pt); |
1195 |
607 |
Trace("sygus-unif-rl-sep") << "...got res = " << res << "\n"; |
1196 |
|
// If condition is templated, recompute result accordingly |
1197 |
1214 |
Node templ = d_dt->d_template.first; |
1198 |
1214 |
TNode templ_var = d_dt->d_template.second; |
1199 |
607 |
if (!templ.isNull()) |
1200 |
|
{ |
1201 |
|
res = templ.substitute(templ_var, res); |
1202 |
|
res = Rewriter::rewrite(res); |
1203 |
|
Trace("sygus-unif-rl-sep") |
1204 |
|
<< "...after template res = " << res << std::endl; |
1205 |
|
} |
1206 |
607 |
Assert(res.isConst()); |
1207 |
607 |
d_eval_cond_hd[cond_hd] = res; |
1208 |
607 |
return res; |
1209 |
|
} |
1210 |
|
|
1211 |
|
} // namespace quantifiers |
1212 |
|
} // namespace theory |
1213 |
29502 |
} // namespace cvc5 |