GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus/sygus_unif_rl.cpp Lines: 421 638 66.0 %
Date: 2021-09-15 Branches: 689 2466 27.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Haniel Barbosa, Andrew Reynolds, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of sygus_unif_rl.
14
 */
15
16
#include "theory/quantifiers/sygus/sygus_unif_rl.h"
17
18
#include "expr/skolem_manager.h"
19
#include "options/base_options.h"
20
#include "options/quantifiers_options.h"
21
#include "printer/printer.h"
22
#include "theory/quantifiers/sygus/synth_conjecture.h"
23
#include "theory/quantifiers/sygus/term_database_sygus.h"
24
#include "theory/rewriter.h"
25
#include "util/random.h"
26
27
#include <math.h>
28
29
using namespace cvc5::kind;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
35
1233
SygusUnifRl::SygusUnifRl(Env& env, SynthConjecture* p)
36
    : SygusUnif(env),
37
      d_parent(p),
38
      d_useCondPool(false),
39
1233
      d_useCondPoolIGain(false)
40
{
41
1233
}
42
1231
SygusUnifRl::~SygusUnifRl() {}
43
14
void SygusUnifRl::initializeCandidate(
44
    TermDbSygus* tds,
45
    Node f,
46
    std::vector<Node>& enums,
47
    std::map<Node, std::vector<Node>>& strategy_lemmas)
48
{
49
  // initialize
50
28
  std::vector<Node> all_enums;
51
14
  SygusUnif::initializeCandidate(tds, f, all_enums, strategy_lemmas);
52
  // based on the strategy inferred for each function, determine if we are
53
  // using a unification strategy that is compatible our approach.
54
28
  StrategyRestrictions restrictions;
55
14
  if (options::sygusBoolIteReturnConst())
56
  {
57
14
    restrictions.d_iteReturnBoolConst = true;
58
  }
59
  // register the strategy
60
14
  registerStrategy(f, enums, restrictions.d_unused_strategies);
61
14
  d_strategy.at(f).staticLearnRedundantOps(strategy_lemmas, restrictions);
62
  // Copy candidates and check whether CegisUnif for any of them
63
14
  if (d_unif_candidates.find(f) != d_unif_candidates.end())
64
  {
65
11
    d_hd_to_pt[f].clear();
66
11
    d_cand_to_eval_hds[f].clear();
67
11
    d_cand_to_hd_count[f] = 0;
68
  }
69
  // check whether we are using condition enumeration
70
14
  options::SygusUnifPiMode mode = options::sygusUnifPi();
71
14
  d_useCondPool = mode == options::SygusUnifPiMode::CENUM
72
14
                  || mode == options::SygusUnifPiMode::CENUM_IGAIN;
73
14
  d_useCondPoolIGain = mode == options::SygusUnifPiMode::CENUM_IGAIN;
74
14
}
75
76
void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
77
{
78
  // we do not use notify enumeration
79
  Assert(false);
80
}
81
82
381
Node SygusUnifRl::purifyLemma(Node n,
83
                              bool ensureConst,
84
                              std::vector<Node>& model_guards,
85
                              BoolNodePairMap& cache)
86
{
87
381
  Trace("sygus-unif-rl-purify") << "PurifyLemma : " << n << "\n";
88
  BoolNodePairMap::const_iterator it0 =
89
381
      cache.find(BoolNodePair(ensureConst, n));
90
381
  if (it0 != cache.end())
91
  {
92
227
    Trace("sygus-unif-rl-purify-debug") << "... already visited " << n << "\n";
93
227
    return it0->second;
94
  }
95
  // Recurse
96
154
  unsigned size = n.getNumChildren();
97
154
  Kind k = n.getKind();
98
  // We retrive model value now because purified node may not have a value
99
308
  Node nv = n;
100
  // Whether application of a function-to-synthesize
101
154
  bool fapp = (n.getKind() == DT_SYGUS_EVAL);
102
154
  bool u_fapp = false;
103
154
  bool nu_fapp = false;
104
154
  if (fapp)
105
  {
106
38
    Assert(std::find(d_candidates.begin(), d_candidates.end(), n[0])
107
           != d_candidates.end());
108
    // Whether application of a (non-)unification function-to-synthesize
109
38
    u_fapp = usingUnif(n[0]);
110
38
    nu_fapp = !usingUnif(n[0]);
111
    // get model value of non-top level applications of functions-to-synthesize
112
    // occurring under a unification function-to-synthesize
113
38
    if (ensureConst)
114
    {
115
      std::map<Node, Node>::iterator it1 = d_cand_to_sol.find(n[0]);
116
      // if function-to-synthesize, retrieve its built solution to replace in
117
      // the application before computing the model value
118
      AlwaysAssert(!u_fapp || it1 != d_cand_to_sol.end());
119
      if (it1 != d_cand_to_sol.end())
120
      {
121
        TNode cand = n[0];
122
        Node tmp = n.substitute(cand, it1->second);
123
        // should be concrete, can just use the rewriter
124
        nv = rewrite(tmp);
125
        Trace("sygus-unif-rl-purify")
126
            << "PurifyLemma : model value for " << tmp << " is " << nv << "\n";
127
      }
128
      else
129
      {
130
        nv = d_parent->getModelValue(n);
131
        Trace("sygus-unif-rl-purify")
132
            << "PurifyLemma : model value for " << n << " is " << nv << "\n";
133
      }
134
      Assert(n != nv);
135
    }
136
  }
137
  // Travese to purify
138
154
  bool childChanged = false;
139
308
  std::vector<Node> children;
140
154
  NodeManager* nm = NodeManager::currentNM();
141
154
  SkolemManager* sm = nm->getSkolemManager();
142
558
  for (unsigned i = 0; i < size; ++i)
143
  {
144
442
    if (i == 0 && fapp)
145
    {
146
38
      children.push_back(n[i]);
147
38
      continue;
148
    }
149
    // Arguments of non-unif functions do not need to be constant
150
    Node child = purifyLemma(
151
732
        n[i], !nu_fapp && (ensureConst || u_fapp), model_guards, cache);
152
366
    children.push_back(child);
153
366
    childChanged = childChanged || child != n[i];
154
  }
155
308
  Node nb;
156
154
  if (childChanged)
157
  {
158
41
    if (n.getMetaKind() == metakind::PARAMETERIZED)
159
    {
160
      Trace("sygus-unif-rl-purify-debug")
161
          << "Node " << n << " is parameterized\n";
162
      children.insert(children.begin(), n.getOperator());
163
    }
164
41
    if (Trace.isOn("sygus-unif-rl-purify-debug"))
165
    {
166
      Trace("sygus-unif-rl-purify-debug")
167
          << "...rebuilding " << n << " with kind " << k << " and children:\n";
168
      for (const Node& child : children)
169
      {
170
        Trace("sygus-unif-rl-purify-debug") << "...... " << child << "\n";
171
      }
172
    }
173
41
    nb = NodeManager::currentNM()->mkNode(k, children);
174
82
    Trace("sygus-unif-rl-purify")
175
41
        << "PurifyLemma : transformed " << n << " into " << nb << "\n";
176
  }
177
  else
178
  {
179
113
    nb = n;
180
  }
181
  // Map to point enumerator every unification function-to-synthesize
182
154
  if (u_fapp)
183
  {
184
70
    Node np;
185
35
    std::map<Node, Node>::const_iterator it2 = d_app_to_purified.find(nb);
186
35
    if (it2 == d_app_to_purified.end())
187
    {
188
      // Build purified head with fresh skolem and recreate node
189
52
      std::stringstream ss;
190
26
      ss << nb[0] << "_" << d_cand_to_hd_count[nb[0]]++;
191
52
      Node new_f = sm->mkDummySkolem(ss.str(),
192
52
                                     nb[0].getType(),
193
                                     "head of unif evaluation point",
194
104
                                     NodeManager::SKOLEM_EXACT_NAME);
195
      // Adds new enumerator to map from candidate
196
52
      Trace("sygus-unif-rl-purify")
197
26
          << "...new enum " << new_f << " for candidate " << nb[0] << "\n";
198
26
      d_cand_to_eval_hds[nb[0]].push_back(new_f);
199
      // Maps new enumerator to its respective tuple of arguments
200
26
      d_hd_to_pt[new_f] =
201
52
          std::vector<Node>(children.begin() + 1, children.end());
202
26
      if (Trace.isOn("sygus-unif-rl-purify-debug"))
203
      {
204
        Trace("sygus-unif-rl-purify-debug") << "...[" << new_f << "] --> ( ";
205
        for (const Node& pt_i : d_hd_to_pt[new_f])
206
        {
207
          Trace("sygus-unif-rl-purify-debug") << pt_i << " ";
208
        }
209
        Trace("sygus-unif-rl-purify-debug") << ")\n";
210
      }
211
      // replace first child and rebulid node
212
26
      Assert(children.size() > 0);
213
26
      children[0] = new_f;
214
52
      Trace("sygus-unif-rl-purify-debug")
215
26
          << "Make sygus eval app " << children << std::endl;
216
26
      np = nm->mkNode(DT_SYGUS_EVAL, children);
217
26
      d_app_to_purified[nb] = np;
218
    }
219
    else
220
    {
221
9
      np = it2->second;
222
    }
223
70
    Trace("sygus-unif-rl-purify")
224
35
        << "PurifyLemma : purified head and transformed " << nb << " into "
225
35
        << np << "\n";
226
35
    nb = np;
227
  }
228
  // Add equality between purified fapp and model value
229
154
  if (ensureConst && fapp)
230
  {
231
    model_guards.push_back(
232
        NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate());
233
    nb = nv;
234
    Trace("sygus-unif-rl-purify")
235
        << "PurifyLemma : adding model eq " << model_guards.back() << "\n";
236
  }
237
154
  nb = rewrite(nb);
238
  // every non-top level application of function-to-synthesize must be reduced
239
  // to a concrete constant
240
154
  Assert(!ensureConst || nb.isConst());
241
308
  Trace("sygus-unif-rl-purify-debug")
242
154
      << "... caching [" << n << "] = " << nb << "\n";
243
154
  cache[BoolNodePair(ensureConst, n)] = nb;
244
154
  return nb;
245
}
246
247
15
Node SygusUnifRl::addRefLemma(Node lemma,
248
                              std::map<Node, std::vector<Node>>& eval_hds)
249
{
250
30
  Trace("sygus-unif-rl-purify")
251
15
      << "Registering lemma at SygusUnif : " << lemma << "\n";
252
30
  std::vector<Node> model_guards;
253
30
  BoolNodePairMap cache;
254
  // cache previous sizes
255
30
  std::map<Node, unsigned> prev_n_eval_hds;
256
30
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
257
  {
258
15
    prev_n_eval_hds[cp.first] = cp.second.size();
259
  }
260
261
  // Make the purified lemma which will guide the unification utility.
262
15
  Node plem = purifyLemma(lemma, false, model_guards, cache);
263
15
  if (!model_guards.empty())
264
  {
265
    model_guards.push_back(plem);
266
    plem = NodeManager::currentNM()->mkNode(OR, model_guards);
267
  }
268
15
  plem = rewrite(plem);
269
15
  Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n";
270
271
15
  Trace("sygus-unif-rl-purify") << "Collect new evaluation points...\n";
272
30
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
273
  {
274
30
    Node c = cp.first;
275
15
    unsigned prevn = 0;
276
15
    std::map<Node, unsigned>::iterator itp = prev_n_eval_hds.find(c);
277
15
    if (itp != prev_n_eval_hds.end())
278
    {
279
15
      prevn = itp->second;
280
    }
281
41
    for (unsigned j = prevn, size = cp.second.size(); j < size; j++)
282
    {
283
26
      eval_hds[c].push_back(cp.second[j]);
284
      // Add new point to respective decision trees
285
26
      Assert(d_cand_cenums.find(c) != d_cand_cenums.end());
286
52
      for (const Node& cenum : d_cand_cenums[c])
287
      {
288
26
        Assert(d_cenum_to_stratpt.find(cenum) != d_cenum_to_stratpt.end());
289
52
        for (const Node& stratpt : d_cenum_to_stratpt[cenum])
290
        {
291
26
          Assert(d_stratpt_to_dt.find(stratpt) != d_stratpt_to_dt.end());
292
52
          Trace("sygus-unif-rl-dt")
293
26
              << "Register point with head " << cp.second[j]
294
26
              << " to strategy point " << stratpt << "\n";
295
          // Register new point from new head
296
26
          d_stratpt_to_dt[stratpt].d_hds.push_back(cp.second[j]);
297
        }
298
      }
299
    }
300
  }
301
302
30
  return plem;
303
}
304
305
147
void SygusUnifRl::initializeConstructSol() {}
306
147
void SygusUnifRl::initializeConstructSolFor(Node f) {}
307
147
bool SygusUnifRl::constructSolution(std::vector<Node>& sols,
308
                                    std::vector<Node>& lemmas)
309
{
310
147
  initializeConstructSol();
311
147
  bool successful = true;
312
300
  for (const Node& c : d_candidates)
313
  {
314
153
    if (!usingUnif(c))
315
    {
316
12
      Node v = d_parent->getModelValue(c);
317
6
      sols.push_back(v);
318
6
      continue;
319
    }
320
147
    initializeConstructSolFor(c);
321
    Node v = constructSol(
322
170
        c, d_strategy.at(c).getRootEnumerator(), role_equal, 0, lemmas);
323
271
    if (v.isNull())
324
    {
325
      // we continue trying to build solutions to accumulate potentitial
326
      // separation conditions from other decision trees
327
124
      successful = false;
328
124
      continue;
329
    }
330
23
    sols.push_back(v);
331
23
    d_cand_to_sol[c] = v;
332
  }
333
147
  return successful;
334
}
335
336
147
Node SygusUnifRl::constructSol(
337
    Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
338
{
339
147
  indent("sygus-unif-sol", ind);
340
147
  Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl;
341
  // retrieve strategy information
342
294
  TypeNode etn = e.getType();
343
147
  EnumTypeInfo& tinfo = d_strategy.at(f).getEnumTypeInfo(etn);
344
147
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
345
147
  if (nrole != role_equal)
346
  {
347
    return Node::null();
348
  }
349
  // is there a decision tree strategy?
350
147
  std::map<Node, DecisionTreeInfo>::iterator itd = d_stratpt_to_dt.find(e);
351
  // for now only considering simple case of sole "ITE(cond, e, e)" strategy
352
147
  if (itd == d_stratpt_to_dt.end())
353
  {
354
    return Node::null();
355
  }
356
147
  indent("sygus-unif-sol", ind);
357
147
  Trace("sygus-unif-sol") << "...it has a decision tree strategy.\n";
358
  // whether empty set of points
359
147
  if (d_cand_to_eval_hds[f].empty())
360
  {
361
16
    Trace("sygus-unif-sol") << "...... no points, return root enum value "
362
8
                            << d_parent->getModelValue(e) << "\n";
363
8
    return d_parent->getModelValue(e);
364
  }
365
139
  EnumTypeInfoStrat* etis = snode.d_strats[itd->second.getStrategyIndex()];
366
278
  Node sol = itd->second.buildSol(etis->d_cons, lemmas);
367
139
  Assert(d_useCondPool || !sol.isNull() || !lemmas.empty());
368
139
  return sol;
369
}
370
371
243
bool SygusUnifRl::usingUnif(Node f) const
372
{
373
243
  return d_unif_candidates.find(f) != d_unif_candidates.end();
374
}
375
376
11
Node SygusUnifRl::getConditionForEvaluationPoint(Node e) const
377
{
378
11
  std::map<Node, DecisionTreeInfo>::const_iterator it = d_stratpt_to_dt.find(e);
379
11
  Assert(it != d_stratpt_to_dt.end());
380
11
  return it->second.getConditionEnumerator();
381
}
382
383
147
void SygusUnifRl::setConditions(Node e,
384
                                Node guard,
385
                                const std::vector<Node>& enums,
386
                                const std::vector<Node>& conds)
387
{
388
147
  std::map<Node, DecisionTreeInfo>::iterator it = d_stratpt_to_dt.find(e);
389
147
  Assert(it != d_stratpt_to_dt.end());
390
  // set the conditions for the appropriate tree
391
147
  it->second.setConditions(guard, enums, conds);
392
147
}
393
394
323
std::vector<Node> SygusUnifRl::getEvalPointHeads(Node c)
395
{
396
323
  std::map<Node, std::vector<Node>>::iterator it = d_cand_to_eval_hds.find(c);
397
323
  if (it == d_cand_to_eval_hds.end())
398
  {
399
    return std::vector<Node>();
400
  }
401
323
  return it->second;
402
}
403
404
970
bool SygusUnifRl::usingConditionPool() const { return d_useCondPool; }
405
15
bool SygusUnifRl::usingConditionPoolInfoGain() const
406
{
407
15
  return d_useCondPoolIGain;
408
}
409
14
void SygusUnifRl::registerStrategy(
410
    Node f,
411
    std::vector<Node>& enums,
412
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
413
{
414
14
  if (Trace.isOn("sygus-unif-rl-strat"))
415
  {
416
    Trace("sygus-unif-rl-strat")
417
        << "Strategy for " << f << " is : " << std::endl;
418
    d_strategy.at(f).debugPrint("sygus-unif-rl-strat");
419
  }
420
14
  Trace("sygus-unif-rl-strat") << "Register..." << std::endl;
421
28
  Node e = d_strategy.at(f).getRootEnumerator();
422
28
  std::map<Node, std::map<NodeRole, bool>> visited;
423
14
  registerStrategyNode(f, e, role_equal, visited, enums, unused_strats);
424
14
}
425
426
14
void SygusUnifRl::registerStrategyNode(
427
    Node f,
428
    Node e,
429
    NodeRole nrole,
430
    std::map<Node, std::map<NodeRole, bool>>& visited,
431
    std::vector<Node>& enums,
432
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
433
{
434
14
  Trace("sygus-unif-rl-strat") << "  register node " << e << std::endl;
435
14
  if (visited[e].find(nrole) != visited[e].end())
436
  {
437
    return;
438
  }
439
14
  visited[e][nrole] = true;
440
28
  TypeNode etn = e.getType();
441
14
  EnumTypeInfo& tinfo = d_strategy.at(f).getEnumTypeInfo(etn);
442
14
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
443
27
  for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
444
  {
445
13
    EnumTypeInfoStrat* etis = snode.d_strats[j];
446
13
    StrategyType strat = etis->d_this;
447
    // is this a simple recursive ITE strategy?
448
13
    bool success = false;
449
13
    if (strat == strat_ITE && nrole == role_equal)
450
    {
451
11
      success = true;
452
33
      for (unsigned c = 1; c <= 2; c++)
453
      {
454
44
        std::pair<Node, NodeRole> child = etis->d_cenum[c];
455
22
        if (child.first != e || child.second != nrole)
456
        {
457
          success = false;
458
          break;
459
        }
460
      }
461
11
      if (success)
462
      {
463
22
        Node cond = etis->d_cenum[0].first;
464
11
        Assert(etis->d_cenum[0].second == role_ite_condition);
465
22
        Trace("sygus-unif-rl-strat")
466
11
            << "  ...detected recursive ITE strategy, condition enumerator : "
467
11
            << cond << std::endl;
468
        // indicate that we will be enumerating values for cond
469
11
        registerConditionalEnumerator(f, e, cond, j);
470
        // we will be using a strategy for e
471
11
        enums.push_back(e);
472
      }
473
    }
474
13
    if (!success)
475
    {
476
2
      unused_strats[e].insert(j);
477
    }
478
    // TODO: recurse? for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
479
  }
480
}
481
482
11
void SygusUnifRl::registerConditionalEnumerator(Node f,
483
                                                Node e,
484
                                                Node cond,
485
                                                unsigned strategy_index)
486
{
487
  // only allow one decision tree per strategy point
488
11
  if (d_stratpt_to_dt.find(e) != d_stratpt_to_dt.end())
489
  {
490
    return;
491
  }
492
  // we will do unification for this candidate
493
11
  d_unif_candidates.insert(f);
494
  // add to the list of all conditional enumerators
495
33
  if (std::find(d_cond_enums.begin(), d_cond_enums.end(), cond)
496
33
      == d_cond_enums.end())
497
  {
498
11
    d_cond_enums.push_back(cond);
499
11
    d_cand_cenums[f].push_back(cond);
500
11
    d_cenum_to_stratpt[cond].clear();
501
  }
502
  // register that this strategy node has a decision tree construction
503
11
  d_stratpt_to_dt[e].initialize(cond, this, &d_strategy.at(f), strategy_index);
504
  // associate conditional enumerator with strategy node
505
11
  d_cenum_to_stratpt[cond].push_back(e);
506
}
507
508
11
void SygusUnifRl::DecisionTreeInfo::initialize(Node cond_enum,
509
                                               SygusUnifRl* unif,
510
                                               SygusUnifStrategy* strategy,
511
                                               unsigned strategy_index)
512
{
513
11
  d_cond_enum = cond_enum;
514
11
  d_unif = unif;
515
11
  d_strategy = strategy;
516
11
  d_strategy_index = strategy_index;
517
11
  d_true = NodeManager::currentNM()->mkConst(true);
518
11
  d_false = NodeManager::currentNM()->mkConst(false);
519
  // Retrieve template
520
11
  EnumInfo& eiv = d_strategy->getEnumInfo(d_cond_enum);
521
11
  d_template = NodePair(eiv.d_template, eiv.d_template_arg);
522
  // Initialize classifier
523
11
  d_pt_sep.initialize(this);
524
11
}
525
526
147
void SygusUnifRl::DecisionTreeInfo::setConditions(
527
    Node guard, const std::vector<Node>& enums, const std::vector<Node>& conds)
528
{
529
147
  Assert(enums.size() == conds.size());
530
  // set the guard
531
147
  d_guard = guard;
532
  // clear old condition values
533
147
  d_enums.clear();
534
147
  d_conds.clear();
535
  // set new condition values
536
147
  d_enums.insert(d_enums.end(), enums.begin(), enums.end());
537
147
  d_conds.insert(d_conds.end(), conds.begin(), conds.end());
538
  // add to condition pool
539
147
  if (d_unif->usingConditionPool())
540
  {
541
    d_cond_mvs.insert(conds.begin(), conds.end());
542
    if (Trace.isOn("sygus-unif-cond-pool"))
543
    {
544
      for (const Node& condv : conds)
545
      {
546
        if (d_cond_mvs.find(condv) == d_cond_mvs.end())
547
        {
548
          Trace("sygus-unif-cond-pool")
549
              << "  ...adding to condition pool : "
550
              << d_unif->d_tds->sygusToBuiltin(condv, condv.getType()) << "\n";
551
        }
552
      }
553
    }
554
  }
555
147
}
556
557
139
unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const
558
{
559
139
  return d_strategy_index;
560
}
561
562
139
Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
563
                                             std::vector<Node>& lemmas)
564
{
565
139
  if (!d_template.first.isNull())
566
  {
567
    Trace("sygus-unif-sol") << "...templated conditions unsupported\n";
568
    return Node::null();
569
  }
570
278
  Trace("sygus-unif-sol") << "Decision::buildSol with " << d_hds.size()
571
278
                          << " evaluation heads and " << d_conds.size()
572
139
                          << " conditions..." << std::endl;
573
  // reset the trie
574
139
  d_pt_sep.d_trie.clear();
575
139
  return d_unif->usingConditionPool() ? buildSolAllCond(cons, lemmas)
576
139
                                      : buildSolMinCond(cons, lemmas);
577
}
578
579
Node SygusUnifRl::DecisionTreeInfo::buildSolAllCond(Node cons,
580
                                                    std::vector<Node>& lemmas)
581
{
582
  // model values for evaluation heads
583
  std::map<Node, Node> hd_mv;
584
  // add conditions
585
  d_conds.clear();
586
  d_conds.insert(d_conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
587
  // shuffle conditions before bulding DT
588
  //
589
  // this does not impact whether it's possible to build a solution, but it does
590
  // impact the potential size of the resulting solution (can make it smaller,
591
  // bigger, or have no impact) and which conditions will be present in the DT,
592
  // which influences the "quality" of the solution for cases not covered in the
593
  // current data points
594
  if (options::sygusUnifShuffleCond())
595
  {
596
    std::shuffle(d_conds.begin(), d_conds.end(), Random::getRandom());
597
  }
598
  unsigned num_conds = d_conds.size();
599
  for (unsigned i = 0; i < num_conds; ++i)
600
  {
601
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, i);
602
  }
603
  // add heads
604
  for (const Node& e : d_hds)
605
  {
606
    Node v = d_unif->d_parent->getModelValue(e);
607
    hd_mv[e] = v;
608
    Node er = d_pt_sep.d_trie.add(e, &d_pt_sep, num_conds);
609
    // are we in conflict?
610
    if (er == e)
611
    {
612
      // new separation class, no conflict
613
      continue;
614
    }
615
    Assert(hd_mv.find(er) != hd_mv.end());
616
    // merged into separation class with same model value, no conflict
617
    if (hd_mv[e] == hd_mv[er])
618
    {
619
      continue;
620
    }
621
    // conflict. Explanation?
622
    Trace("sygus-unif-sol")
623
        << "  ...can't separate " << e << " from " << er << std::endl;
624
    return Node::null();
625
  }
626
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
627
  Node sol = extractSol(cons, hd_mv);
628
  // repeated solution
629
  if (options::sygusUnifCondIndNoRepeatSol()
630
      && d_sols.find(sol) != d_sols.end())
631
  {
632
    return Node::null();
633
  }
634
  d_sols.insert(sol);
635
  return sol;
636
}
637
638
139
Node SygusUnifRl::DecisionTreeInfo::buildSolMinCond(Node cons,
639
                                                    std::vector<Node>& lemmas)
640
{
641
139
  NodeManager* nm = NodeManager::currentNM();
642
  // model values for evaluation heads
643
278
  std::map<Node, Node> hd_mv;
644
  // the current explanation of why there has not yet been a separation conflict
645
278
  std::vector<Node> exp;
646
  // is the above explanation ready to be sent out as a lemma?
647
139
  bool exp_conflict = false;
648
  // the index of the head we are considering
649
139
  unsigned hd_counter = 0;
650
  // the index of the condition we are considering
651
139
  unsigned c_counter = 0;
652
  // do we need to resolve a separation conflict?
653
139
  bool needs_sep_resolve = false;
654
  // This loop simultaneously builds the solution in terms of a lazy trie
655
  // (LazyTrieMulti), and checks whether a separation conflict exists. We
656
  // enforce that the separation conflicts we encounter while building
657
  // this solution are resolved, in order, by the condition enumerators.
658
  // If not, then we add a (conflict) lemma stating that the current model
659
  // value of the condition enumerator must be different. We also call this
660
  // a "separation lemma".
661
  //
662
  // As a simple example, say we have:
663
  //   evalution heads: (eval e1 0 0), (eval e2 1 2)
664
  //   conditions: c1
665
  // where M(e1) = x, M(e2) = y, and M(c1) = x>1. After adding e1 and e2, we are
666
  // in conflict since { e1, e2 } form a separation class, M(e1)!=M(e2), and
667
  // M(c1) does not separate e1 and e2 since:
668
  //   (x>1){x->0,y->0} = (x>1){x->1,y->2} = false
669
  // Hence, we would fail to build a solution in this case, and instead send a
670
  // separation lemma of the form:
671
  //   ~( e1 != e2 ^ c1 = [x<1] )
672
  //
673
  // Say we have:
674
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
675
  //   conditions: c1 c2
676
  // where M(e1) = x, M(e2) = y, M(e3) = x+1, M(c1) = x>0 and M(c2) = x<0.
677
  // After adding e1 and e2, { e1, e2 } form a separation class, M(e1)!=M(e2),
678
  // but M(c1) separates e1 and e2 since
679
  //   (x>0){x->0,y->0} = false, and
680
  //   (x>1){x->1,y->2} = true
681
  // Hence, we get new separation classes { e1 } and { e2 }, and afterwards
682
  // add e3. We then get { e2, e3 } as a separation class, which is also a
683
  // conflict since M(e2)!=M(e3). We check if M(c2) resolves this conflict.
684
  // It does not, since (x<1){x->0,y->0} = (x<1){x->1,y->2} = false. Hence,
685
  // we get a separation lemma:
686
  //  ~( c1 = [x>1] ^ e2 != e3 ^ c2 = [x<1] )
687
  //
688
  // Say we have:
689
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
690
  //   conditions: c1
691
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = x>0.
692
  // After adding e1 and e2, we have separation class { e1, e2 }. This is not a
693
  // conflict since M(e1)=M(e2). We then add e3, obtaining separation class
694
  // { e1, e2, e3 }, which is in conflict since M(e3)!=M(e1), and the condition
695
  // c1 does not separate e3 and the representative of this class, e1. Hence we
696
  // get a separation lemma of the form:
697
  //  ~( e1 = e2 ^ e1 != e3 ^ c1 = [x>0] )
698
  //
699
  // It also may be the case that we exhaust the pool of condition enumerators.
700
  // Say we have:
701
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
702
  //   conditions: c1
703
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = y>0. After adding e1, e2,
704
  // and e3, we have a separation class { e1, e2, e3 } that is in conflict
705
  // since M(e3)!=M(e1). We add the condition c1, which separates into new
706
  // equivalence classes { e1 }, { e2, e3 }. We are still in separation conflict
707
  // since M(e3)!=M(e2). However, we do not have any further conditions to use
708
  // to resolve this conflict. Thus, we add the separation lemma:
709
  //  ~( e1 = e2 ^ e1 != e3 ^ e2 != e3 ^ c1 = [y>0] ^ G_1 )
710
  // where G_1 is a guard stating that we use at most 1 condition.
711
278
  Node e;
712
278
  Node er;
713
939
  while (hd_counter < d_hds.size() || needs_sep_resolve)
714
  {
715
524
    if (!needs_sep_resolve)
716
    {
717
      // add the head to the trie
718
485
      e = d_hds[hd_counter];
719
485
      hd_mv[e] = d_unif->d_parent->getModelValue(e);
720
485
      if (Trace.isOn("sygus-unif-sol"))
721
      {
722
        std::stringstream ss;
723
        TermDbSygus::toStreamSygus(ss, hd_mv[e]);
724
        Trace("sygus-unif-sol")
725
            << "  add evaluation head (" << hd_counter << "/" << d_hds.size()
726
            << "): " << e << " -> " << ss.str() << std::endl;
727
      }
728
485
      hd_counter++;
729
      // get the representative of the trie
730
485
      er = d_pt_sep.d_trie.add(e, &d_pt_sep, c_counter);
731
485
      Trace("sygus-unif-sol") << "  ...separation class " << er << std::endl;
732
      // are we in conflict?
733
485
      if (er == e)
734
      {
735
        // new separation class, no conflict
736
500
        continue;
737
      }
738
346
      Assert(hd_mv.find(er) != hd_mv.end());
739
467
      if (hd_mv[er] == hd_mv[e])
740
      {
741
        // merged into separation class with same model value, no conflict
742
        // add to explanation
743
        // this states that it mattered that (er = e) at the time that e was
744
        // added to the trie. Notice that er and e may become separated later,
745
        // but to ensure the overall invariant, this equality must persist in
746
        // the explanation.
747
121
        exp.push_back(er.eqNode(e));
748
121
        Trace("sygus-unif-sol") << "  ...equal model values " << std::endl;
749
242
        Trace("sygus-unif-sol")
750
121
            << "  ...add to explanation " << er.eqNode(e) << std::endl;
751
121
        continue;
752
      }
753
    }
754
    // must include in the explanation that we hit a conflict at this point in
755
    // the construction
756
264
    exp.push_back(e.eqNode(er).negate());
757
    // we are in separation conflict, does the next condition resolve this?
758
    // check whether we have have exhausted our condition pool. If so, we
759
    // are in conflict and this conflict depends on the guard.
760
264
    if (c_counter >= d_conds.size())
761
    {
762
      // truncated separation lemma
763
40
      Assert(!d_guard.isNull());
764
40
      exp.push_back(d_guard);
765
40
      exp_conflict = true;
766
164
      break;
767
    }
768
224
    Assert(c_counter < d_conds.size());
769
263
    Node ce = d_enums[c_counter];
770
263
    Node cv = d_conds[c_counter];
771
224
    Assert(ce.getType() == cv.getType());
772
224
    if (Trace.isOn("sygus-unif-sol"))
773
    {
774
      std::stringstream ss;
775
      TermDbSygus::toStreamSygus(ss, cv);
776
      Trace("sygus-unif-sol")
777
          << "  add condition (" << c_counter << "/" << d_conds.size()
778
          << "): " << ce << " -> " << ss.str() << std::endl;
779
    }
780
    // cache the separation class
781
263
    std::vector<Node> prev_sep_c = d_pt_sep.d_trie.d_rep_to_class[er];
782
    // add new classifier
783
224
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, c_counter);
784
224
    c_counter++;
785
    // add to explanation
786
    // c_exp is a conjunction of testers applied to shared selector chains
787
263
    Node c_exp = d_unif->d_tds->getExplain()->getExplanationForEquality(ce, cv);
788
224
    exp.push_back(c_exp);
789
    std::map<Node, std::vector<Node>>::iterator itr =
790
224
        d_pt_sep.d_trie.d_rep_to_class.find(e);
791
    // since e is last in its separation class, if it becomes a representative,
792
    // then it is separated from all values in prev_sep_c
793
325
    if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
794
    {
795
202
      Trace("sygus-unif-sol")
796
101
          << "  ...resolves separation conflict with all" << std::endl;
797
101
      needs_sep_resolve = false;
798
101
      continue;
799
    }
800
123
    itr = d_pt_sep.d_trie.d_rep_to_class.find(er);
801
    // since er is first in its separation class, it remains a representative
802
123
    Assert(itr != d_pt_sep.d_trie.d_rep_to_class.end());
803
    // is e still in the separation class of er?
804
369
    if (std::find(itr->second.begin(), itr->second.end(), e)
805
369
        != itr->second.end())
806
    {
807
168
      Trace("sygus-unif-sol")
808
84
          << "  ...does not resolve separation conflict with current"
809
84
          << std::endl;
810
      // the condition does not separate e and er
811
      // this violates the invariant that the i^th conditional enumerator
812
      // resolves the i^th separation conflict
813
84
      exp_conflict = true;
814
84
      SygusTypeInfo& ti = d_unif->d_tds->getTypeInfo(ce.getType());
815
      // The reasoning below is only necessary if we use symbolic constructors.
816
84
      if (!ti.hasSubtermSymbolicCons())
817
      {
818
84
        break;
819
      }
820
      // Since the explanation of the condition (c_exp above) does not account
821
      // for builtin subterms, we additionally require that the valuation of
822
      // the condition is indeed different on the two points.
823
      // For example, say ce has model value equal to the SyGuS datatype term:
824
      //   C_leq_xy( 0, 1 )
825
      // where C_leq_xy is a SyGuS datatype constructor taking two integer
826
      // constants c_x and c_y, and whose builtin version is:
827
      //   (0*x + 1*y >= 0)
828
      // Then, c_exp above is:
829
      //   is-C_leq_xy( ce )
830
      // which is added to our explanation of the conflict, which does not
831
      // account for the values of the arguments of C_leq_xy.
832
      // Now, say that we are in a separation conflict due to f(1,2) and f(2,3)
833
      // being assigned different values; the value of ce does not separate
834
      // these two terms since:
835
      //   (y>=0) { x -> 1, y -> 2 } = (y>=0) { x -> 2, y -> 3 } = true
836
      // The code below adds a constraint that states that the above values are
837
      // the same, which is part of the reason for the conflict. In the above
838
      // example, we generate:
839
      //   (DT_SYGUS_EVAL ce 1 2) == (DT_SYGUS_EVAL ce 2 3) { ce -> M(ce) }
840
      // which unfolds via the SygusEvalUnfold utility to:
841
      //   ( (c_x ce)*1 + (c_y ce)*2 >= 0 ) == ( (c_x ce)*2 + (c_y ce)*3 >= 0 )
842
      // where c_x and c_y are the selectors of the subfields of C_leq_xy.
843
      Trace("sygus-unif-sol-sym")
844
          << "Explain symbolic separation conflict" << std::endl;
845
      std::map<Node, std::vector<Node>>::iterator ith;
846
      Node ceApp[2];
847
      SygusEvalUnfold* eunf = d_unif->d_tds->getEvalUnfold();
848
      std::map<Node, Node> vtm;
849
      vtm[ce] = cv;
850
      Trace("sygus-unif-sol-sym")
851
          << "Model value for " << ce << " is " << cv << std::endl;
852
      for (unsigned r = 0; r < 2; r++)
853
      {
854
        std::vector<Node> cechildren;
855
        cechildren.push_back(ce);
856
        Node ecurr = r == 0 ? e : er;
857
        ith = d_unif->d_hd_to_pt.find(ecurr);
858
        AlwaysAssert(ith != d_unif->d_hd_to_pt.end());
859
        cechildren.insert(
860
            cechildren.end(), ith->second.begin(), ith->second.end());
861
        Node cea = nm->mkNode(DT_SYGUS_EVAL, cechildren);
862
        Trace("sygus-unif-sol-sym")
863
            << "Sep conflict app #" << r << " : " << cea << std::endl;
864
        std::vector<Node> tmpExp;
865
        cea = eunf->unfold(cea, vtm, tmpExp, true, true);
866
        Trace("sygus-unif-sol-sym") << "Unfolded to : " << cea << std::endl;
867
        ceApp[r] = cea;
868
      }
869
      Node ceAppEq = ceApp[0].eqNode(ceApp[1]);
870
      Trace("sygus-unif-sol-sym")
871
          << "Sep conflict app explanation is : " << ceAppEq << std::endl;
872
      exp.push_back(ceAppEq);
873
      break;
874
    }
875
78
    Trace("sygus-unif-sol")
876
39
        << "  ...resolves separation conflict with current, but not all"
877
39
        << std::endl;
878
    // find the new term to resolve a separation
879
78
    Node new_er = Node::null();
880
    // scan the previous list and find the representative of the class that e is
881
    // now in
882
78
    for (unsigned i = 0, size = prev_sep_c.size(); i < size; i++)
883
    {
884
117
      Node check_er = prev_sep_c[i];
885
78
      if (check_er != er && check_er != e)
886
      {
887
39
        itr = d_pt_sep.d_trie.d_rep_to_class.find(check_er);
888
39
        if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
889
        {
890
117
          if (std::find(itr->second.begin(), itr->second.end(), e)
891
117
              != itr->second.end())
892
          {
893
39
            new_er = check_er;
894
39
            break;
895
          }
896
        }
897
      }
898
    }
899
    // should find exactly one
900
39
    Assert(!new_er.isNull());
901
39
    er = new_er;
902
39
    needs_sep_resolve = true;
903
  }
904
139
  if (exp_conflict)
905
  {
906
248
    Node lemma = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp);
907
124
    lemma = lemma.negate();
908
124
    Trace("sygus-unif-sol") << "  ......conflict is " << lemma << std::endl;
909
124
    lemmas.push_back(lemma);
910
124
    return Node::null();
911
  }
912
913
15
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
914
15
  return extractSol(cons, hd_mv);
915
}
916
917
15
Node SygusUnifRl::DecisionTreeInfo::extractSol(Node cons,
918
                                               std::map<Node, Node>& hd_mv)
919
{
920
  // rebuild decision tree using heuristic learning
921
15
  if (d_unif->usingConditionPoolInfoGain())
922
  {
923
    recomputeSolHeuristically(hd_mv);
924
  }
925
15
  return d_pt_sep.extractSol(cons, hd_mv);
926
}
927
928
15
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::extractSol(
929
    Node cons, std::map<Node, Node>& hd_mv)
930
{
931
  // Traverse trie and build ITE with cons
932
15
  NodeManager* nm = NodeManager::currentNM();
933
30
  std::map<IndTriePair, Node> cache;
934
15
  std::map<IndTriePair, Node>::iterator it;
935
30
  std::vector<IndTriePair> visit;
936
15
  unsigned index = 0;
937
  LazyTrie* trie;
938
15
  IndTriePair root = IndTriePair(0, &d_trie.d_trie);
939
15
  visit.push_back(root);
940
131
  while (!visit.empty())
941
  {
942
58
    index = visit.back().first;
943
58
    trie = visit.back().second;
944
58
    visit.pop_back();
945
58
    IndTriePair cur = IndTriePair(index, trie);
946
58
    it = cache.find(cur);
947
    // traverse children so results are saved to build node for parent
948
58
    if (it == cache.end())
949
    {
950
      // leaf
951
71
      if (trie->d_children.empty())
952
      {
953
28
        Assert(hd_mv.find(trie->d_lazy_child) != hd_mv.end());
954
28
        cache[cur] = hd_mv[trie->d_lazy_child];
955
56
        Trace("sygus-unif-sol-debug") << "......leaf, build "
956
84
                                      << d_dt->d_unif->d_tds->sygusToBuiltin(
957
84
                                             cache[cur], cache[cur].getType())
958
28
                                      << "\n";
959
73
        continue;
960
      }
961
15
      cache[cur] = Node::null();
962
15
      visit.push_back(cur);
963
43
      for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
964
      {
965
28
        visit.push_back(IndTriePair(index + 1, &p_nt.second));
966
      }
967
15
      continue;
968
    }
969
    // retrieve terms of children and build result
970
15
    Assert(it->second.isNull());
971
15
    Assert(trie->d_children.size() == 1 || trie->d_children.size() == 2);
972
28
    std::vector<Node> children(4);
973
15
    children[0] = cons;
974
15
    children[1] = d_dt->d_conds[index];
975
15
    unsigned i = 0;
976
43
    for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
977
    {
978
28
      i = p_nt.first.getConst<bool>() ? 2 : 3;
979
28
      Assert(cache.find(IndTriePair(index + 1, &p_nt.second)) != cache.end());
980
28
      children[i] = cache[IndTriePair(index + 1, &p_nt.second)];
981
28
      Assert(!children[i].isNull());
982
    }
983
    // condition is useless or result children are equal, no no need for ITE
984
17
    if (trie->d_children.size() == 1 || children[2] == children[3])
985
    {
986
2
      cache[cur] = children[i];
987
4
      Trace("sygus-unif-sol-debug")
988
2
          << "......no need for cond "
989
6
          << d_dt->d_unif->d_tds->sygusToBuiltin(d_dt->d_conds[index],
990
6
                                                 d_dt->d_conds[index].getType())
991
2
          << ", build "
992
6
          << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur],
993
6
                                                 cache[cur].getType())
994
2
          << "\n";
995
2
      continue;
996
    }
997
13
    Assert(trie->d_children.size() == 2);
998
13
    cache[cur] = nm->mkNode(APPLY_CONSTRUCTOR, children);
999
26
    Trace("sygus-unif-sol-debug")
1000
13
        << "......build node "
1001
26
        << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur], cache[cur].getType())
1002
13
        << "\n";
1003
  }
1004
15
  Assert(cache.find(root) != cache.end());
1005
15
  Assert(!cache.find(root)->second.isNull());
1006
30
  return cache[root];
1007
}
1008
1009
void SygusUnifRl::DecisionTreeInfo::recomputeSolHeuristically(
1010
    std::map<Node, Node>& hd_mv)
1011
{
1012
  // reset the trie
1013
  d_pt_sep.d_trie.clear();
1014
  // TODO workaround and not really sure this is the last condition, since I put
1015
  // a set here. Maybe make d_cond_mvs into a vector
1016
  Node backup_last_cond = d_conds.back();
1017
  d_conds.clear();
1018
  for (const Node& e : d_hds)
1019
  {
1020
    d_pt_sep.d_trie.add(e, &d_pt_sep, 0);
1021
  }
1022
  // init vector of conds
1023
  std::vector<Node> conds;
1024
  conds.insert(conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
1025
1026
  // recursively build trie by picking best condition for respective points
1027
  buildDtInfoGain(d_hds, conds, hd_mv, 1);
1028
  // if no condition was added (i.e. points are already classified at root
1029
  // level), use last condition as candidate
1030
  if (d_conds.empty())
1031
  {
1032
    Trace("sygus-unif-dt") << "......using last condition "
1033
                           << d_unif->d_tds->sygusToBuiltin(
1034
                                  backup_last_cond, backup_last_cond.getType())
1035
                           << " as candidate\n";
1036
    d_conds.push_back(backup_last_cond);
1037
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1038
  }
1039
}
1040
1041
void SygusUnifRl::DecisionTreeInfo::buildDtInfoGain(std::vector<Node>& hds,
1042
                                                    std::vector<Node> conds,
1043
                                                    std::map<Node, Node>& hd_mv,
1044
                                                    int ind)
1045
{
1046
  // test if fully classified
1047
  if (hds.size() < 2)
1048
  {
1049
    indent("sygus-unif-dt", ind);
1050
    Trace("sygus-unif-dt") << "..set fully classified: "
1051
                           << (hds.empty() ? "empty" : "unary") << "\n";
1052
    return;
1053
  }
1054
  Node v1 = hd_mv[hds[0]];
1055
  unsigned i = 1, size = hds.size();
1056
  for (; i < size; ++i)
1057
  {
1058
    if (hd_mv[hds[i]] != v1)
1059
    {
1060
      break;
1061
    }
1062
  }
1063
  if (i == size)
1064
  {
1065
    indent("sygus-unif-dt", ind);
1066
    Trace("sygus-unif-dt") << "..set fully classified: " << hds.size() << " "
1067
                           << (d_unif->d_tds->sygusToBuiltin(v1, v1.getType())
1068
                                       == d_true
1069
                                   ? "good"
1070
                                   : "bad")
1071
                           << " points\n";
1072
    return;
1073
  }
1074
  // pick condition to further classify
1075
  double maxgain = -1;
1076
  unsigned picked_cond = 0;
1077
  std::vector<std::pair<std::vector<Node>, std::vector<Node>>> splits;
1078
  double current_set_entropy = getEntropy(hds, hd_mv, ind);
1079
  for (unsigned j = 0, conds_size = conds.size(); j < conds_size; ++j)
1080
  {
1081
    std::pair<std::vector<Node>, std::vector<Node>> split =
1082
        evaluateCond(hds, conds[j]);
1083
    splits.push_back(split);
1084
    Assert(hds.size() == split.first.size() + split.second.size());
1085
    double gain =
1086
        current_set_entropy
1087
        - (split.first.size() * getEntropy(split.first, hd_mv, ind)
1088
           + split.second.size() * getEntropy(split.second, hd_mv, ind))
1089
              / hds.size();
1090
    indent("sygus-unif-dt-debug", ind);
1091
    Trace("sygus-unif-dt-debug")
1092
        << "..gain of "
1093
        << d_unif->d_tds->sygusToBuiltin(conds[j], conds[j].getType()) << " is "
1094
        << gain << "\n";
1095
    if (gain > maxgain)
1096
    {
1097
      maxgain = gain;
1098
      picked_cond = j;
1099
    }
1100
  }
1101
  // add picked condition
1102
  indent("sygus-unif-dt", ind);
1103
  Trace("sygus-unif-dt") << "..picked condition "
1104
                         << d_unif->d_tds->sygusToBuiltin(
1105
                                conds[picked_cond],
1106
                                conds[picked_cond].getType())
1107
                         << "\n";
1108
  d_conds.push_back(conds[picked_cond]);
1109
  conds.erase(conds.begin() + picked_cond);
1110
  d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1111
  // recurse
1112
  buildDtInfoGain(splits[picked_cond].first, conds, hd_mv, ind + 1);
1113
  buildDtInfoGain(splits[picked_cond].second, conds, hd_mv, ind + 1);
1114
}
1115
1116
std::pair<std::vector<Node>, std::vector<Node>>
1117
SygusUnifRl::DecisionTreeInfo::evaluateCond(std::vector<Node>& pts, Node cond)
1118
{
1119
  std::vector<Node> good, bad;
1120
  for (const Node& pt : pts)
1121
  {
1122
    if (d_pt_sep.computeCond(cond, pt) == d_true)
1123
    {
1124
      good.push_back(pt);
1125
      continue;
1126
    }
1127
    Assert(d_pt_sep.computeCond(cond, pt) == d_false);
1128
    bad.push_back(pt);
1129
  }
1130
  return std::pair<std::vector<Node>, std::vector<Node>>(good, bad);
1131
}
1132
1133
double SygusUnifRl::DecisionTreeInfo::getEntropy(const std::vector<Node>& hds,
1134
                                                 std::map<Node, Node>& hd_mv,
1135
                                                 int ind)
1136
{
1137
  double p = 0, n = 0;
1138
  TermDbSygus* tds = d_unif->d_tds;
1139
  // get number of points evaluated positively and negatively with feature
1140
  for (const Node& e : hds)
1141
  {
1142
    if (tds->sygusToBuiltin(hd_mv[e]) == d_true)
1143
    {
1144
      p++;
1145
      continue;
1146
    }
1147
    Assert(tds->sygusToBuiltin(hd_mv[e]) == d_false);
1148
    n++;
1149
  }
1150
  // compute entropy
1151
  return p == 0 || n == 0 ? 0
1152
                          : ((-p / (p + n)) * log2(p / (p + n)))
1153
                                - ((n / (p + n)) * log2(n / (p + n)));
1154
}
1155
1156
11
void SygusUnifRl::DecisionTreeInfo::PointSeparator::initialize(
1157
    DecisionTreeInfo* dt)
1158
{
1159
11
  d_dt = dt;
1160
11
}
1161
1162
813
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::evaluate(Node n,
1163
                                                             unsigned index)
1164
{
1165
813
  Assert(index < d_dt->d_conds.size());
1166
  // Retrieve respective built_in condition
1167
1626
  Node cond = d_dt->d_conds[index];
1168
1626
  return computeCond(cond, n);
1169
}
1170
1171
813
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::computeCond(Node cond,
1172
                                                                Node hd)
1173
{
1174
1626
  std::pair<Node, Node> cond_hd = std::pair<Node, Node>(cond, hd);
1175
  std::map<std::pair<Node, Node>, Node>::iterator it =
1176
813
      d_eval_cond_hd.find(cond_hd);
1177
813
  if (it != d_eval_cond_hd.end())
1178
  {
1179
694
    return it->second;
1180
  }
1181
238
  TypeNode tn = cond.getType();
1182
238
  Node builtin_cond = d_dt->d_unif->d_tds->sygusToBuiltin(cond, tn);
1183
  // Retrieve evaluation point
1184
119
  Assert(d_dt->d_unif->d_hd_to_pt.find(hd) != d_dt->d_unif->d_hd_to_pt.end());
1185
238
  std::vector<Node> pt = d_dt->d_unif->d_hd_to_pt[hd];
1186
  // compute the result
1187
119
  if (Trace.isOn("sygus-unif-rl-sep"))
1188
  {
1189
    Trace("sygus-unif-rl-sep")
1190
        << "Evaluate cond " << builtin_cond << " on pt " << hd << " ( ";
1191
    for (const Node& pti : pt)
1192
    {
1193
      Trace("sygus-unif-rl-sep") << pti << " ";
1194
    }
1195
    Trace("sygus-unif-rl-sep") << ")\n";
1196
  }
1197
238
  Node res = d_dt->d_unif->d_tds->evaluateBuiltin(tn, builtin_cond, pt);
1198
119
  Trace("sygus-unif-rl-sep") << "...got res = " << res << "\n";
1199
  // If condition is templated, recompute result accordingly
1200
238
  Node templ = d_dt->d_template.first;
1201
238
  TNode templ_var = d_dt->d_template.second;
1202
119
  if (!templ.isNull())
1203
  {
1204
    res = templ.substitute(templ_var, res);
1205
    res = Rewriter::rewrite(res);
1206
    Trace("sygus-unif-rl-sep")
1207
        << "...after template res = " << res << std::endl;
1208
  }
1209
119
  Assert(res.isConst());
1210
119
  d_eval_cond_hd[cond_hd] = res;
1211
119
  return res;
1212
}
1213
1214
}  // namespace quantifiers
1215
}  // namespace theory
1216
29577
}  // namespace cvc5