GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus/sygus_unif_rl.cpp Lines: 419 636 65.9 %
Date: 2021-11-07 Branches: 689 2466 27.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Haniel Barbosa, Andrew Reynolds, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of sygus_unif_rl.
14
 */
15
16
#include "theory/quantifiers/sygus/sygus_unif_rl.h"
17
18
#include "expr/skolem_manager.h"
19
#include "options/base_options.h"
20
#include "options/quantifiers_options.h"
21
#include "printer/printer.h"
22
#include "theory/quantifiers/sygus/synth_conjecture.h"
23
#include "theory/quantifiers/sygus/term_database_sygus.h"
24
#include "theory/rewriter.h"
25
#include "util/random.h"
26
27
#include <math.h>
28
29
using namespace cvc5::kind;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
35
1900
SygusUnifRl::SygusUnifRl(Env& env, SynthConjecture* p)
36
    : SygusUnif(env),
37
      d_parent(p),
38
      d_useCondPool(false),
39
1900
      d_useCondPoolIGain(false)
40
{
41
1900
}
42
1896
SygusUnifRl::~SygusUnifRl() {}
43
23
void SygusUnifRl::initializeCandidate(
44
    TermDbSygus* tds,
45
    Node f,
46
    std::vector<Node>& enums,
47
    std::map<Node, std::vector<Node>>& strategy_lemmas)
48
{
49
  // initialize
50
46
  std::vector<Node> all_enums;
51
23
  SygusUnif::initializeCandidate(tds, f, all_enums, strategy_lemmas);
52
  // based on the strategy inferred for each function, determine if we are
53
  // using a unification strategy that is compatible our approach.
54
46
  StrategyRestrictions restrictions;
55
23
  if (options::sygusBoolIteReturnConst())
56
  {
57
23
    restrictions.d_iteReturnBoolConst = true;
58
  }
59
  // register the strategy
60
23
  registerStrategy(f, enums, restrictions.d_unused_strategies);
61
23
  d_strategy.at(f).staticLearnRedundantOps(strategy_lemmas, restrictions);
62
  // Copy candidates and check whether CegisUnif for any of them
63
23
  if (d_unif_candidates.find(f) != d_unif_candidates.end())
64
  {
65
17
    d_hd_to_pt[f].clear();
66
17
    d_cand_to_eval_hds[f].clear();
67
17
    d_cand_to_hd_count[f] = 0;
68
  }
69
  // check whether we are using condition enumeration
70
23
  options::SygusUnifPiMode mode = options::sygusUnifPi();
71
23
  d_useCondPool = mode == options::SygusUnifPiMode::CENUM
72
23
                  || mode == options::SygusUnifPiMode::CENUM_IGAIN;
73
23
  d_useCondPoolIGain = mode == options::SygusUnifPiMode::CENUM_IGAIN;
74
23
}
75
76
void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
77
{
78
  // we do not use notify enumeration
79
  Assert(false);
80
}
81
82
646
Node SygusUnifRl::purifyLemma(Node n,
83
                              bool ensureConst,
84
                              std::vector<Node>& model_guards,
85
                              BoolNodePairMap& cache)
86
{
87
646
  Trace("sygus-unif-rl-purify") << "PurifyLemma : " << n << "\n";
88
  BoolNodePairMap::const_iterator it0 =
89
646
      cache.find(BoolNodePair(ensureConst, n));
90
646
  if (it0 != cache.end())
91
  {
92
378
    Trace("sygus-unif-rl-purify-debug") << "... already visited " << n << "\n";
93
378
    return it0->second;
94
  }
95
  // Recurse
96
268
  unsigned size = n.getNumChildren();
97
268
  Kind k = n.getKind();
98
  // We retrive model value now because purified node may not have a value
99
536
  Node nv = n;
100
  // Whether application of a function-to-synthesize
101
268
  bool fapp = (n.getKind() == DT_SYGUS_EVAL);
102
268
  bool u_fapp = false;
103
268
  bool nu_fapp = false;
104
268
  if (fapp)
105
  {
106
66
    Assert(std::find(d_candidates.begin(), d_candidates.end(), n[0])
107
           != d_candidates.end());
108
    // Whether application of a (non-)unification function-to-synthesize
109
66
    u_fapp = usingUnif(n[0]);
110
66
    nu_fapp = !usingUnif(n[0]);
111
    // get model value of non-top level applications of functions-to-synthesize
112
    // occurring under a unification function-to-synthesize
113
66
    if (ensureConst)
114
    {
115
      std::map<Node, Node>::iterator it1 = d_cand_to_sol.find(n[0]);
116
      // if function-to-synthesize, retrieve its built solution to replace in
117
      // the application before computing the model value
118
      AlwaysAssert(!u_fapp || it1 != d_cand_to_sol.end());
119
      if (it1 != d_cand_to_sol.end())
120
      {
121
        TNode cand = n[0];
122
        Node tmp = n.substitute(cand, it1->second);
123
        // should be concrete, can just use the rewriter
124
        nv = rewrite(tmp);
125
        Trace("sygus-unif-rl-purify")
126
            << "PurifyLemma : model value for " << tmp << " is " << nv << "\n";
127
      }
128
      else
129
      {
130
        nv = d_parent->getModelValue(n);
131
        Trace("sygus-unif-rl-purify")
132
            << "PurifyLemma : model value for " << n << " is " << nv << "\n";
133
      }
134
      Assert(n != nv);
135
    }
136
  }
137
  // Travese to purify
138
268
  bool childChanged = false;
139
536
  std::vector<Node> children;
140
268
  NodeManager* nm = NodeManager::currentNM();
141
268
  SkolemManager* sm = nm->getSkolemManager();
142
954
  for (unsigned i = 0; i < size; ++i)
143
  {
144
752
    if (i == 0 && fapp)
145
    {
146
66
      children.push_back(n[i]);
147
66
      continue;
148
    }
149
    // Arguments of non-unif functions do not need to be constant
150
    Node child = purifyLemma(
151
1240
        n[i], !nu_fapp && (ensureConst || u_fapp), model_guards, cache);
152
620
    children.push_back(child);
153
620
    childChanged = childChanged || child != n[i];
154
  }
155
536
  Node nb;
156
268
  if (childChanged)
157
  {
158
70
    if (n.getMetaKind() == metakind::PARAMETERIZED)
159
    {
160
      Trace("sygus-unif-rl-purify-debug")
161
          << "Node " << n << " is parameterized\n";
162
      children.insert(children.begin(), n.getOperator());
163
    }
164
70
    if (Trace.isOn("sygus-unif-rl-purify-debug"))
165
    {
166
      Trace("sygus-unif-rl-purify-debug")
167
          << "...rebuilding " << n << " with kind " << k << " and children:\n";
168
      for (const Node& child : children)
169
      {
170
        Trace("sygus-unif-rl-purify-debug") << "...... " << child << "\n";
171
      }
172
    }
173
70
    nb = NodeManager::currentNM()->mkNode(k, children);
174
140
    Trace("sygus-unif-rl-purify")
175
70
        << "PurifyLemma : transformed " << n << " into " << nb << "\n";
176
  }
177
  else
178
  {
179
198
    nb = n;
180
  }
181
  // Map to point enumerator every unification function-to-synthesize
182
268
  if (u_fapp)
183
  {
184
120
    Node np;
185
60
    std::map<Node, Node>::const_iterator it2 = d_app_to_purified.find(nb);
186
60
    if (it2 == d_app_to_purified.end())
187
    {
188
      // Build purified head with fresh skolem and recreate node
189
92
      std::stringstream ss;
190
46
      ss << nb[0] << "_" << d_cand_to_hd_count[nb[0]]++;
191
      Node new_f = sm->mkDummySkolem(
192
92
          ss.str(), nb[0].getType(), "head of unif evaluation point");
193
      // Adds new enumerator to map from candidate
194
92
      Trace("sygus-unif-rl-purify")
195
46
          << "...new enum " << new_f << " for candidate " << nb[0] << "\n";
196
46
      d_cand_to_eval_hds[nb[0]].push_back(new_f);
197
      // Maps new enumerator to its respective tuple of arguments
198
46
      d_hd_to_pt[new_f] =
199
92
          std::vector<Node>(children.begin() + 1, children.end());
200
46
      if (Trace.isOn("sygus-unif-rl-purify-debug"))
201
      {
202
        Trace("sygus-unif-rl-purify-debug") << "...[" << new_f << "] --> ( ";
203
        for (const Node& pt_i : d_hd_to_pt[new_f])
204
        {
205
          Trace("sygus-unif-rl-purify-debug") << pt_i << " ";
206
        }
207
        Trace("sygus-unif-rl-purify-debug") << ")\n";
208
      }
209
      // replace first child and rebulid node
210
46
      Assert(children.size() > 0);
211
46
      children[0] = new_f;
212
92
      Trace("sygus-unif-rl-purify-debug")
213
46
          << "Make sygus eval app " << children << std::endl;
214
46
      np = nm->mkNode(DT_SYGUS_EVAL, children);
215
46
      d_app_to_purified[nb] = np;
216
    }
217
    else
218
    {
219
14
      np = it2->second;
220
    }
221
120
    Trace("sygus-unif-rl-purify")
222
60
        << "PurifyLemma : purified head and transformed " << nb << " into "
223
60
        << np << "\n";
224
60
    nb = np;
225
  }
226
  // Add equality between purified fapp and model value
227
268
  if (ensureConst && fapp)
228
  {
229
    model_guards.push_back(
230
        NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate());
231
    nb = nv;
232
    Trace("sygus-unif-rl-purify")
233
        << "PurifyLemma : adding model eq " << model_guards.back() << "\n";
234
  }
235
268
  nb = rewrite(nb);
236
  // every non-top level application of function-to-synthesize must be reduced
237
  // to a concrete constant
238
268
  Assert(!ensureConst || nb.isConst());
239
536
  Trace("sygus-unif-rl-purify-debug")
240
268
      << "... caching [" << n << "] = " << nb << "\n";
241
268
  cache[BoolNodePair(ensureConst, n)] = nb;
242
268
  return nb;
243
}
244
245
26
Node SygusUnifRl::addRefLemma(Node lemma,
246
                              std::map<Node, std::vector<Node>>& eval_hds)
247
{
248
52
  Trace("sygus-unif-rl-purify")
249
26
      << "Registering lemma at SygusUnif : " << lemma << "\n";
250
52
  std::vector<Node> model_guards;
251
52
  BoolNodePairMap cache;
252
  // cache previous sizes
253
52
  std::map<Node, unsigned> prev_n_eval_hds;
254
52
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
255
  {
256
26
    prev_n_eval_hds[cp.first] = cp.second.size();
257
  }
258
259
  // Make the purified lemma which will guide the unification utility.
260
26
  Node plem = purifyLemma(lemma, false, model_guards, cache);
261
26
  if (!model_guards.empty())
262
  {
263
    model_guards.push_back(plem);
264
    plem = NodeManager::currentNM()->mkNode(OR, model_guards);
265
  }
266
26
  plem = rewrite(plem);
267
26
  Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n";
268
269
26
  Trace("sygus-unif-rl-purify") << "Collect new evaluation points...\n";
270
52
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
271
  {
272
52
    Node c = cp.first;
273
26
    unsigned prevn = 0;
274
26
    std::map<Node, unsigned>::iterator itp = prev_n_eval_hds.find(c);
275
26
    if (itp != prev_n_eval_hds.end())
276
    {
277
26
      prevn = itp->second;
278
    }
279
72
    for (unsigned j = prevn, size = cp.second.size(); j < size; j++)
280
    {
281
46
      eval_hds[c].push_back(cp.second[j]);
282
      // Add new point to respective decision trees
283
46
      Assert(d_cand_cenums.find(c) != d_cand_cenums.end());
284
92
      for (const Node& cenum : d_cand_cenums[c])
285
      {
286
46
        Assert(d_cenum_to_stratpt.find(cenum) != d_cenum_to_stratpt.end());
287
92
        for (const Node& stratpt : d_cenum_to_stratpt[cenum])
288
        {
289
46
          Assert(d_stratpt_to_dt.find(stratpt) != d_stratpt_to_dt.end());
290
92
          Trace("sygus-unif-rl-dt")
291
46
              << "Register point with head " << cp.second[j]
292
46
              << " to strategy point " << stratpt << "\n";
293
          // Register new point from new head
294
46
          d_stratpt_to_dt[stratpt].d_hds.push_back(cp.second[j]);
295
        }
296
      }
297
    }
298
  }
299
300
52
  return plem;
301
}
302
303
144
void SygusUnifRl::initializeConstructSol() {}
304
144
void SygusUnifRl::initializeConstructSolFor(Node f) {}
305
144
bool SygusUnifRl::constructSolution(std::vector<Node>& sols,
306
                                    std::vector<Node>& lemmas)
307
{
308
144
  initializeConstructSol();
309
144
  bool successful = true;
310
300
  for (const Node& c : d_candidates)
311
  {
312
156
    if (!usingUnif(c))
313
    {
314
24
      Node v = d_parent->getModelValue(c);
315
12
      sols.push_back(v);
316
12
      continue;
317
    }
318
144
    initializeConstructSolFor(c);
319
    Node v = constructSol(
320
184
        c, d_strategy.at(c).getRootEnumerator(), role_equal, 0, lemmas);
321
248
    if (v.isNull())
322
    {
323
      // we continue trying to build solutions to accumulate potentitial
324
      // separation conditions from other decision trees
325
104
      successful = false;
326
104
      continue;
327
    }
328
40
    sols.push_back(v);
329
40
    d_cand_to_sol[c] = v;
330
  }
331
144
  return successful;
332
}
333
334
144
Node SygusUnifRl::constructSol(
335
    Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
336
{
337
144
  indent("sygus-unif-sol", ind);
338
144
  Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl;
339
  // retrieve strategy information
340
288
  TypeNode etn = e.getType();
341
144
  EnumTypeInfo& tinfo = d_strategy.at(f).getEnumTypeInfo(etn);
342
144
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
343
144
  if (nrole != role_equal)
344
  {
345
    return Node::null();
346
  }
347
  // is there a decision tree strategy?
348
144
  std::map<Node, DecisionTreeInfo>::iterator itd = d_stratpt_to_dt.find(e);
349
  // for now only considering simple case of sole "ITE(cond, e, e)" strategy
350
144
  if (itd == d_stratpt_to_dt.end())
351
  {
352
    return Node::null();
353
  }
354
144
  indent("sygus-unif-sol", ind);
355
144
  Trace("sygus-unif-sol") << "...it has a decision tree strategy.\n";
356
  // whether empty set of points
357
144
  if (d_cand_to_eval_hds[f].empty())
358
  {
359
28
    Trace("sygus-unif-sol") << "...... no points, return root enum value "
360
14
                            << d_parent->getModelValue(e) << "\n";
361
14
    return d_parent->getModelValue(e);
362
  }
363
130
  EnumTypeInfoStrat* etis = snode.d_strats[itd->second.getStrategyIndex()];
364
260
  Node sol = itd->second.buildSol(etis->d_cons, lemmas);
365
130
  Assert(d_useCondPool || !sol.isNull() || !lemmas.empty());
366
130
  return sol;
367
}
368
369
311
bool SygusUnifRl::usingUnif(Node f) const
370
{
371
311
  return d_unif_candidates.find(f) != d_unif_candidates.end();
372
}
373
374
17
Node SygusUnifRl::getConditionForEvaluationPoint(Node e) const
375
{
376
17
  std::map<Node, DecisionTreeInfo>::const_iterator it = d_stratpt_to_dt.find(e);
377
17
  Assert(it != d_stratpt_to_dt.end());
378
17
  return it->second.getConditionEnumerator();
379
}
380
381
144
void SygusUnifRl::setConditions(Node e,
382
                                Node guard,
383
                                const std::vector<Node>& enums,
384
                                const std::vector<Node>& conds)
385
{
386
144
  std::map<Node, DecisionTreeInfo>::iterator it = d_stratpt_to_dt.find(e);
387
144
  Assert(it != d_stratpt_to_dt.end());
388
  // set the conditions for the appropriate tree
389
144
  it->second.setConditions(guard, enums, conds);
390
144
}
391
392
482
std::vector<Node> SygusUnifRl::getEvalPointHeads(Node c)
393
{
394
482
  std::map<Node, std::vector<Node>>::iterator it = d_cand_to_eval_hds.find(c);
395
482
  if (it == d_cand_to_eval_hds.end())
396
  {
397
    return std::vector<Node>();
398
  }
399
482
  return it->second;
400
}
401
402
1154
bool SygusUnifRl::usingConditionPool() const { return d_useCondPool; }
403
26
bool SygusUnifRl::usingConditionPoolInfoGain() const
404
{
405
26
  return d_useCondPoolIGain;
406
}
407
23
void SygusUnifRl::registerStrategy(
408
    Node f,
409
    std::vector<Node>& enums,
410
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
411
{
412
23
  if (Trace.isOn("sygus-unif-rl-strat"))
413
  {
414
    Trace("sygus-unif-rl-strat")
415
        << "Strategy for " << f << " is : " << std::endl;
416
    d_strategy.at(f).debugPrint("sygus-unif-rl-strat");
417
  }
418
23
  Trace("sygus-unif-rl-strat") << "Register..." << std::endl;
419
46
  Node e = d_strategy.at(f).getRootEnumerator();
420
46
  std::map<Node, std::map<NodeRole, bool>> visited;
421
23
  registerStrategyNode(f, e, role_equal, visited, enums, unused_strats);
422
23
}
423
424
23
void SygusUnifRl::registerStrategyNode(
425
    Node f,
426
    Node e,
427
    NodeRole nrole,
428
    std::map<Node, std::map<NodeRole, bool>>& visited,
429
    std::vector<Node>& enums,
430
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
431
{
432
23
  Trace("sygus-unif-rl-strat") << "  register node " << e << std::endl;
433
23
  if (visited[e].find(nrole) != visited[e].end())
434
  {
435
    return;
436
  }
437
23
  visited[e][nrole] = true;
438
46
  TypeNode etn = e.getType();
439
23
  EnumTypeInfo& tinfo = d_strategy.at(f).getEnumTypeInfo(etn);
440
23
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
441
43
  for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
442
  {
443
20
    EnumTypeInfoStrat* etis = snode.d_strats[j];
444
20
    StrategyType strat = etis->d_this;
445
    // is this a simple recursive ITE strategy?
446
20
    bool success = false;
447
20
    if (strat == strat_ITE && nrole == role_equal)
448
    {
449
17
      success = true;
450
51
      for (unsigned c = 1; c <= 2; c++)
451
      {
452
68
        std::pair<Node, NodeRole> child = etis->d_cenum[c];
453
34
        if (child.first != e || child.second != nrole)
454
        {
455
          success = false;
456
          break;
457
        }
458
      }
459
17
      if (success)
460
      {
461
34
        Node cond = etis->d_cenum[0].first;
462
17
        Assert(etis->d_cenum[0].second == role_ite_condition);
463
34
        Trace("sygus-unif-rl-strat")
464
17
            << "  ...detected recursive ITE strategy, condition enumerator : "
465
17
            << cond << std::endl;
466
        // indicate that we will be enumerating values for cond
467
17
        registerConditionalEnumerator(f, e, cond, j);
468
        // we will be using a strategy for e
469
17
        enums.push_back(e);
470
      }
471
    }
472
20
    if (!success)
473
    {
474
3
      unused_strats[e].insert(j);
475
    }
476
    // TODO: recurse? for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
477
  }
478
}
479
480
17
void SygusUnifRl::registerConditionalEnumerator(Node f,
481
                                                Node e,
482
                                                Node cond,
483
                                                unsigned strategy_index)
484
{
485
  // only allow one decision tree per strategy point
486
17
  if (d_stratpt_to_dt.find(e) != d_stratpt_to_dt.end())
487
  {
488
    return;
489
  }
490
  // we will do unification for this candidate
491
17
  d_unif_candidates.insert(f);
492
  // add to the list of all conditional enumerators
493
51
  if (std::find(d_cond_enums.begin(), d_cond_enums.end(), cond)
494
51
      == d_cond_enums.end())
495
  {
496
17
    d_cond_enums.push_back(cond);
497
17
    d_cand_cenums[f].push_back(cond);
498
17
    d_cenum_to_stratpt[cond].clear();
499
  }
500
  // register that this strategy node has a decision tree construction
501
17
  d_stratpt_to_dt[e].initialize(cond, this, &d_strategy.at(f), strategy_index);
502
  // associate conditional enumerator with strategy node
503
17
  d_cenum_to_stratpt[cond].push_back(e);
504
}
505
506
17
void SygusUnifRl::DecisionTreeInfo::initialize(Node cond_enum,
507
                                               SygusUnifRl* unif,
508
                                               SygusUnifStrategy* strategy,
509
                                               unsigned strategy_index)
510
{
511
17
  d_cond_enum = cond_enum;
512
17
  d_unif = unif;
513
17
  d_strategy = strategy;
514
17
  d_strategy_index = strategy_index;
515
17
  d_true = NodeManager::currentNM()->mkConst(true);
516
17
  d_false = NodeManager::currentNM()->mkConst(false);
517
  // Retrieve template
518
17
  EnumInfo& eiv = d_strategy->getEnumInfo(d_cond_enum);
519
17
  d_template = NodePair(eiv.d_template, eiv.d_template_arg);
520
  // Initialize classifier
521
17
  d_pt_sep.initialize(this);
522
17
}
523
524
144
void SygusUnifRl::DecisionTreeInfo::setConditions(
525
    Node guard, const std::vector<Node>& enums, const std::vector<Node>& conds)
526
{
527
144
  Assert(enums.size() == conds.size());
528
  // set the guard
529
144
  d_guard = guard;
530
  // clear old condition values
531
144
  d_enums.clear();
532
144
  d_conds.clear();
533
  // set new condition values
534
144
  d_enums.insert(d_enums.end(), enums.begin(), enums.end());
535
144
  d_conds.insert(d_conds.end(), conds.begin(), conds.end());
536
  // add to condition pool
537
144
  if (d_unif->usingConditionPool())
538
  {
539
    d_cond_mvs.insert(conds.begin(), conds.end());
540
    if (Trace.isOn("sygus-unif-cond-pool"))
541
    {
542
      for (const Node& condv : conds)
543
      {
544
        if (d_cond_mvs.find(condv) == d_cond_mvs.end())
545
        {
546
          Trace("sygus-unif-cond-pool")
547
              << "  ...adding to condition pool : "
548
              << d_unif->d_tds->sygusToBuiltin(condv, condv.getType()) << "\n";
549
        }
550
      }
551
    }
552
  }
553
144
}
554
555
130
unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const
556
{
557
130
  return d_strategy_index;
558
}
559
560
130
Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
561
                                             std::vector<Node>& lemmas)
562
{
563
130
  if (!d_template.first.isNull())
564
  {
565
    Trace("sygus-unif-sol") << "...templated conditions unsupported\n";
566
    return Node::null();
567
  }
568
260
  Trace("sygus-unif-sol") << "Decision::buildSol with " << d_hds.size()
569
260
                          << " evaluation heads and " << d_conds.size()
570
130
                          << " conditions..." << std::endl;
571
  // reset the trie
572
130
  d_pt_sep.d_trie.clear();
573
130
  return d_unif->usingConditionPool() ? buildSolAllCond(cons, lemmas)
574
130
                                      : buildSolMinCond(cons, lemmas);
575
}
576
577
Node SygusUnifRl::DecisionTreeInfo::buildSolAllCond(Node cons,
578
                                                    std::vector<Node>& lemmas)
579
{
580
  // model values for evaluation heads
581
  std::map<Node, Node> hd_mv;
582
  // add conditions
583
  d_conds.clear();
584
  d_conds.insert(d_conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
585
  // shuffle conditions before bulding DT
586
  //
587
  // this does not impact whether it's possible to build a solution, but it does
588
  // impact the potential size of the resulting solution (can make it smaller,
589
  // bigger, or have no impact) and which conditions will be present in the DT,
590
  // which influences the "quality" of the solution for cases not covered in the
591
  // current data points
592
  if (options::sygusUnifShuffleCond())
593
  {
594
    std::shuffle(d_conds.begin(), d_conds.end(), Random::getRandom());
595
  }
596
  unsigned num_conds = d_conds.size();
597
  for (unsigned i = 0; i < num_conds; ++i)
598
  {
599
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, i);
600
  }
601
  // add heads
602
  for (const Node& e : d_hds)
603
  {
604
    Node v = d_unif->d_parent->getModelValue(e);
605
    hd_mv[e] = v;
606
    Node er = d_pt_sep.d_trie.add(e, &d_pt_sep, num_conds);
607
    // are we in conflict?
608
    if (er == e)
609
    {
610
      // new separation class, no conflict
611
      continue;
612
    }
613
    Assert(hd_mv.find(er) != hd_mv.end());
614
    // merged into separation class with same model value, no conflict
615
    if (hd_mv[e] == hd_mv[er])
616
    {
617
      continue;
618
    }
619
    // conflict. Explanation?
620
    Trace("sygus-unif-sol")
621
        << "  ...can't separate " << e << " from " << er << std::endl;
622
    return Node::null();
623
  }
624
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
625
  Node sol = extractSol(cons, hd_mv);
626
  // repeated solution
627
  if (options::sygusUnifCondIndNoRepeatSol()
628
      && d_sols.find(sol) != d_sols.end())
629
  {
630
    return Node::null();
631
  }
632
  d_sols.insert(sol);
633
  return sol;
634
}
635
636
130
Node SygusUnifRl::DecisionTreeInfo::buildSolMinCond(Node cons,
637
                                                    std::vector<Node>& lemmas)
638
{
639
130
  NodeManager* nm = NodeManager::currentNM();
640
  // model values for evaluation heads
641
260
  std::map<Node, Node> hd_mv;
642
  // the current explanation of why there has not yet been a separation conflict
643
260
  std::vector<Node> exp;
644
  // is the above explanation ready to be sent out as a lemma?
645
130
  bool exp_conflict = false;
646
  // the index of the head we are considering
647
130
  unsigned hd_counter = 0;
648
  // the index of the condition we are considering
649
130
  unsigned c_counter = 0;
650
  // do we need to resolve a separation conflict?
651
130
  bool needs_sep_resolve = false;
652
  // This loop simultaneously builds the solution in terms of a lazy trie
653
  // (LazyTrieMulti), and checks whether a separation conflict exists. We
654
  // enforce that the separation conflicts we encounter while building
655
  // this solution are resolved, in order, by the condition enumerators.
656
  // If not, then we add a (conflict) lemma stating that the current model
657
  // value of the condition enumerator must be different. We also call this
658
  // a "separation lemma".
659
  //
660
  // As a simple example, say we have:
661
  //   evalution heads: (eval e1 0 0), (eval e2 1 2)
662
  //   conditions: c1
663
  // where M(e1) = x, M(e2) = y, and M(c1) = x>1. After adding e1 and e2, we are
664
  // in conflict since { e1, e2 } form a separation class, M(e1)!=M(e2), and
665
  // M(c1) does not separate e1 and e2 since:
666
  //   (x>1){x->0,y->0} = (x>1){x->1,y->2} = false
667
  // Hence, we would fail to build a solution in this case, and instead send a
668
  // separation lemma of the form:
669
  //   ~( e1 != e2 ^ c1 = [x<1] )
670
  //
671
  // Say we have:
672
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
673
  //   conditions: c1 c2
674
  // where M(e1) = x, M(e2) = y, M(e3) = x+1, M(c1) = x>0 and M(c2) = x<0.
675
  // After adding e1 and e2, { e1, e2 } form a separation class, M(e1)!=M(e2),
676
  // but M(c1) separates e1 and e2 since
677
  //   (x>0){x->0,y->0} = false, and
678
  //   (x>1){x->1,y->2} = true
679
  // Hence, we get new separation classes { e1 } and { e2 }, and afterwards
680
  // add e3. We then get { e2, e3 } as a separation class, which is also a
681
  // conflict since M(e2)!=M(e3). We check if M(c2) resolves this conflict.
682
  // It does not, since (x<1){x->0,y->0} = (x<1){x->1,y->2} = false. Hence,
683
  // we get a separation lemma:
684
  //  ~( c1 = [x>1] ^ e2 != e3 ^ c2 = [x<1] )
685
  //
686
  // Say we have:
687
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
688
  //   conditions: c1
689
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = x>0.
690
  // After adding e1 and e2, we have separation class { e1, e2 }. This is not a
691
  // conflict since M(e1)=M(e2). We then add e3, obtaining separation class
692
  // { e1, e2, e3 }, which is in conflict since M(e3)!=M(e1), and the condition
693
  // c1 does not separate e3 and the representative of this class, e1. Hence we
694
  // get a separation lemma of the form:
695
  //  ~( e1 = e2 ^ e1 != e3 ^ c1 = [x>0] )
696
  //
697
  // It also may be the case that we exhaust the pool of condition enumerators.
698
  // Say we have:
699
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
700
  //   conditions: c1
701
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = y>0. After adding e1, e2,
702
  // and e3, we have a separation class { e1, e2, e3 } that is in conflict
703
  // since M(e3)!=M(e1). We add the condition c1, which separates into new
704
  // equivalence classes { e1 }, { e2, e3 }. We are still in separation conflict
705
  // since M(e3)!=M(e2). However, we do not have any further conditions to use
706
  // to resolve this conflict. Thus, we add the separation lemma:
707
  //  ~( e1 = e2 ^ e1 != e3 ^ e2 != e3 ^ c1 = [y>0] ^ G_1 )
708
  // where G_1 is a guard stating that we use at most 1 condition.
709
260
  Node e;
710
260
  Node er;
711
786
  while (hd_counter < d_hds.size() || needs_sep_resolve)
712
  {
713
432
    if (!needs_sep_resolve)
714
    {
715
      // add the head to the trie
716
388
      e = d_hds[hd_counter];
717
388
      hd_mv[e] = d_unif->d_parent->getModelValue(e);
718
388
      if (Trace.isOn("sygus-unif-sol"))
719
      {
720
        std::stringstream ss;
721
        TermDbSygus::toStreamSygus(ss, hd_mv[e]);
722
        Trace("sygus-unif-sol")
723
            << "  add evaluation head (" << hd_counter << "/" << d_hds.size()
724
            << "): " << e << " -> " << ss.str() << std::endl;
725
      }
726
388
      hd_counter++;
727
      // get the representative of the trie
728
388
      er = d_pt_sep.d_trie.add(e, &d_pt_sep, c_counter);
729
388
      Trace("sygus-unif-sol") << "  ...separation class " << er << std::endl;
730
      // are we in conflict?
731
388
      if (er == e)
732
      {
733
        // new separation class, no conflict
734
414
        continue;
735
      }
736
258
      Assert(hd_mv.find(er) != hd_mv.end());
737
362
      if (hd_mv[er] == hd_mv[e])
738
      {
739
        // merged into separation class with same model value, no conflict
740
        // add to explanation
741
        // this states that it mattered that (er = e) at the time that e was
742
        // added to the trie. Notice that er and e may become separated later,
743
        // but to ensure the overall invariant, this equality must persist in
744
        // the explanation.
745
104
        exp.push_back(er.eqNode(e));
746
104
        Trace("sygus-unif-sol") << "  ...equal model values " << std::endl;
747
208
        Trace("sygus-unif-sol")
748
104
            << "  ...add to explanation " << er.eqNode(e) << std::endl;
749
104
        continue;
750
      }
751
    }
752
    // must include in the explanation that we hit a conflict at this point in
753
    // the construction
754
198
    exp.push_back(e.eqNode(er).negate());
755
    // we are in separation conflict, does the next condition resolve this?
756
    // check whether we have have exhausted our condition pool. If so, we
757
    // are in conflict and this conflict depends on the guard.
758
198
    if (c_counter >= d_conds.size())
759
    {
760
      // truncated separation lemma
761
36
      Assert(!d_guard.isNull());
762
36
      exp.push_back(d_guard);
763
36
      exp_conflict = true;
764
140
      break;
765
    }
766
162
    Assert(c_counter < d_conds.size());
767
206
    Node ce = d_enums[c_counter];
768
206
    Node cv = d_conds[c_counter];
769
162
    Assert(ce.getType() == cv.getType());
770
162
    if (Trace.isOn("sygus-unif-sol"))
771
    {
772
      std::stringstream ss;
773
      TermDbSygus::toStreamSygus(ss, cv);
774
      Trace("sygus-unif-sol")
775
          << "  add condition (" << c_counter << "/" << d_conds.size()
776
          << "): " << ce << " -> " << ss.str() << std::endl;
777
    }
778
    // cache the separation class
779
206
    std::vector<Node> prev_sep_c = d_pt_sep.d_trie.d_rep_to_class[er];
780
    // add new classifier
781
162
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, c_counter);
782
162
    c_counter++;
783
    // add to explanation
784
    // c_exp is a conjunction of testers applied to shared selector chains
785
206
    Node c_exp = d_unif->d_tds->getExplain()->getExplanationForEquality(ce, cv);
786
162
    exp.push_back(c_exp);
787
    std::map<Node, std::vector<Node>>::iterator itr =
788
162
        d_pt_sep.d_trie.d_rep_to_class.find(e);
789
    // since e is last in its separation class, if it becomes a representative,
790
    // then it is separated from all values in prev_sep_c
791
212
    if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
792
    {
793
100
      Trace("sygus-unif-sol")
794
50
          << "  ...resolves separation conflict with all" << std::endl;
795
50
      needs_sep_resolve = false;
796
50
      continue;
797
    }
798
112
    itr = d_pt_sep.d_trie.d_rep_to_class.find(er);
799
    // since er is first in its separation class, it remains a representative
800
112
    Assert(itr != d_pt_sep.d_trie.d_rep_to_class.end());
801
    // is e still in the separation class of er?
802
336
    if (std::find(itr->second.begin(), itr->second.end(), e)
803
336
        != itr->second.end())
804
    {
805
136
      Trace("sygus-unif-sol")
806
68
          << "  ...does not resolve separation conflict with current"
807
68
          << std::endl;
808
      // the condition does not separate e and er
809
      // this violates the invariant that the i^th conditional enumerator
810
      // resolves the i^th separation conflict
811
68
      exp_conflict = true;
812
68
      SygusTypeInfo& ti = d_unif->d_tds->getTypeInfo(ce.getType());
813
      // The reasoning below is only necessary if we use symbolic constructors.
814
68
      if (!ti.hasSubtermSymbolicCons())
815
      {
816
68
        break;
817
      }
818
      // Since the explanation of the condition (c_exp above) does not account
819
      // for builtin subterms, we additionally require that the valuation of
820
      // the condition is indeed different on the two points.
821
      // For example, say ce has model value equal to the SyGuS datatype term:
822
      //   C_leq_xy( 0, 1 )
823
      // where C_leq_xy is a SyGuS datatype constructor taking two integer
824
      // constants c_x and c_y, and whose builtin version is:
825
      //   (0*x + 1*y >= 0)
826
      // Then, c_exp above is:
827
      //   is-C_leq_xy( ce )
828
      // which is added to our explanation of the conflict, which does not
829
      // account for the values of the arguments of C_leq_xy.
830
      // Now, say that we are in a separation conflict due to f(1,2) and f(2,3)
831
      // being assigned different values; the value of ce does not separate
832
      // these two terms since:
833
      //   (y>=0) { x -> 1, y -> 2 } = (y>=0) { x -> 2, y -> 3 } = true
834
      // The code below adds a constraint that states that the above values are
835
      // the same, which is part of the reason for the conflict. In the above
836
      // example, we generate:
837
      //   (DT_SYGUS_EVAL ce 1 2) == (DT_SYGUS_EVAL ce 2 3) { ce -> M(ce) }
838
      // which unfolds via the SygusEvalUnfold utility to:
839
      //   ( (c_x ce)*1 + (c_y ce)*2 >= 0 ) == ( (c_x ce)*2 + (c_y ce)*3 >= 0 )
840
      // where c_x and c_y are the selectors of the subfields of C_leq_xy.
841
      Trace("sygus-unif-sol-sym")
842
          << "Explain symbolic separation conflict" << std::endl;
843
      std::map<Node, std::vector<Node>>::iterator ith;
844
      Node ceApp[2];
845
      SygusEvalUnfold* eunf = d_unif->d_tds->getEvalUnfold();
846
      std::map<Node, Node> vtm;
847
      vtm[ce] = cv;
848
      Trace("sygus-unif-sol-sym")
849
          << "Model value for " << ce << " is " << cv << std::endl;
850
      for (unsigned r = 0; r < 2; r++)
851
      {
852
        std::vector<Node> cechildren;
853
        cechildren.push_back(ce);
854
        Node ecurr = r == 0 ? e : er;
855
        ith = d_unif->d_hd_to_pt.find(ecurr);
856
        AlwaysAssert(ith != d_unif->d_hd_to_pt.end());
857
        cechildren.insert(
858
            cechildren.end(), ith->second.begin(), ith->second.end());
859
        Node cea = nm->mkNode(DT_SYGUS_EVAL, cechildren);
860
        Trace("sygus-unif-sol-sym")
861
            << "Sep conflict app #" << r << " : " << cea << std::endl;
862
        std::vector<Node> tmpExp;
863
        cea = eunf->unfold(cea, vtm, tmpExp, true, true);
864
        Trace("sygus-unif-sol-sym") << "Unfolded to : " << cea << std::endl;
865
        ceApp[r] = cea;
866
      }
867
      Node ceAppEq = ceApp[0].eqNode(ceApp[1]);
868
      Trace("sygus-unif-sol-sym")
869
          << "Sep conflict app explanation is : " << ceAppEq << std::endl;
870
      exp.push_back(ceAppEq);
871
      break;
872
    }
873
88
    Trace("sygus-unif-sol")
874
44
        << "  ...resolves separation conflict with current, but not all"
875
44
        << std::endl;
876
    // find the new term to resolve a separation
877
88
    Node new_er = Node::null();
878
    // scan the previous list and find the representative of the class that e is
879
    // now in
880
88
    for (unsigned i = 0, size = prev_sep_c.size(); i < size; i++)
881
    {
882
132
      Node check_er = prev_sep_c[i];
883
88
      if (check_er != er && check_er != e)
884
      {
885
44
        itr = d_pt_sep.d_trie.d_rep_to_class.find(check_er);
886
44
        if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
887
        {
888
132
          if (std::find(itr->second.begin(), itr->second.end(), e)
889
132
              != itr->second.end())
890
          {
891
44
            new_er = check_er;
892
44
            break;
893
          }
894
        }
895
      }
896
    }
897
    // should find exactly one
898
44
    Assert(!new_er.isNull());
899
44
    er = new_er;
900
44
    needs_sep_resolve = true;
901
  }
902
130
  if (exp_conflict)
903
  {
904
208
    Node lemma = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp);
905
104
    lemma = lemma.negate();
906
104
    Trace("sygus-unif-sol") << "  ......conflict is " << lemma << std::endl;
907
104
    lemmas.push_back(lemma);
908
104
    return Node::null();
909
  }
910
911
26
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
912
26
  return extractSol(cons, hd_mv);
913
}
914
915
26
Node SygusUnifRl::DecisionTreeInfo::extractSol(Node cons,
916
                                               std::map<Node, Node>& hd_mv)
917
{
918
  // rebuild decision tree using heuristic learning
919
26
  if (d_unif->usingConditionPoolInfoGain())
920
  {
921
    recomputeSolHeuristically(hd_mv);
922
  }
923
26
  return d_pt_sep.extractSol(cons, hd_mv);
924
}
925
926
26
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::extractSol(
927
    Node cons, std::map<Node, Node>& hd_mv)
928
{
929
  // Traverse trie and build ITE with cons
930
26
  NodeManager* nm = NodeManager::currentNM();
931
52
  std::map<IndTriePair, Node> cache;
932
26
  std::map<IndTriePair, Node>::iterator it;
933
52
  std::vector<IndTriePair> visit;
934
26
  unsigned index = 0;
935
  LazyTrie* trie;
936
26
  IndTriePair root = IndTriePair(0, &d_trie.d_trie);
937
26
  visit.push_back(root);
938
206
  while (!visit.empty())
939
  {
940
90
    index = visit.back().first;
941
90
    trie = visit.back().second;
942
90
    visit.pop_back();
943
90
    IndTriePair cur = IndTriePair(index, trie);
944
90
    it = cache.find(cur);
945
    // traverse children so results are saved to build node for parent
946
90
    if (it == cache.end())
947
    {
948
      // leaf
949
114
      if (trie->d_children.empty())
950
      {
951
46
        Assert(hd_mv.find(trie->d_lazy_child) != hd_mv.end());
952
46
        cache[cur] = hd_mv[trie->d_lazy_child];
953
92
        Trace("sygus-unif-sol-debug") << "......leaf, build "
954
138
                                      << d_dt->d_unif->d_tds->sygusToBuiltin(
955
138
                                             cache[cur], cache[cur].getType())
956
46
                                      << "\n";
957
116
        continue;
958
      }
959
22
      cache[cur] = Node::null();
960
22
      visit.push_back(cur);
961
64
      for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
962
      {
963
42
        visit.push_back(IndTriePair(index + 1, &p_nt.second));
964
      }
965
22
      continue;
966
    }
967
    // retrieve terms of children and build result
968
22
    Assert(it->second.isNull());
969
22
    Assert(trie->d_children.size() == 1 || trie->d_children.size() == 2);
970
42
    std::vector<Node> children(4);
971
22
    children[0] = cons;
972
22
    children[1] = d_dt->d_conds[index];
973
22
    unsigned i = 0;
974
64
    for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
975
    {
976
42
      i = p_nt.first.getConst<bool>() ? 2 : 3;
977
42
      Assert(cache.find(IndTriePair(index + 1, &p_nt.second)) != cache.end());
978
42
      children[i] = cache[IndTriePair(index + 1, &p_nt.second)];
979
42
      Assert(!children[i].isNull());
980
    }
981
    // condition is useless or result children are equal, no no need for ITE
982
24
    if (trie->d_children.size() == 1 || children[2] == children[3])
983
    {
984
2
      cache[cur] = children[i];
985
4
      Trace("sygus-unif-sol-debug")
986
2
          << "......no need for cond "
987
6
          << d_dt->d_unif->d_tds->sygusToBuiltin(d_dt->d_conds[index],
988
6
                                                 d_dt->d_conds[index].getType())
989
2
          << ", build "
990
6
          << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur],
991
6
                                                 cache[cur].getType())
992
2
          << "\n";
993
2
      continue;
994
    }
995
20
    Assert(trie->d_children.size() == 2);
996
20
    cache[cur] = nm->mkNode(APPLY_CONSTRUCTOR, children);
997
40
    Trace("sygus-unif-sol-debug")
998
20
        << "......build node "
999
40
        << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur], cache[cur].getType())
1000
20
        << "\n";
1001
  }
1002
26
  Assert(cache.find(root) != cache.end());
1003
26
  Assert(!cache.find(root)->second.isNull());
1004
52
  return cache[root];
1005
}
1006
1007
void SygusUnifRl::DecisionTreeInfo::recomputeSolHeuristically(
1008
    std::map<Node, Node>& hd_mv)
1009
{
1010
  // reset the trie
1011
  d_pt_sep.d_trie.clear();
1012
  // TODO workaround and not really sure this is the last condition, since I put
1013
  // a set here. Maybe make d_cond_mvs into a vector
1014
  Node backup_last_cond = d_conds.back();
1015
  d_conds.clear();
1016
  for (const Node& e : d_hds)
1017
  {
1018
    d_pt_sep.d_trie.add(e, &d_pt_sep, 0);
1019
  }
1020
  // init vector of conds
1021
  std::vector<Node> conds;
1022
  conds.insert(conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
1023
1024
  // recursively build trie by picking best condition for respective points
1025
  buildDtInfoGain(d_hds, conds, hd_mv, 1);
1026
  // if no condition was added (i.e. points are already classified at root
1027
  // level), use last condition as candidate
1028
  if (d_conds.empty())
1029
  {
1030
    Trace("sygus-unif-dt") << "......using last condition "
1031
                           << d_unif->d_tds->sygusToBuiltin(
1032
                                  backup_last_cond, backup_last_cond.getType())
1033
                           << " as candidate\n";
1034
    d_conds.push_back(backup_last_cond);
1035
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1036
  }
1037
}
1038
1039
void SygusUnifRl::DecisionTreeInfo::buildDtInfoGain(std::vector<Node>& hds,
1040
                                                    std::vector<Node> conds,
1041
                                                    std::map<Node, Node>& hd_mv,
1042
                                                    int ind)
1043
{
1044
  // test if fully classified
1045
  if (hds.size() < 2)
1046
  {
1047
    indent("sygus-unif-dt", ind);
1048
    Trace("sygus-unif-dt") << "..set fully classified: "
1049
                           << (hds.empty() ? "empty" : "unary") << "\n";
1050
    return;
1051
  }
1052
  Node v1 = hd_mv[hds[0]];
1053
  unsigned i = 1, size = hds.size();
1054
  for (; i < size; ++i)
1055
  {
1056
    if (hd_mv[hds[i]] != v1)
1057
    {
1058
      break;
1059
    }
1060
  }
1061
  if (i == size)
1062
  {
1063
    indent("sygus-unif-dt", ind);
1064
    Trace("sygus-unif-dt") << "..set fully classified: " << hds.size() << " "
1065
                           << (d_unif->d_tds->sygusToBuiltin(v1, v1.getType())
1066
                                       == d_true
1067
                                   ? "good"
1068
                                   : "bad")
1069
                           << " points\n";
1070
    return;
1071
  }
1072
  // pick condition to further classify
1073
  double maxgain = -1;
1074
  unsigned picked_cond = 0;
1075
  std::vector<std::pair<std::vector<Node>, std::vector<Node>>> splits;
1076
  double current_set_entropy = getEntropy(hds, hd_mv, ind);
1077
  for (unsigned j = 0, conds_size = conds.size(); j < conds_size; ++j)
1078
  {
1079
    std::pair<std::vector<Node>, std::vector<Node>> split =
1080
        evaluateCond(hds, conds[j]);
1081
    splits.push_back(split);
1082
    Assert(hds.size() == split.first.size() + split.second.size());
1083
    double gain =
1084
        current_set_entropy
1085
        - (split.first.size() * getEntropy(split.first, hd_mv, ind)
1086
           + split.second.size() * getEntropy(split.second, hd_mv, ind))
1087
              / hds.size();
1088
    indent("sygus-unif-dt-debug", ind);
1089
    Trace("sygus-unif-dt-debug")
1090
        << "..gain of "
1091
        << d_unif->d_tds->sygusToBuiltin(conds[j], conds[j].getType()) << " is "
1092
        << gain << "\n";
1093
    if (gain > maxgain)
1094
    {
1095
      maxgain = gain;
1096
      picked_cond = j;
1097
    }
1098
  }
1099
  // add picked condition
1100
  indent("sygus-unif-dt", ind);
1101
  Trace("sygus-unif-dt") << "..picked condition "
1102
                         << d_unif->d_tds->sygusToBuiltin(
1103
                                conds[picked_cond],
1104
                                conds[picked_cond].getType())
1105
                         << "\n";
1106
  d_conds.push_back(conds[picked_cond]);
1107
  conds.erase(conds.begin() + picked_cond);
1108
  d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1109
  // recurse
1110
  buildDtInfoGain(splits[picked_cond].first, conds, hd_mv, ind + 1);
1111
  buildDtInfoGain(splits[picked_cond].second, conds, hd_mv, ind + 1);
1112
}
1113
1114
std::pair<std::vector<Node>, std::vector<Node>>
1115
SygusUnifRl::DecisionTreeInfo::evaluateCond(std::vector<Node>& pts, Node cond)
1116
{
1117
  std::vector<Node> good, bad;
1118
  for (const Node& pt : pts)
1119
  {
1120
    if (d_pt_sep.computeCond(cond, pt) == d_true)
1121
    {
1122
      good.push_back(pt);
1123
      continue;
1124
    }
1125
    Assert(d_pt_sep.computeCond(cond, pt) == d_false);
1126
    bad.push_back(pt);
1127
  }
1128
  return std::pair<std::vector<Node>, std::vector<Node>>(good, bad);
1129
}
1130
1131
double SygusUnifRl::DecisionTreeInfo::getEntropy(const std::vector<Node>& hds,
1132
                                                 std::map<Node, Node>& hd_mv,
1133
                                                 int ind)
1134
{
1135
  double p = 0, n = 0;
1136
  TermDbSygus* tds = d_unif->d_tds;
1137
  // get number of points evaluated positively and negatively with feature
1138
  for (const Node& e : hds)
1139
  {
1140
    if (tds->sygusToBuiltin(hd_mv[e]) == d_true)
1141
    {
1142
      p++;
1143
      continue;
1144
    }
1145
    Assert(tds->sygusToBuiltin(hd_mv[e]) == d_false);
1146
    n++;
1147
  }
1148
  // compute entropy
1149
  return p == 0 || n == 0 ? 0
1150
                          : ((-p / (p + n)) * log2(p / (p + n)))
1151
                                - ((n / (p + n)) * log2(n / (p + n)));
1152
}
1153
1154
17
void SygusUnifRl::DecisionTreeInfo::PointSeparator::initialize(
1155
    DecisionTreeInfo* dt)
1156
{
1157
17
  d_dt = dt;
1158
17
}
1159
1160
506
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::evaluate(Node n,
1161
                                                             unsigned index)
1162
{
1163
506
  Assert(index < d_dt->d_conds.size());
1164
  // Retrieve respective built_in condition
1165
1012
  Node cond = d_dt->d_conds[index];
1166
1012
  return computeCond(cond, n);
1167
}
1168
1169
506
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::computeCond(Node cond,
1170
                                                                Node hd)
1171
{
1172
1012
  std::pair<Node, Node> cond_hd = std::pair<Node, Node>(cond, hd);
1173
  std::map<std::pair<Node, Node>, Node>::iterator it =
1174
506
      d_eval_cond_hd.find(cond_hd);
1175
506
  if (it != d_eval_cond_hd.end())
1176
  {
1177
322
    return it->second;
1178
  }
1179
368
  TypeNode tn = cond.getType();
1180
368
  Node builtin_cond = d_dt->d_unif->d_tds->sygusToBuiltin(cond, tn);
1181
  // Retrieve evaluation point
1182
184
  Assert(d_dt->d_unif->d_hd_to_pt.find(hd) != d_dt->d_unif->d_hd_to_pt.end());
1183
368
  std::vector<Node> pt = d_dt->d_unif->d_hd_to_pt[hd];
1184
  // compute the result
1185
184
  if (Trace.isOn("sygus-unif-rl-sep"))
1186
  {
1187
    Trace("sygus-unif-rl-sep")
1188
        << "Evaluate cond " << builtin_cond << " on pt " << hd << " ( ";
1189
    for (const Node& pti : pt)
1190
    {
1191
      Trace("sygus-unif-rl-sep") << pti << " ";
1192
    }
1193
    Trace("sygus-unif-rl-sep") << ")\n";
1194
  }
1195
368
  Node res = d_dt->d_unif->d_tds->evaluateBuiltin(tn, builtin_cond, pt);
1196
184
  Trace("sygus-unif-rl-sep") << "...got res = " << res << "\n";
1197
  // If condition is templated, recompute result accordingly
1198
368
  Node templ = d_dt->d_template.first;
1199
368
  TNode templ_var = d_dt->d_template.second;
1200
184
  if (!templ.isNull())
1201
  {
1202
    res = templ.substitute(templ_var, res);
1203
    res = Rewriter::rewrite(res);
1204
    Trace("sygus-unif-rl-sep")
1205
        << "...after template res = " << res << std::endl;
1206
  }
1207
184
  Assert(res.isConst());
1208
184
  d_eval_cond_hd[cond_hd] = res;
1209
184
  return res;
1210
}
1211
1212
}  // namespace quantifiers
1213
}  // namespace theory
1214
31137
}  // namespace cvc5