GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus/sygus_unif_rl.cpp Lines: 437 638 68.5 %
Date: 2021-05-22 Branches: 738 2486 29.7 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Haniel Barbosa, Andrew Reynolds, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of sygus_unif_rl.
14
 */
15
16
#include "theory/quantifiers/sygus/sygus_unif_rl.h"
17
18
#include "expr/skolem_manager.h"
19
#include "options/base_options.h"
20
#include "options/quantifiers_options.h"
21
#include "printer/printer.h"
22
#include "theory/quantifiers/sygus/synth_conjecture.h"
23
#include "theory/quantifiers/sygus/term_database_sygus.h"
24
#include "theory/rewriter.h"
25
#include "util/random.h"
26
27
#include <math.h>
28
29
using namespace cvc5::kind;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
35
1529
SygusUnifRl::SygusUnifRl(SynthConjecture* p)
36
1529
    : d_parent(p), d_useCondPool(false), d_useCondPoolIGain(false)
37
{
38
1529
}
39
1529
SygusUnifRl::~SygusUnifRl() {}
40
12
void SygusUnifRl::initializeCandidate(
41
    TermDbSygus* tds,
42
    Node f,
43
    std::vector<Node>& enums,
44
    std::map<Node, std::vector<Node>>& strategy_lemmas)
45
{
46
  // initialize
47
24
  std::vector<Node> all_enums;
48
12
  SygusUnif::initializeCandidate(tds, f, all_enums, strategy_lemmas);
49
  // based on the strategy inferred for each function, determine if we are
50
  // using a unification strategy that is compatible our approach.
51
24
  StrategyRestrictions restrictions;
52
25
  if (options::sygusBoolIteReturnConst())
53
  {
54
12
    restrictions.d_iteReturnBoolConst = true;
55
  }
56
  // register the strategy
57
12
  registerStrategy(f, enums, restrictions.d_unused_strategies);
58
12
  d_strategy[f].staticLearnRedundantOps(strategy_lemmas, restrictions);
59
  // Copy candidates and check whether CegisUnif for any of them
60
12
  if (d_unif_candidates.find(f) != d_unif_candidates.end())
61
  {
62
9
    d_hd_to_pt[f].clear();
63
9
    d_cand_to_eval_hds[f].clear();
64
9
    d_cand_to_hd_count[f] = 0;
65
  }
66
  // check whether we are using condition enumeration
67
12
  options::SygusUnifPiMode mode = options::sygusUnifPi();
68
12
  d_useCondPool = mode == options::SygusUnifPiMode::CENUM
69
12
                  || mode == options::SygusUnifPiMode::CENUM_IGAIN;
70
12
  d_useCondPoolIGain = mode == options::SygusUnifPiMode::CENUM_IGAIN;
71
12
}
72
73
void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
74
{
75
  // we do not use notify enumeration
76
  Assert(false);
77
}
78
79
948
Node SygusUnifRl::purifyLemma(Node n,
80
                              bool ensureConst,
81
                              std::vector<Node>& model_guards,
82
                              BoolNodePairMap& cache)
83
{
84
948
  Trace("sygus-unif-rl-purify") << "PurifyLemma : " << n << "\n";
85
  BoolNodePairMap::const_iterator it0 =
86
948
      cache.find(BoolNodePair(ensureConst, n));
87
948
  if (it0 != cache.end())
88
  {
89
488
    Trace("sygus-unif-rl-purify-debug") << "... already visited " << n << "\n";
90
488
    return it0->second;
91
  }
92
  // Recurse
93
460
  unsigned size = n.getNumChildren();
94
460
  Kind k = n.getKind();
95
  // We retrive model value now because purified node may not have a value
96
920
  Node nv = n;
97
  // Whether application of a function-to-synthesize
98
460
  bool fapp = (n.getKind() == DT_SYGUS_EVAL);
99
460
  bool u_fapp = false;
100
460
  bool nu_fapp = false;
101
460
  if (fapp)
102
  {
103
31
    Assert(std::find(d_candidates.begin(), d_candidates.end(), n[0])
104
           != d_candidates.end());
105
    // Whether application of a (non-)unification function-to-synthesize
106
31
    u_fapp = usingUnif(n[0]);
107
31
    nu_fapp = !usingUnif(n[0]);
108
    // get model value of non-top level applications of functions-to-synthesize
109
    // occurring under a unification function-to-synthesize
110
31
    if (ensureConst)
111
    {
112
2
      std::map<Node, Node>::iterator it1 = d_cand_to_sol.find(n[0]);
113
      // if function-to-synthesize, retrieve its built solution to replace in
114
      // the application before computing the model value
115
2
      AlwaysAssert(!u_fapp || it1 != d_cand_to_sol.end());
116
2
      if (it1 != d_cand_to_sol.end())
117
      {
118
4
        TNode cand = n[0];
119
4
        Node tmp = n.substitute(cand, it1->second);
120
        // should be concrete, can just use the rewriter
121
2
        nv = Rewriter::rewrite(tmp);
122
4
        Trace("sygus-unif-rl-purify")
123
2
            << "PurifyLemma : model value for " << tmp << " is " << nv << "\n";
124
      }
125
      else
126
      {
127
        nv = d_parent->getModelValue(n);
128
        Trace("sygus-unif-rl-purify")
129
            << "PurifyLemma : model value for " << n << " is " << nv << "\n";
130
      }
131
2
      Assert(n != nv);
132
    }
133
  }
134
  // Travese to purify
135
460
  bool childChanged = false;
136
920
  std::vector<Node> children;
137
460
  NodeManager* nm = NodeManager::currentNM();
138
460
  SkolemManager* sm = nm->getSkolemManager();
139
1428
  for (unsigned i = 0; i < size; ++i)
140
  {
141
999
    if (i == 0 && fapp)
142
    {
143
31
      children.push_back(n[i]);
144
31
      continue;
145
    }
146
    // Arguments of non-unif functions do not need to be constant
147
    Node child = purifyLemma(
148
1874
        n[i], !nu_fapp && (ensureConst || u_fapp), model_guards, cache);
149
937
    children.push_back(child);
150
937
    childChanged = childChanged || child != n[i];
151
  }
152
920
  Node nb;
153
460
  if (childChanged)
154
  {
155
327
    if (n.getMetaKind() == metakind::PARAMETERIZED)
156
    {
157
      Trace("sygus-unif-rl-purify-debug")
158
          << "Node " << n << " is parameterized\n";
159
      children.insert(children.begin(), n.getOperator());
160
    }
161
327
    if (Trace.isOn("sygus-unif-rl-purify-debug"))
162
    {
163
      Trace("sygus-unif-rl-purify-debug")
164
          << "...rebuilding " << n << " with kind " << k << " and children:\n";
165
      for (const Node& child : children)
166
      {
167
        Trace("sygus-unif-rl-purify-debug") << "...... " << child << "\n";
168
      }
169
    }
170
327
    nb = NodeManager::currentNM()->mkNode(k, children);
171
654
    Trace("sygus-unif-rl-purify")
172
327
        << "PurifyLemma : transformed " << n << " into " << nb << "\n";
173
  }
174
  else
175
  {
176
133
    nb = n;
177
  }
178
  // Map to point enumerator every unification function-to-synthesize
179
460
  if (u_fapp)
180
  {
181
56
    Node np;
182
28
    std::map<Node, Node>::const_iterator it2 = d_app_to_purified.find(nb);
183
28
    if (it2 == d_app_to_purified.end())
184
    {
185
      // Build purified head with fresh skolem and recreate node
186
38
      std::stringstream ss;
187
19
      ss << nb[0] << "_" << d_cand_to_hd_count[nb[0]]++;
188
38
      Node new_f = sm->mkDummySkolem(ss.str(),
189
38
                                     nb[0].getType(),
190
                                     "head of unif evaluation point",
191
76
                                     NodeManager::SKOLEM_EXACT_NAME);
192
      // Adds new enumerator to map from candidate
193
38
      Trace("sygus-unif-rl-purify")
194
19
          << "...new enum " << new_f << " for candidate " << nb[0] << "\n";
195
19
      d_cand_to_eval_hds[nb[0]].push_back(new_f);
196
      // Maps new enumerator to its respective tuple of arguments
197
19
      d_hd_to_pt[new_f] =
198
38
          std::vector<Node>(children.begin() + 1, children.end());
199
19
      if (Trace.isOn("sygus-unif-rl-purify-debug"))
200
      {
201
        Trace("sygus-unif-rl-purify-debug") << "...[" << new_f << "] --> ( ";
202
        for (const Node& pt_i : d_hd_to_pt[new_f])
203
        {
204
          Trace("sygus-unif-rl-purify-debug") << pt_i << " ";
205
        }
206
        Trace("sygus-unif-rl-purify-debug") << ")\n";
207
      }
208
      // replace first child and rebulid node
209
19
      Assert(children.size() > 0);
210
19
      children[0] = new_f;
211
38
      Trace("sygus-unif-rl-purify-debug")
212
19
          << "Make sygus eval app " << children << std::endl;
213
19
      np = nm->mkNode(DT_SYGUS_EVAL, children);
214
19
      d_app_to_purified[nb] = np;
215
    }
216
    else
217
    {
218
9
      np = it2->second;
219
    }
220
56
    Trace("sygus-unif-rl-purify")
221
28
        << "PurifyLemma : purified head and transformed " << nb << " into "
222
28
        << np << "\n";
223
28
    nb = np;
224
  }
225
  // Add equality between purified fapp and model value
226
460
  if (ensureConst && fapp)
227
  {
228
2
    model_guards.push_back(
229
4
        NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate());
230
2
    nb = nv;
231
4
    Trace("sygus-unif-rl-purify")
232
2
        << "PurifyLemma : adding model eq " << model_guards.back() << "\n";
233
  }
234
460
  nb = Rewriter::rewrite(nb);
235
  // every non-top level application of function-to-synthesize must be reduced
236
  // to a concrete constant
237
460
  Assert(!ensureConst || nb.isConst());
238
920
  Trace("sygus-unif-rl-purify-debug")
239
460
      << "... caching [" << n << "] = " << nb << "\n";
240
460
  cache[BoolNodePair(ensureConst, n)] = nb;
241
460
  return nb;
242
}
243
244
11
Node SygusUnifRl::addRefLemma(Node lemma,
245
                              std::map<Node, std::vector<Node>>& eval_hds)
246
{
247
22
  Trace("sygus-unif-rl-purify")
248
11
      << "Registering lemma at SygusUnif : " << lemma << "\n";
249
22
  std::vector<Node> model_guards;
250
22
  BoolNodePairMap cache;
251
  // cache previous sizes
252
22
  std::map<Node, unsigned> prev_n_eval_hds;
253
22
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
254
  {
255
11
    prev_n_eval_hds[cp.first] = cp.second.size();
256
  }
257
258
  // Make the purified lemma which will guide the unification utility.
259
11
  Node plem = purifyLemma(lemma, false, model_guards, cache);
260
11
  if (!model_guards.empty())
261
  {
262
1
    model_guards.push_back(plem);
263
1
    plem = NodeManager::currentNM()->mkNode(OR, model_guards);
264
  }
265
11
  plem = Rewriter::rewrite(plem);
266
11
  Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n";
267
268
11
  Trace("sygus-unif-rl-purify") << "Collect new evaluation points...\n";
269
22
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
270
  {
271
22
    Node c = cp.first;
272
11
    unsigned prevn = 0;
273
11
    std::map<Node, unsigned>::iterator itp = prev_n_eval_hds.find(c);
274
11
    if (itp != prev_n_eval_hds.end())
275
    {
276
11
      prevn = itp->second;
277
    }
278
30
    for (unsigned j = prevn, size = cp.second.size(); j < size; j++)
279
    {
280
19
      eval_hds[c].push_back(cp.second[j]);
281
      // Add new point to respective decision trees
282
19
      Assert(d_cand_cenums.find(c) != d_cand_cenums.end());
283
38
      for (const Node& cenum : d_cand_cenums[c])
284
      {
285
19
        Assert(d_cenum_to_stratpt.find(cenum) != d_cenum_to_stratpt.end());
286
38
        for (const Node& stratpt : d_cenum_to_stratpt[cenum])
287
        {
288
19
          Assert(d_stratpt_to_dt.find(stratpt) != d_stratpt_to_dt.end());
289
38
          Trace("sygus-unif-rl-dt")
290
19
              << "Register point with head " << cp.second[j]
291
19
              << " to strategy point " << stratpt << "\n";
292
          // Register new point from new head
293
19
          d_stratpt_to_dt[stratpt].d_hds.push_back(cp.second[j]);
294
        }
295
      }
296
    }
297
  }
298
299
22
  return plem;
300
}
301
302
141
void SygusUnifRl::initializeConstructSol() {}
303
141
void SygusUnifRl::initializeConstructSolFor(Node f) {}
304
141
bool SygusUnifRl::constructSolution(std::vector<Node>& sols,
305
                                    std::vector<Node>& lemmas)
306
{
307
141
  initializeConstructSol();
308
141
  bool successful = true;
309
288
  for (const Node& c : d_candidates)
310
  {
311
147
    if (!usingUnif(c))
312
    {
313
12
      Node v = d_parent->getModelValue(c);
314
6
      sols.push_back(v);
315
6
      continue;
316
    }
317
141
    initializeConstructSolFor(c);
318
    Node v = constructSol(
319
159
        c, d_strategy[c].getRootEnumerator(), role_equal, 0, lemmas);
320
264
    if (v.isNull())
321
    {
322
      // we continue trying to build solutions to accumulate potentitial
323
      // separation conditions from other decision trees
324
123
      successful = false;
325
123
      continue;
326
    }
327
18
    sols.push_back(v);
328
18
    d_cand_to_sol[c] = v;
329
  }
330
141
  return successful;
331
}
332
333
141
Node SygusUnifRl::constructSol(
334
    Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
335
{
336
141
  indent("sygus-unif-sol", ind);
337
141
  Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl;
338
  // retrieve strategy information
339
282
  TypeNode etn = e.getType();
340
141
  EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn);
341
141
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
342
141
  if (nrole != role_equal)
343
  {
344
    return Node::null();
345
  }
346
  // is there a decision tree strategy?
347
141
  std::map<Node, DecisionTreeInfo>::iterator itd = d_stratpt_to_dt.find(e);
348
  // for now only considering simple case of sole "ITE(cond, e, e)" strategy
349
141
  if (itd == d_stratpt_to_dt.end())
350
  {
351
    return Node::null();
352
  }
353
141
  indent("sygus-unif-sol", ind);
354
141
  Trace("sygus-unif-sol") << "...it has a decision tree strategy.\n";
355
  // whether empty set of points
356
141
  if (d_cand_to_eval_hds[f].empty())
357
  {
358
14
    Trace("sygus-unif-sol") << "...... no points, return root enum value "
359
7
                            << d_parent->getModelValue(e) << "\n";
360
7
    return d_parent->getModelValue(e);
361
  }
362
134
  EnumTypeInfoStrat* etis = snode.d_strats[itd->second.getStrategyIndex()];
363
268
  Node sol = itd->second.buildSol(etis->d_cons, lemmas);
364
134
  Assert(d_useCondPool || !sol.isNull() || !lemmas.empty());
365
134
  return sol;
366
}
367
368
221
bool SygusUnifRl::usingUnif(Node f) const
369
{
370
221
  return d_unif_candidates.find(f) != d_unif_candidates.end();
371
}
372
373
9
Node SygusUnifRl::getConditionForEvaluationPoint(Node e) const
374
{
375
9
  std::map<Node, DecisionTreeInfo>::const_iterator it = d_stratpt_to_dt.find(e);
376
9
  Assert(it != d_stratpt_to_dt.end());
377
9
  return it->second.getConditionEnumerator();
378
}
379
380
141
void SygusUnifRl::setConditions(Node e,
381
                                Node guard,
382
                                const std::vector<Node>& enums,
383
                                const std::vector<Node>& conds)
384
{
385
141
  std::map<Node, DecisionTreeInfo>::iterator it = d_stratpt_to_dt.find(e);
386
141
  Assert(it != d_stratpt_to_dt.end());
387
  // set the conditions for the appropriate tree
388
141
  it->second.setConditions(guard, enums, conds);
389
141
}
390
391
325
std::vector<Node> SygusUnifRl::getEvalPointHeads(Node c)
392
{
393
325
  std::map<Node, std::vector<Node>>::iterator it = d_cand_to_eval_hds.find(c);
394
325
  if (it == d_cand_to_eval_hds.end())
395
  {
396
    return std::vector<Node>();
397
  }
398
325
  return it->second;
399
}
400
401
906
bool SygusUnifRl::usingConditionPool() const { return d_useCondPool; }
402
11
bool SygusUnifRl::usingConditionPoolInfoGain() const
403
{
404
11
  return d_useCondPoolIGain;
405
}
406
12
void SygusUnifRl::registerStrategy(
407
    Node f,
408
    std::vector<Node>& enums,
409
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
410
{
411
12
  if (Trace.isOn("sygus-unif-rl-strat"))
412
  {
413
    Trace("sygus-unif-rl-strat")
414
        << "Strategy for " << f << " is : " << std::endl;
415
    d_strategy[f].debugPrint("sygus-unif-rl-strat");
416
  }
417
12
  Trace("sygus-unif-rl-strat") << "Register..." << std::endl;
418
24
  Node e = d_strategy[f].getRootEnumerator();
419
24
  std::map<Node, std::map<NodeRole, bool>> visited;
420
12
  registerStrategyNode(f, e, role_equal, visited, enums, unused_strats);
421
12
}
422
423
12
void SygusUnifRl::registerStrategyNode(
424
    Node f,
425
    Node e,
426
    NodeRole nrole,
427
    std::map<Node, std::map<NodeRole, bool>>& visited,
428
    std::vector<Node>& enums,
429
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
430
{
431
12
  Trace("sygus-unif-rl-strat") << "  register node " << e << std::endl;
432
12
  if (visited[e].find(nrole) != visited[e].end())
433
  {
434
    return;
435
  }
436
12
  visited[e][nrole] = true;
437
24
  TypeNode etn = e.getType();
438
12
  EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn);
439
12
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
440
23
  for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
441
  {
442
11
    EnumTypeInfoStrat* etis = snode.d_strats[j];
443
11
    StrategyType strat = etis->d_this;
444
    // is this a simple recursive ITE strategy?
445
11
    bool success = false;
446
11
    if (strat == strat_ITE && nrole == role_equal)
447
    {
448
9
      success = true;
449
27
      for (unsigned c = 1; c <= 2; c++)
450
      {
451
36
        std::pair<Node, NodeRole> child = etis->d_cenum[c];
452
18
        if (child.first != e || child.second != nrole)
453
        {
454
          success = false;
455
          break;
456
        }
457
      }
458
9
      if (success)
459
      {
460
18
        Node cond = etis->d_cenum[0].first;
461
9
        Assert(etis->d_cenum[0].second == role_ite_condition);
462
18
        Trace("sygus-unif-rl-strat")
463
9
            << "  ...detected recursive ITE strategy, condition enumerator : "
464
9
            << cond << std::endl;
465
        // indicate that we will be enumerating values for cond
466
9
        registerConditionalEnumerator(f, e, cond, j);
467
        // we will be using a strategy for e
468
9
        enums.push_back(e);
469
      }
470
    }
471
11
    if (!success)
472
    {
473
2
      unused_strats[e].insert(j);
474
    }
475
    // TODO: recurse? for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
476
  }
477
}
478
479
9
void SygusUnifRl::registerConditionalEnumerator(Node f,
480
                                                Node e,
481
                                                Node cond,
482
                                                unsigned strategy_index)
483
{
484
  // only allow one decision tree per strategy point
485
9
  if (d_stratpt_to_dt.find(e) != d_stratpt_to_dt.end())
486
  {
487
    return;
488
  }
489
  // we will do unification for this candidate
490
9
  d_unif_candidates.insert(f);
491
  // add to the list of all conditional enumerators
492
27
  if (std::find(d_cond_enums.begin(), d_cond_enums.end(), cond)
493
27
      == d_cond_enums.end())
494
  {
495
9
    d_cond_enums.push_back(cond);
496
9
    d_cand_cenums[f].push_back(cond);
497
9
    d_cenum_to_stratpt[cond].clear();
498
  }
499
  // register that this strategy node has a decision tree construction
500
9
  d_stratpt_to_dt[e].initialize(cond, this, &d_strategy[f], strategy_index);
501
  // associate conditional enumerator with strategy node
502
9
  d_cenum_to_stratpt[cond].push_back(e);
503
}
504
505
9
void SygusUnifRl::DecisionTreeInfo::initialize(Node cond_enum,
506
                                               SygusUnifRl* unif,
507
                                               SygusUnifStrategy* strategy,
508
                                               unsigned strategy_index)
509
{
510
9
  d_cond_enum = cond_enum;
511
9
  d_unif = unif;
512
9
  d_strategy = strategy;
513
9
  d_strategy_index = strategy_index;
514
9
  d_true = NodeManager::currentNM()->mkConst(true);
515
9
  d_false = NodeManager::currentNM()->mkConst(false);
516
  // Retrieve template
517
9
  EnumInfo& eiv = d_strategy->getEnumInfo(d_cond_enum);
518
9
  d_template = NodePair(eiv.d_template, eiv.d_template_arg);
519
  // Initialize classifier
520
9
  d_pt_sep.initialize(this);
521
9
}
522
523
141
void SygusUnifRl::DecisionTreeInfo::setConditions(
524
    Node guard, const std::vector<Node>& enums, const std::vector<Node>& conds)
525
{
526
141
  Assert(enums.size() == conds.size());
527
  // set the guard
528
141
  d_guard = guard;
529
  // clear old condition values
530
141
  d_enums.clear();
531
141
  d_conds.clear();
532
  // set new condition values
533
141
  d_enums.insert(d_enums.end(), enums.begin(), enums.end());
534
141
  d_conds.insert(d_conds.end(), conds.begin(), conds.end());
535
  // add to condition pool
536
141
  if (d_unif->usingConditionPool())
537
  {
538
    d_cond_mvs.insert(conds.begin(), conds.end());
539
    if (Trace.isOn("sygus-unif-cond-pool"))
540
    {
541
      for (const Node& condv : conds)
542
      {
543
        if (d_cond_mvs.find(condv) == d_cond_mvs.end())
544
        {
545
          Trace("sygus-unif-cond-pool")
546
              << "  ...adding to condition pool : "
547
              << d_unif->d_tds->sygusToBuiltin(condv, condv.getType()) << "\n";
548
        }
549
      }
550
    }
551
  }
552
141
}
553
554
134
unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const
555
{
556
134
  return d_strategy_index;
557
}
558
559
134
Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
560
                                             std::vector<Node>& lemmas)
561
{
562
134
  if (!d_template.first.isNull())
563
  {
564
    Trace("sygus-unif-sol") << "...templated conditions unsupported\n";
565
    return Node::null();
566
  }
567
268
  Trace("sygus-unif-sol") << "Decision::buildSol with " << d_hds.size()
568
268
                          << " evaluation heads and " << d_conds.size()
569
134
                          << " conditions..." << std::endl;
570
  // reset the trie
571
134
  d_pt_sep.d_trie.clear();
572
134
  return d_unif->usingConditionPool() ? buildSolAllCond(cons, lemmas)
573
134
                                      : buildSolMinCond(cons, lemmas);
574
}
575
576
Node SygusUnifRl::DecisionTreeInfo::buildSolAllCond(Node cons,
577
                                                    std::vector<Node>& lemmas)
578
{
579
  // model values for evaluation heads
580
  std::map<Node, Node> hd_mv;
581
  // add conditions
582
  d_conds.clear();
583
  d_conds.insert(d_conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
584
  // shuffle conditions before bulding DT
585
  //
586
  // this does not impact whether it's possible to build a solution, but it does
587
  // impact the potential size of the resulting solution (can make it smaller,
588
  // bigger, or have no impact) and which conditions will be present in the DT,
589
  // which influences the "quality" of the solution for cases not covered in the
590
  // current data points
591
  if (options::sygusUnifShuffleCond())
592
  {
593
    std::shuffle(d_conds.begin(), d_conds.end(), Random::getRandom());
594
  }
595
  unsigned num_conds = d_conds.size();
596
  for (unsigned i = 0; i < num_conds; ++i)
597
  {
598
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, i);
599
  }
600
  // add heads
601
  for (const Node& e : d_hds)
602
  {
603
    Node v = d_unif->d_parent->getModelValue(e);
604
    hd_mv[e] = v;
605
    Node er = d_pt_sep.d_trie.add(e, &d_pt_sep, num_conds);
606
    // are we in conflict?
607
    if (er == e)
608
    {
609
      // new separation class, no conflict
610
      continue;
611
    }
612
    Assert(hd_mv.find(er) != hd_mv.end());
613
    // merged into separation class with same model value, no conflict
614
    if (hd_mv[e] == hd_mv[er])
615
    {
616
      continue;
617
    }
618
    // conflict. Explanation?
619
    Trace("sygus-unif-sol")
620
        << "  ...can't separate " << e << " from " << er << std::endl;
621
    return Node::null();
622
  }
623
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
624
  Node sol = extractSol(cons, hd_mv);
625
  // repeated solution
626
  if (options::sygusUnifCondIndNoRepeatSol()
627
      && d_sols.find(sol) != d_sols.end())
628
  {
629
    return Node::null();
630
  }
631
  d_sols.insert(sol);
632
  return sol;
633
}
634
635
134
Node SygusUnifRl::DecisionTreeInfo::buildSolMinCond(Node cons,
636
                                                    std::vector<Node>& lemmas)
637
{
638
134
  NodeManager* nm = NodeManager::currentNM();
639
  // model values for evaluation heads
640
268
  std::map<Node, Node> hd_mv;
641
  // the current explanation of why there has not yet been a separation conflict
642
268
  std::vector<Node> exp;
643
  // is the above explanation ready to be sent out as a lemma?
644
134
  bool exp_conflict = false;
645
  // the index of the head we are considering
646
134
  unsigned hd_counter = 0;
647
  // the index of the condition we are considering
648
134
  unsigned c_counter = 0;
649
  // do we need to resolve a separation conflict?
650
134
  bool needs_sep_resolve = false;
651
  // This loop simultaneously builds the solution in terms of a lazy trie
652
  // (LazyTrieMulti), and checks whether a separation conflict exists. We
653
  // enforce that the separation conflicts we encounter while building
654
  // this solution are resolved, in order, by the condition enumerators.
655
  // If not, then we add a (conflict) lemma stating that the current model
656
  // value of the condition enumerator must be different. We also call this
657
  // a "separation lemma".
658
  //
659
  // As a simple example, say we have:
660
  //   evalution heads: (eval e1 0 0), (eval e2 1 2)
661
  //   conditions: c1
662
  // where M(e1) = x, M(e2) = y, and M(c1) = x>1. After adding e1 and e2, we are
663
  // in conflict since { e1, e2 } form a separation class, M(e1)!=M(e2), and
664
  // M(c1) does not separate e1 and e2 since:
665
  //   (x>1){x->0,y->0} = (x>1){x->1,y->2} = false
666
  // Hence, we would fail to build a solution in this case, and instead send a
667
  // separation lemma of the form:
668
  //   ~( e1 != e2 ^ c1 = [x<1] )
669
  //
670
  // Say we have:
671
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
672
  //   conditions: c1 c2
673
  // where M(e1) = x, M(e2) = y, M(e3) = x+1, M(c1) = x>0 and M(c2) = x<0.
674
  // After adding e1 and e2, { e1, e2 } form a separation class, M(e1)!=M(e2),
675
  // but M(c1) separates e1 and e2 since
676
  //   (x>0){x->0,y->0} = false, and
677
  //   (x>1){x->1,y->2} = true
678
  // Hence, we get new separation classes { e1 } and { e2 }, and afterwards
679
  // add e3. We then get { e2, e3 } as a separation class, which is also a
680
  // conflict since M(e2)!=M(e3). We check if M(c2) resolves this conflict.
681
  // It does not, since (x<1){x->0,y->0} = (x<1){x->1,y->2} = false. Hence,
682
  // we get a separation lemma:
683
  //  ~( c1 = [x>1] ^ e2 != e3 ^ c2 = [x<1] )
684
  //
685
  // Say we have:
686
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
687
  //   conditions: c1
688
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = x>0.
689
  // After adding e1 and e2, we have separation class { e1, e2 }. This is not a
690
  // conflict since M(e1)=M(e2). We then add e3, obtaining separation class
691
  // { e1, e2, e3 }, which is in conflict since M(e3)!=M(e1), and the condition
692
  // c1 does not separate e3 and the representative of this class, e1. Hence we
693
  // get a separation lemma of the form:
694
  //  ~( e1 = e2 ^ e1 != e3 ^ c1 = [x>0] )
695
  //
696
  // It also may be the case that we exhaust the pool of condition enumerators.
697
  // Say we have:
698
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
699
  //   conditions: c1
700
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = y>0. After adding e1, e2,
701
  // and e3, we have a separation class { e1, e2, e3 } that is in conflict
702
  // since M(e3)!=M(e1). We add the condition c1, which separates into new
703
  // equivalence classes { e1 }, { e2, e3 }. We are still in separation conflict
704
  // since M(e3)!=M(e2). However, we do not have any further conditions to use
705
  // to resolve this conflict. Thus, we add the separation lemma:
706
  //  ~( e1 = e2 ^ e1 != e3 ^ e2 != e3 ^ c1 = [y>0] ^ G_1 )
707
  // where G_1 is a guard stating that we use at most 1 condition.
708
268
  Node e;
709
268
  Node er;
710
804
  while (hd_counter < d_hds.size() || needs_sep_resolve)
711
  {
712
458
    if (!needs_sep_resolve)
713
    {
714
      // add the head to the trie
715
444
      e = d_hds[hd_counter];
716
444
      hd_mv[e] = d_unif->d_parent->getModelValue(e);
717
444
      if (Trace.isOn("sygus-unif-sol"))
718
      {
719
        std::stringstream ss;
720
        TermDbSygus::toStreamSygus(ss, hd_mv[e]);
721
        Trace("sygus-unif-sol")
722
            << "  add evaluation head (" << hd_counter << "/" << d_hds.size()
723
            << "): " << e << " -> " << ss.str() << std::endl;
724
      }
725
444
      hd_counter++;
726
      // get the representative of the trie
727
444
      er = d_pt_sep.d_trie.add(e, &d_pt_sep, c_counter);
728
444
      Trace("sygus-unif-sol") << "  ...separation class " << er << std::endl;
729
      // are we in conflict?
730
444
      if (er == e)
731
      {
732
        // new separation class, no conflict
733
455
        continue;
734
      }
735
310
      Assert(hd_mv.find(er) != hd_mv.end());
736
471
      if (hd_mv[er] == hd_mv[e])
737
      {
738
        // merged into separation class with same model value, no conflict
739
        // add to explanation
740
        // this states that it mattered that (er = e) at the time that e was
741
        // added to the trie. Notice that er and e may become separated later,
742
        // but to ensure the overall invariant, this equality must persist in
743
        // the explanation.
744
161
        exp.push_back(er.eqNode(e));
745
161
        Trace("sygus-unif-sol") << "  ...equal model values " << std::endl;
746
322
        Trace("sygus-unif-sol")
747
161
            << "  ...add to explanation " << er.eqNode(e) << std::endl;
748
161
        continue;
749
      }
750
    }
751
    // must include in the explanation that we hit a conflict at this point in
752
    // the construction
753
163
    exp.push_back(e.eqNode(er).negate());
754
    // we are in separation conflict, does the next condition resolve this?
755
    // check whether we have have exhausted our condition pool. If so, we
756
    // are in conflict and this conflict depends on the guard.
757
163
    if (c_counter >= d_conds.size())
758
    {
759
      // truncated separation lemma
760
31
      Assert(!d_guard.isNull());
761
31
      exp.push_back(d_guard);
762
31
      exp_conflict = true;
763
154
      break;
764
    }
765
132
    Assert(c_counter < d_conds.size());
766
146
    Node ce = d_enums[c_counter];
767
146
    Node cv = d_conds[c_counter];
768
132
    Assert(ce.getType() == cv.getType());
769
132
    if (Trace.isOn("sygus-unif-sol"))
770
    {
771
      std::stringstream ss;
772
      TermDbSygus::toStreamSygus(ss, cv);
773
      Trace("sygus-unif-sol")
774
          << "  add condition (" << c_counter << "/" << d_conds.size()
775
          << "): " << ce << " -> " << ss.str() << std::endl;
776
    }
777
    // cache the separation class
778
146
    std::vector<Node> prev_sep_c = d_pt_sep.d_trie.d_rep_to_class[er];
779
    // add new classifier
780
132
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, c_counter);
781
132
    c_counter++;
782
    // add to explanation
783
    // c_exp is a conjunction of testers applied to shared selector chains
784
146
    Node c_exp = d_unif->d_tds->getExplain()->getExplanationForEquality(ce, cv);
785
132
    exp.push_back(c_exp);
786
    std::map<Node, std::vector<Node>>::iterator itr =
787
132
        d_pt_sep.d_trie.d_rep_to_class.find(e);
788
    // since e is last in its separation class, if it becomes a representative,
789
    // then it is separated from all values in prev_sep_c
790
158
    if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
791
    {
792
52
      Trace("sygus-unif-sol")
793
26
          << "  ...resolves separation conflict with all" << std::endl;
794
26
      needs_sep_resolve = false;
795
26
      continue;
796
    }
797
106
    itr = d_pt_sep.d_trie.d_rep_to_class.find(er);
798
    // since er is first in its separation class, it remains a representative
799
106
    Assert(itr != d_pt_sep.d_trie.d_rep_to_class.end());
800
    // is e still in the separation class of er?
801
318
    if (std::find(itr->second.begin(), itr->second.end(), e)
802
318
        != itr->second.end())
803
    {
804
184
      Trace("sygus-unif-sol")
805
92
          << "  ...does not resolve separation conflict with current"
806
92
          << std::endl;
807
      // the condition does not separate e and er
808
      // this violates the invariant that the i^th conditional enumerator
809
      // resolves the i^th separation conflict
810
92
      exp_conflict = true;
811
92
      SygusTypeInfo& ti = d_unif->d_tds->getTypeInfo(ce.getType());
812
      // The reasoning below is only necessary if we use symbolic constructors.
813
92
      if (!ti.hasSubtermSymbolicCons())
814
      {
815
92
        break;
816
      }
817
      // Since the explanation of the condition (c_exp above) does not account
818
      // for builtin subterms, we additionally require that the valuation of
819
      // the condition is indeed different on the two points.
820
      // For example, say ce has model value equal to the SyGuS datatype term:
821
      //   C_leq_xy( 0, 1 )
822
      // where C_leq_xy is a SyGuS datatype constructor taking two integer
823
      // constants c_x and c_y, and whose builtin version is:
824
      //   (0*x + 1*y >= 0)
825
      // Then, c_exp above is:
826
      //   is-C_leq_xy( ce )
827
      // which is added to our explanation of the conflict, which does not
828
      // account for the values of the arguments of C_leq_xy.
829
      // Now, say that we are in a separation conflict due to f(1,2) and f(2,3)
830
      // being assigned different values; the value of ce does not separate
831
      // these two terms since:
832
      //   (y>=0) { x -> 1, y -> 2 } = (y>=0) { x -> 2, y -> 3 } = true
833
      // The code below adds a constraint that states that the above values are
834
      // the same, which is part of the reason for the conflict. In the above
835
      // example, we generate:
836
      //   (DT_SYGUS_EVAL ce 1 2) == (DT_SYGUS_EVAL ce 2 3) { ce -> M(ce) }
837
      // which unfolds via the SygusEvalUnfold utility to:
838
      //   ( (c_x ce)*1 + (c_y ce)*2 >= 0 ) == ( (c_x ce)*2 + (c_y ce)*3 >= 0 )
839
      // where c_x and c_y are the selectors of the subfields of C_leq_xy.
840
      Trace("sygus-unif-sol-sym")
841
          << "Explain symbolic separation conflict" << std::endl;
842
      std::map<Node, std::vector<Node>>::iterator ith;
843
      Node ceApp[2];
844
      SygusEvalUnfold* eunf = d_unif->d_tds->getEvalUnfold();
845
      std::map<Node, Node> vtm;
846
      vtm[ce] = cv;
847
      Trace("sygus-unif-sol-sym")
848
          << "Model value for " << ce << " is " << cv << std::endl;
849
      for (unsigned r = 0; r < 2; r++)
850
      {
851
        std::vector<Node> cechildren;
852
        cechildren.push_back(ce);
853
        Node ecurr = r == 0 ? e : er;
854
        ith = d_unif->d_hd_to_pt.find(ecurr);
855
        AlwaysAssert(ith != d_unif->d_hd_to_pt.end());
856
        cechildren.insert(
857
            cechildren.end(), ith->second.begin(), ith->second.end());
858
        Node cea = nm->mkNode(DT_SYGUS_EVAL, cechildren);
859
        Trace("sygus-unif-sol-sym")
860
            << "Sep conflict app #" << r << " : " << cea << std::endl;
861
        std::vector<Node> tmpExp;
862
        cea = eunf->unfold(cea, vtm, tmpExp, true, true);
863
        Trace("sygus-unif-sol-sym") << "Unfolded to : " << cea << std::endl;
864
        ceApp[r] = cea;
865
      }
866
      Node ceAppEq = ceApp[0].eqNode(ceApp[1]);
867
      Trace("sygus-unif-sol-sym")
868
          << "Sep conflict app explanation is : " << ceAppEq << std::endl;
869
      exp.push_back(ceAppEq);
870
      break;
871
    }
872
28
    Trace("sygus-unif-sol")
873
14
        << "  ...resolves separation conflict with current, but not all"
874
14
        << std::endl;
875
    // find the new term to resolve a separation
876
28
    Node new_er = Node::null();
877
    // scan the previous list and find the representative of the class that e is
878
    // now in
879
29
    for (unsigned i = 0, size = prev_sep_c.size(); i < size; i++)
880
    {
881
44
      Node check_er = prev_sep_c[i];
882
29
      if (check_er != er && check_er != e)
883
      {
884
15
        itr = d_pt_sep.d_trie.d_rep_to_class.find(check_er);
885
15
        if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
886
        {
887
42
          if (std::find(itr->second.begin(), itr->second.end(), e)
888
42
              != itr->second.end())
889
          {
890
14
            new_er = check_er;
891
14
            break;
892
          }
893
        }
894
      }
895
    }
896
    // should find exactly one
897
14
    Assert(!new_er.isNull());
898
14
    er = new_er;
899
14
    needs_sep_resolve = true;
900
  }
901
134
  if (exp_conflict)
902
  {
903
246
    Node lemma = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp);
904
123
    lemma = lemma.negate();
905
123
    Trace("sygus-unif-sol") << "  ......conflict is " << lemma << std::endl;
906
123
    lemmas.push_back(lemma);
907
123
    return Node::null();
908
  }
909
910
11
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
911
11
  return extractSol(cons, hd_mv);
912
}
913
914
11
Node SygusUnifRl::DecisionTreeInfo::extractSol(Node cons,
915
                                               std::map<Node, Node>& hd_mv)
916
{
917
  // rebuild decision tree using heuristic learning
918
11
  if (d_unif->usingConditionPoolInfoGain())
919
  {
920
    recomputeSolHeuristically(hd_mv);
921
  }
922
11
  return d_pt_sep.extractSol(cons, hd_mv);
923
}
924
925
11
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::extractSol(
926
    Node cons, std::map<Node, Node>& hd_mv)
927
{
928
  // Traverse trie and build ITE with cons
929
11
  NodeManager* nm = NodeManager::currentNM();
930
22
  std::map<IndTriePair, Node> cache;
931
11
  std::map<IndTriePair, Node>::iterator it;
932
22
  std::vector<IndTriePair> visit;
933
11
  unsigned index = 0;
934
  LazyTrie* trie;
935
11
  IndTriePair root = IndTriePair(0, &d_trie.d_trie);
936
11
  visit.push_back(root);
937
85
  while (!visit.empty())
938
  {
939
37
    index = visit.back().first;
940
37
    trie = visit.back().second;
941
37
    visit.pop_back();
942
37
    IndTriePair cur = IndTriePair(index, trie);
943
37
    it = cache.find(cur);
944
    // traverse children so results are saved to build node for parent
945
37
    if (it == cache.end())
946
    {
947
      // leaf
948
47
      if (trie->d_children.empty())
949
      {
950
19
        Assert(hd_mv.find(trie->d_lazy_child) != hd_mv.end());
951
19
        cache[cur] = hd_mv[trie->d_lazy_child];
952
38
        Trace("sygus-unif-sol-debug") << "......leaf, build "
953
57
                                      << d_dt->d_unif->d_tds->sygusToBuiltin(
954
57
                                             cache[cur], cache[cur].getType())
955
19
                                      << "\n";
956
48
        continue;
957
      }
958
9
      cache[cur] = Node::null();
959
9
      visit.push_back(cur);
960
26
      for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
961
      {
962
17
        visit.push_back(IndTriePair(index + 1, &p_nt.second));
963
      }
964
9
      continue;
965
    }
966
    // retrieve terms of children and build result
967
9
    Assert(it->second.isNull());
968
9
    Assert(trie->d_children.size() == 1 || trie->d_children.size() == 2);
969
17
    std::vector<Node> children(4);
970
9
    children[0] = cons;
971
9
    children[1] = d_dt->d_conds[index];
972
9
    unsigned i = 0;
973
26
    for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
974
    {
975
17
      i = p_nt.first.getConst<bool>() ? 2 : 3;
976
17
      Assert(cache.find(IndTriePair(index + 1, &p_nt.second)) != cache.end());
977
17
      children[i] = cache[IndTriePair(index + 1, &p_nt.second)];
978
17
      Assert(!children[i].isNull());
979
    }
980
    // condition is useless or result children are equal, no no need for ITE
981
10
    if (trie->d_children.size() == 1 || children[2] == children[3])
982
    {
983
1
      cache[cur] = children[i];
984
2
      Trace("sygus-unif-sol-debug")
985
1
          << "......no need for cond "
986
3
          << d_dt->d_unif->d_tds->sygusToBuiltin(d_dt->d_conds[index],
987
3
                                                 d_dt->d_conds[index].getType())
988
1
          << ", build "
989
3
          << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur],
990
3
                                                 cache[cur].getType())
991
1
          << "\n";
992
1
      continue;
993
    }
994
8
    Assert(trie->d_children.size() == 2);
995
8
    cache[cur] = nm->mkNode(APPLY_CONSTRUCTOR, children);
996
16
    Trace("sygus-unif-sol-debug")
997
8
        << "......build node "
998
16
        << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur], cache[cur].getType())
999
8
        << "\n";
1000
  }
1001
11
  Assert(cache.find(root) != cache.end());
1002
11
  Assert(!cache.find(root)->second.isNull());
1003
22
  return cache[root];
1004
}
1005
1006
void SygusUnifRl::DecisionTreeInfo::recomputeSolHeuristically(
1007
    std::map<Node, Node>& hd_mv)
1008
{
1009
  // reset the trie
1010
  d_pt_sep.d_trie.clear();
1011
  // TODO workaround and not really sure this is the last condition, since I put
1012
  // a set here. Maybe make d_cond_mvs into a vector
1013
  Node backup_last_cond = d_conds.back();
1014
  d_conds.clear();
1015
  for (const Node& e : d_hds)
1016
  {
1017
    d_pt_sep.d_trie.add(e, &d_pt_sep, 0);
1018
  }
1019
  // init vector of conds
1020
  std::vector<Node> conds;
1021
  conds.insert(conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
1022
1023
  // recursively build trie by picking best condition for respective points
1024
  buildDtInfoGain(d_hds, conds, hd_mv, 1);
1025
  // if no condition was added (i.e. points are already classified at root
1026
  // level), use last condition as candidate
1027
  if (d_conds.empty())
1028
  {
1029
    Trace("sygus-unif-dt") << "......using last condition "
1030
                           << d_unif->d_tds->sygusToBuiltin(
1031
                                  backup_last_cond, backup_last_cond.getType())
1032
                           << " as candidate\n";
1033
    d_conds.push_back(backup_last_cond);
1034
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1035
  }
1036
}
1037
1038
void SygusUnifRl::DecisionTreeInfo::buildDtInfoGain(std::vector<Node>& hds,
1039
                                                    std::vector<Node> conds,
1040
                                                    std::map<Node, Node>& hd_mv,
1041
                                                    int ind)
1042
{
1043
  // test if fully classified
1044
  if (hds.size() < 2)
1045
  {
1046
    indent("sygus-unif-dt", ind);
1047
    Trace("sygus-unif-dt") << "..set fully classified: "
1048
                           << (hds.empty() ? "empty" : "unary") << "\n";
1049
    return;
1050
  }
1051
  Node v1 = hd_mv[hds[0]];
1052
  unsigned i = 1, size = hds.size();
1053
  for (; i < size; ++i)
1054
  {
1055
    if (hd_mv[hds[i]] != v1)
1056
    {
1057
      break;
1058
    }
1059
  }
1060
  if (i == size)
1061
  {
1062
    indent("sygus-unif-dt", ind);
1063
    Trace("sygus-unif-dt") << "..set fully classified: " << hds.size() << " "
1064
                           << (d_unif->d_tds->sygusToBuiltin(v1, v1.getType())
1065
                                       == d_true
1066
                                   ? "good"
1067
                                   : "bad")
1068
                           << " points\n";
1069
    return;
1070
  }
1071
  // pick condition to further classify
1072
  double maxgain = -1;
1073
  unsigned picked_cond = 0;
1074
  std::vector<std::pair<std::vector<Node>, std::vector<Node>>> splits;
1075
  double current_set_entropy = getEntropy(hds, hd_mv, ind);
1076
  for (unsigned j = 0, conds_size = conds.size(); j < conds_size; ++j)
1077
  {
1078
    std::pair<std::vector<Node>, std::vector<Node>> split =
1079
        evaluateCond(hds, conds[j]);
1080
    splits.push_back(split);
1081
    Assert(hds.size() == split.first.size() + split.second.size());
1082
    double gain =
1083
        current_set_entropy
1084
        - (split.first.size() * getEntropy(split.first, hd_mv, ind)
1085
           + split.second.size() * getEntropy(split.second, hd_mv, ind))
1086
              / hds.size();
1087
    indent("sygus-unif-dt-debug", ind);
1088
    Trace("sygus-unif-dt-debug")
1089
        << "..gain of "
1090
        << d_unif->d_tds->sygusToBuiltin(conds[j], conds[j].getType()) << " is "
1091
        << gain << "\n";
1092
    if (gain > maxgain)
1093
    {
1094
      maxgain = gain;
1095
      picked_cond = j;
1096
    }
1097
  }
1098
  // add picked condition
1099
  indent("sygus-unif-dt", ind);
1100
  Trace("sygus-unif-dt") << "..picked condition "
1101
                         << d_unif->d_tds->sygusToBuiltin(
1102
                                conds[picked_cond],
1103
                                conds[picked_cond].getType())
1104
                         << "\n";
1105
  d_conds.push_back(conds[picked_cond]);
1106
  conds.erase(conds.begin() + picked_cond);
1107
  d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1108
  // recurse
1109
  buildDtInfoGain(splits[picked_cond].first, conds, hd_mv, ind + 1);
1110
  buildDtInfoGain(splits[picked_cond].second, conds, hd_mv, ind + 1);
1111
}
1112
1113
std::pair<std::vector<Node>, std::vector<Node>>
1114
SygusUnifRl::DecisionTreeInfo::evaluateCond(std::vector<Node>& pts, Node cond)
1115
{
1116
  std::vector<Node> good, bad;
1117
  for (const Node& pt : pts)
1118
  {
1119
    if (d_pt_sep.computeCond(cond, pt) == d_true)
1120
    {
1121
      good.push_back(pt);
1122
      continue;
1123
    }
1124
    Assert(d_pt_sep.computeCond(cond, pt) == d_false);
1125
    bad.push_back(pt);
1126
  }
1127
  return std::pair<std::vector<Node>, std::vector<Node>>(good, bad);
1128
}
1129
1130
double SygusUnifRl::DecisionTreeInfo::getEntropy(const std::vector<Node>& hds,
1131
                                                 std::map<Node, Node>& hd_mv,
1132
                                                 int ind)
1133
{
1134
  double p = 0, n = 0;
1135
  TermDbSygus* tds = d_unif->d_tds;
1136
  // get number of points evaluated positively and negatively with feature
1137
  for (const Node& e : hds)
1138
  {
1139
    if (tds->sygusToBuiltin(hd_mv[e]) == d_true)
1140
    {
1141
      p++;
1142
      continue;
1143
    }
1144
    Assert(tds->sygusToBuiltin(hd_mv[e]) == d_false);
1145
    n++;
1146
  }
1147
  // compute entropy
1148
  return p == 0 || n == 0 ? 0
1149
                          : ((-p / (p + n)) * log2(p / (p + n)))
1150
                                - ((n / (p + n)) * log2(n / (p + n)));
1151
}
1152
1153
9
void SygusUnifRl::DecisionTreeInfo::PointSeparator::initialize(
1154
    DecisionTreeInfo* dt)
1155
{
1156
9
  d_dt = dt;
1157
9
}
1158
1159
442
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::evaluate(Node n,
1160
                                                             unsigned index)
1161
{
1162
442
  Assert(index < d_dt->d_conds.size());
1163
  // Retrieve respective built_in condition
1164
884
  Node cond = d_dt->d_conds[index];
1165
884
  return computeCond(cond, n);
1166
}
1167
1168
442
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::computeCond(Node cond,
1169
                                                                Node hd)
1170
{
1171
884
  std::pair<Node, Node> cond_hd = std::pair<Node, Node>(cond, hd);
1172
  std::map<std::pair<Node, Node>, Node>::iterator it =
1173
442
      d_eval_cond_hd.find(cond_hd);
1174
442
  if (it != d_eval_cond_hd.end())
1175
  {
1176
167
    return it->second;
1177
  }
1178
550
  TypeNode tn = cond.getType();
1179
550
  Node builtin_cond = d_dt->d_unif->d_tds->sygusToBuiltin(cond, tn);
1180
  // Retrieve evaluation point
1181
275
  Assert(d_dt->d_unif->d_hd_to_pt.find(hd) != d_dt->d_unif->d_hd_to_pt.end());
1182
550
  std::vector<Node> pt = d_dt->d_unif->d_hd_to_pt[hd];
1183
  // compute the result
1184
275
  if (Trace.isOn("sygus-unif-rl-sep"))
1185
  {
1186
    Trace("sygus-unif-rl-sep")
1187
        << "Evaluate cond " << builtin_cond << " on pt " << hd << " ( ";
1188
    for (const Node& pti : pt)
1189
    {
1190
      Trace("sygus-unif-rl-sep") << pti << " ";
1191
    }
1192
    Trace("sygus-unif-rl-sep") << ")\n";
1193
  }
1194
550
  Node res = d_dt->d_unif->d_tds->evaluateBuiltin(tn, builtin_cond, pt);
1195
275
  Trace("sygus-unif-rl-sep") << "...got res = " << res << "\n";
1196
  // If condition is templated, recompute result accordingly
1197
550
  Node templ = d_dt->d_template.first;
1198
550
  TNode templ_var = d_dt->d_template.second;
1199
275
  if (!templ.isNull())
1200
  {
1201
    res = templ.substitute(templ_var, res);
1202
    res = Rewriter::rewrite(res);
1203
    Trace("sygus-unif-rl-sep")
1204
        << "...after template res = " << res << std::endl;
1205
  }
1206
275
  Assert(res.isConst());
1207
275
  d_eval_cond_hd[cond_hd] = res;
1208
275
  return res;
1209
}
1210
1211
}  // namespace quantifiers
1212
}  // namespace theory
1213
28207
}  // namespace cvc5