GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus/sygus_unif_rl.cpp Lines: 436 637 68.4 %
Date: 2021-03-22 Branches: 738 2486 29.7 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file sygus_unif_rl.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Haniel Barbosa, Andrew Reynolds, Mathias Preiner
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of sygus_unif_rl
13
 **/
14
15
#include "theory/quantifiers/sygus/sygus_unif_rl.h"
16
17
#include "options/base_options.h"
18
#include "options/quantifiers_options.h"
19
#include "printer/printer.h"
20
#include "theory/quantifiers/sygus/synth_conjecture.h"
21
#include "theory/quantifiers/sygus/term_database_sygus.h"
22
#include "theory/rewriter.h"
23
#include "util/random.h"
24
25
#include <math.h>
26
27
using namespace CVC4::kind;
28
29
namespace CVC4 {
30
namespace theory {
31
namespace quantifiers {
32
33
2190
SygusUnifRl::SygusUnifRl(SynthConjecture* p)
34
2190
    : d_parent(p), d_useCondPool(false), d_useCondPoolIGain(false)
35
{
36
2190
}
37
2188
SygusUnifRl::~SygusUnifRl() {}
38
23
void SygusUnifRl::initializeCandidate(
39
    QuantifiersEngine* qe,
40
    Node f,
41
    std::vector<Node>& enums,
42
    std::map<Node, std::vector<Node>>& strategy_lemmas)
43
{
44
  // initialize
45
46
  std::vector<Node> all_enums;
46
23
  SygusUnif::initializeCandidate(qe, f, all_enums, strategy_lemmas);
47
  // based on the strategy inferred for each function, determine if we are
48
  // using a unification strategy that is compatible our approach.
49
46
  StrategyRestrictions restrictions;
50
50
  if (options::sygusBoolIteReturnConst())
51
  {
52
23
    restrictions.d_iteReturnBoolConst = true;
53
  }
54
  // register the strategy
55
23
  registerStrategy(f, enums, restrictions.d_unused_strategies);
56
23
  d_strategy[f].staticLearnRedundantOps(strategy_lemmas, restrictions);
57
  // Copy candidates and check whether CegisUnif for any of them
58
23
  if (d_unif_candidates.find(f) != d_unif_candidates.end())
59
  {
60
17
    d_hd_to_pt[f].clear();
61
17
    d_cand_to_eval_hds[f].clear();
62
17
    d_cand_to_hd_count[f] = 0;
63
  }
64
  // check whether we are using condition enumeration
65
23
  options::SygusUnifPiMode mode = options::sygusUnifPi();
66
23
  d_useCondPool = mode == options::SygusUnifPiMode::CENUM
67
23
                  || mode == options::SygusUnifPiMode::CENUM_IGAIN;
68
23
  d_useCondPoolIGain = mode == options::SygusUnifPiMode::CENUM_IGAIN;
69
23
}
70
71
void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
72
{
73
  // we do not use notify enumeration
74
  Assert(false);
75
}
76
77
3112
Node SygusUnifRl::purifyLemma(Node n,
78
                              bool ensureConst,
79
                              std::vector<Node>& model_guards,
80
                              BoolNodePairMap& cache)
81
{
82
3112
  Trace("sygus-unif-rl-purify") << "PurifyLemma : " << n << "\n";
83
  BoolNodePairMap::const_iterator it0 =
84
3112
      cache.find(BoolNodePair(ensureConst, n));
85
3112
  if (it0 != cache.end())
86
  {
87
1976
    Trace("sygus-unif-rl-purify-debug") << "... already visited " << n << "\n";
88
1976
    return it0->second;
89
  }
90
  // Recurse
91
1136
  unsigned size = n.getNumChildren();
92
1136
  Kind k = n.getKind();
93
  // We retrive model value now because purified node may not have a value
94
2272
  Node nv = n;
95
  // Whether application of a function-to-synthesize
96
1136
  bool fapp = (n.getKind() == DT_SYGUS_EVAL);
97
1136
  bool u_fapp = false;
98
1136
  bool nu_fapp = false;
99
1136
  if (fapp)
100
  {
101
102
    Assert(std::find(d_candidates.begin(), d_candidates.end(), n[0])
102
           != d_candidates.end());
103
    // Whether application of a (non-)unification function-to-synthesize
104
102
    u_fapp = usingUnif(n[0]);
105
102
    nu_fapp = !usingUnif(n[0]);
106
    // get model value of non-top level applications of functions-to-synthesize
107
    // occurring under a unification function-to-synthesize
108
102
    if (ensureConst)
109
    {
110
4
      std::map<Node, Node>::iterator it1 = d_cand_to_sol.find(n[0]);
111
      // if function-to-synthesize, retrieve its built solution to replace in
112
      // the application before computing the model value
113
4
      AlwaysAssert(!u_fapp || it1 != d_cand_to_sol.end());
114
4
      if (it1 != d_cand_to_sol.end())
115
      {
116
8
        TNode cand = n[0];
117
8
        Node tmp = n.substitute(cand, it1->second);
118
        // should be concrete, can just use the rewriter
119
4
        nv = Rewriter::rewrite(tmp);
120
8
        Trace("sygus-unif-rl-purify")
121
4
            << "PurifyLemma : model value for " << tmp << " is " << nv << "\n";
122
      }
123
      else
124
      {
125
        nv = d_parent->getModelValue(n);
126
        Trace("sygus-unif-rl-purify")
127
            << "PurifyLemma : model value for " << n << " is " << nv << "\n";
128
      }
129
4
      Assert(n != nv);
130
    }
131
  }
132
  // Travese to purify
133
1136
  bool childChanged = false;
134
2272
  std::vector<Node> children;
135
1136
  NodeManager* nm = NodeManager::currentNM();
136
4304
  for (unsigned i = 0; i < size; ++i)
137
  {
138
3270
    if (i == 0 && fapp)
139
    {
140
102
      children.push_back(n[i]);
141
102
      continue;
142
    }
143
    // Arguments of non-unif functions do not need to be constant
144
    Node child = purifyLemma(
145
6132
        n[i], !nu_fapp && (ensureConst || u_fapp), model_guards, cache);
146
3066
    children.push_back(child);
147
3066
    childChanged = childChanged || child != n[i];
148
  }
149
2272
  Node nb;
150
1136
  if (childChanged)
151
  {
152
692
    if (n.getMetaKind() == metakind::PARAMETERIZED)
153
    {
154
      Trace("sygus-unif-rl-purify-debug")
155
          << "Node " << n << " is parameterized\n";
156
      children.insert(children.begin(), n.getOperator());
157
    }
158
692
    if (Trace.isOn("sygus-unif-rl-purify-debug"))
159
    {
160
      Trace("sygus-unif-rl-purify-debug")
161
          << "...rebuilding " << n << " with kind " << k << " and children:\n";
162
      for (const Node& child : children)
163
      {
164
        Trace("sygus-unif-rl-purify-debug") << "...... " << child << "\n";
165
      }
166
    }
167
692
    nb = NodeManager::currentNM()->mkNode(k, children);
168
1384
    Trace("sygus-unif-rl-purify")
169
692
        << "PurifyLemma : transformed " << n << " into " << nb << "\n";
170
  }
171
  else
172
  {
173
444
    nb = n;
174
  }
175
  // Map to point enumerator every unification function-to-synthesize
176
1136
  if (u_fapp)
177
  {
178
192
    Node np;
179
96
    std::map<Node, Node>::const_iterator it2 = d_app_to_purified.find(nb);
180
96
    if (it2 == d_app_to_purified.end())
181
    {
182
      // Build purified head with fresh skolem and recreate node
183
156
      std::stringstream ss;
184
78
      ss << nb[0] << "_" << d_cand_to_hd_count[nb[0]]++;
185
156
      Node new_f = nm->mkSkolem(ss.str(),
186
156
                                nb[0].getType(),
187
                                "head of unif evaluation point",
188
312
                                NodeManager::SKOLEM_EXACT_NAME);
189
      // Adds new enumerator to map from candidate
190
156
      Trace("sygus-unif-rl-purify")
191
78
          << "...new enum " << new_f << " for candidate " << nb[0] << "\n";
192
78
      d_cand_to_eval_hds[nb[0]].push_back(new_f);
193
      // Maps new enumerator to its respective tuple of arguments
194
78
      d_hd_to_pt[new_f] =
195
156
          std::vector<Node>(children.begin() + 1, children.end());
196
78
      if (Trace.isOn("sygus-unif-rl-purify-debug"))
197
      {
198
        Trace("sygus-unif-rl-purify-debug") << "...[" << new_f << "] --> ( ";
199
        for (const Node& pt_i : d_hd_to_pt[new_f])
200
        {
201
          Trace("sygus-unif-rl-purify-debug") << pt_i << " ";
202
        }
203
        Trace("sygus-unif-rl-purify-debug") << ")\n";
204
      }
205
      // replace first child and rebulid node
206
78
      Assert(children.size() > 0);
207
78
      children[0] = new_f;
208
156
      Trace("sygus-unif-rl-purify-debug")
209
78
          << "Make sygus eval app " << children << std::endl;
210
78
      np = nm->mkNode(DT_SYGUS_EVAL, children);
211
78
      d_app_to_purified[nb] = np;
212
    }
213
    else
214
    {
215
18
      np = it2->second;
216
    }
217
192
    Trace("sygus-unif-rl-purify")
218
96
        << "PurifyLemma : purified head and transformed " << nb << " into "
219
96
        << np << "\n";
220
96
    nb = np;
221
  }
222
  // Add equality between purified fapp and model value
223
1136
  if (ensureConst && fapp)
224
  {
225
4
    model_guards.push_back(
226
8
        NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate());
227
4
    nb = nv;
228
8
    Trace("sygus-unif-rl-purify")
229
4
        << "PurifyLemma : adding model eq " << model_guards.back() << "\n";
230
  }
231
1136
  nb = Rewriter::rewrite(nb);
232
  // every non-top level application of function-to-synthesize must be reduced
233
  // to a concrete constant
234
1136
  Assert(!ensureConst || nb.isConst());
235
2272
  Trace("sygus-unif-rl-purify-debug")
236
1136
      << "... caching [" << n << "] = " << nb << "\n";
237
1136
  cache[BoolNodePair(ensureConst, n)] = nb;
238
1136
  return nb;
239
}
240
241
46
Node SygusUnifRl::addRefLemma(Node lemma,
242
                              std::map<Node, std::vector<Node>>& eval_hds)
243
{
244
92
  Trace("sygus-unif-rl-purify")
245
46
      << "Registering lemma at SygusUnif : " << lemma << "\n";
246
92
  std::vector<Node> model_guards;
247
92
  BoolNodePairMap cache;
248
  // cache previous sizes
249
92
  std::map<Node, unsigned> prev_n_eval_hds;
250
92
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
251
  {
252
46
    prev_n_eval_hds[cp.first] = cp.second.size();
253
  }
254
255
  // Make the purified lemma which will guide the unification utility.
256
46
  Node plem = purifyLemma(lemma, false, model_guards, cache);
257
46
  if (!model_guards.empty())
258
  {
259
2
    model_guards.push_back(plem);
260
2
    plem = NodeManager::currentNM()->mkNode(OR, model_guards);
261
  }
262
46
  plem = Rewriter::rewrite(plem);
263
46
  Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n";
264
265
46
  Trace("sygus-unif-rl-purify") << "Collect new evaluation points...\n";
266
92
  for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
267
  {
268
92
    Node c = cp.first;
269
46
    unsigned prevn = 0;
270
46
    std::map<Node, unsigned>::iterator itp = prev_n_eval_hds.find(c);
271
46
    if (itp != prev_n_eval_hds.end())
272
    {
273
46
      prevn = itp->second;
274
    }
275
124
    for (unsigned j = prevn, size = cp.second.size(); j < size; j++)
276
    {
277
78
      eval_hds[c].push_back(cp.second[j]);
278
      // Add new point to respective decision trees
279
78
      Assert(d_cand_cenums.find(c) != d_cand_cenums.end());
280
156
      for (const Node& cenum : d_cand_cenums[c])
281
      {
282
78
        Assert(d_cenum_to_stratpt.find(cenum) != d_cenum_to_stratpt.end());
283
156
        for (const Node& stratpt : d_cenum_to_stratpt[cenum])
284
        {
285
78
          Assert(d_stratpt_to_dt.find(stratpt) != d_stratpt_to_dt.end());
286
156
          Trace("sygus-unif-rl-dt")
287
78
              << "Register point with head " << cp.second[j]
288
78
              << " to strategy point " << stratpt << "\n";
289
          // Register new point from new head
290
78
          d_stratpt_to_dt[stratpt].d_hds.push_back(cp.second[j]);
291
        }
292
      }
293
    }
294
  }
295
296
92
  return plem;
297
}
298
299
1348
void SygusUnifRl::initializeConstructSol() {}
300
1348
void SygusUnifRl::initializeConstructSolFor(Node f) {}
301
1348
bool SygusUnifRl::constructSolution(std::vector<Node>& sols,
302
                                    std::vector<Node>& lemmas)
303
{
304
1348
  initializeConstructSol();
305
1348
  bool successful = true;
306
2708
  for (const Node& c : d_candidates)
307
  {
308
1360
    if (!usingUnif(c))
309
    {
310
24
      Node v = d_parent->getModelValue(c);
311
12
      sols.push_back(v);
312
12
      continue;
313
    }
314
1348
    initializeConstructSolFor(c);
315
    Node v = constructSol(
316
1408
        c, d_strategy[c].getRootEnumerator(), role_equal, 0, lemmas);
317
2636
    if (v.isNull())
318
    {
319
      // we continue trying to build solutions to accumulate potentitial
320
      // separation conditions from other decision trees
321
1288
      successful = false;
322
1288
      continue;
323
    }
324
60
    sols.push_back(v);
325
60
    d_cand_to_sol[c] = v;
326
  }
327
1348
  return successful;
328
}
329
330
1348
Node SygusUnifRl::constructSol(
331
    Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
332
{
333
1348
  indent("sygus-unif-sol", ind);
334
1348
  Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl;
335
  // retrieve strategy information
336
2696
  TypeNode etn = e.getType();
337
1348
  EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn);
338
1348
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
339
1348
  if (nrole != role_equal)
340
  {
341
    return Node::null();
342
  }
343
  // is there a decision tree strategy?
344
1348
  std::map<Node, DecisionTreeInfo>::iterator itd = d_stratpt_to_dt.find(e);
345
  // for now only considering simple case of sole "ITE(cond, e, e)" strategy
346
1348
  if (itd == d_stratpt_to_dt.end())
347
  {
348
    return Node::null();
349
  }
350
1348
  indent("sygus-unif-sol", ind);
351
1348
  Trace("sygus-unif-sol") << "...it has a decision tree strategy.\n";
352
  // whether empty set of points
353
1348
  if (d_cand_to_eval_hds[f].empty())
354
  {
355
28
    Trace("sygus-unif-sol") << "...... no points, return root enum value "
356
14
                            << d_parent->getModelValue(e) << "\n";
357
14
    return d_parent->getModelValue(e);
358
  }
359
1334
  EnumTypeInfoStrat* etis = snode.d_strats[itd->second.getStrategyIndex()];
360
2668
  Node sol = itd->second.buildSol(etis->d_cons, lemmas);
361
1334
  Assert(d_useCondPool || !sol.isNull() || !lemmas.empty());
362
1334
  return sol;
363
}
364
365
1587
bool SygusUnifRl::usingUnif(Node f) const
366
{
367
1587
  return d_unif_candidates.find(f) != d_unif_candidates.end();
368
}
369
370
17
Node SygusUnifRl::getConditionForEvaluationPoint(Node e) const
371
{
372
17
  std::map<Node, DecisionTreeInfo>::const_iterator it = d_stratpt_to_dt.find(e);
373
17
  Assert(it != d_stratpt_to_dt.end());
374
17
  return it->second.getConditionEnumerator();
375
}
376
377
1348
void SygusUnifRl::setConditions(Node e,
378
                                Node guard,
379
                                const std::vector<Node>& enums,
380
                                const std::vector<Node>& conds)
381
{
382
1348
  std::map<Node, DecisionTreeInfo>::iterator it = d_stratpt_to_dt.find(e);
383
1348
  Assert(it != d_stratpt_to_dt.end());
384
  // set the conditions for the appropriate tree
385
1348
  it->second.setConditions(guard, enums, conds);
386
1348
}
387
388
1886
std::vector<Node> SygusUnifRl::getEvalPointHeads(Node c)
389
{
390
1886
  std::map<Node, std::vector<Node>>::iterator it = d_cand_to_eval_hds.find(c);
391
1886
  if (it == d_cand_to_eval_hds.end())
392
  {
393
    return std::vector<Node>();
394
  }
395
1886
  return it->second;
396
}
397
398
7238
bool SygusUnifRl::usingConditionPool() const { return d_useCondPool; }
399
46
bool SygusUnifRl::usingConditionPoolInfoGain() const
400
{
401
46
  return d_useCondPoolIGain;
402
}
403
23
void SygusUnifRl::registerStrategy(
404
    Node f,
405
    std::vector<Node>& enums,
406
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
407
{
408
23
  if (Trace.isOn("sygus-unif-rl-strat"))
409
  {
410
    Trace("sygus-unif-rl-strat")
411
        << "Strategy for " << f << " is : " << std::endl;
412
    d_strategy[f].debugPrint("sygus-unif-rl-strat");
413
  }
414
23
  Trace("sygus-unif-rl-strat") << "Register..." << std::endl;
415
46
  Node e = d_strategy[f].getRootEnumerator();
416
46
  std::map<Node, std::map<NodeRole, bool>> visited;
417
23
  registerStrategyNode(f, e, role_equal, visited, enums, unused_strats);
418
23
}
419
420
23
void SygusUnifRl::registerStrategyNode(
421
    Node f,
422
    Node e,
423
    NodeRole nrole,
424
    std::map<Node, std::map<NodeRole, bool>>& visited,
425
    std::vector<Node>& enums,
426
    std::map<Node, std::unordered_set<unsigned>>& unused_strats)
427
{
428
23
  Trace("sygus-unif-rl-strat") << "  register node " << e << std::endl;
429
23
  if (visited[e].find(nrole) != visited[e].end())
430
  {
431
    return;
432
  }
433
23
  visited[e][nrole] = true;
434
46
  TypeNode etn = e.getType();
435
23
  EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn);
436
23
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
437
43
  for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
438
  {
439
20
    EnumTypeInfoStrat* etis = snode.d_strats[j];
440
20
    StrategyType strat = etis->d_this;
441
    // is this a simple recursive ITE strategy?
442
20
    bool success = false;
443
20
    if (strat == strat_ITE && nrole == role_equal)
444
    {
445
17
      success = true;
446
51
      for (unsigned c = 1; c <= 2; c++)
447
      {
448
68
        std::pair<Node, NodeRole> child = etis->d_cenum[c];
449
34
        if (child.first != e || child.second != nrole)
450
        {
451
          success = false;
452
          break;
453
        }
454
      }
455
17
      if (success)
456
      {
457
34
        Node cond = etis->d_cenum[0].first;
458
17
        Assert(etis->d_cenum[0].second == role_ite_condition);
459
34
        Trace("sygus-unif-rl-strat")
460
17
            << "  ...detected recursive ITE strategy, condition enumerator : "
461
17
            << cond << std::endl;
462
        // indicate that we will be enumerating values for cond
463
17
        registerConditionalEnumerator(f, e, cond, j);
464
        // we will be using a strategy for e
465
17
        enums.push_back(e);
466
      }
467
    }
468
20
    if (!success)
469
    {
470
3
      unused_strats[e].insert(j);
471
    }
472
    // TODO: recurse? for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
473
  }
474
}
475
476
17
void SygusUnifRl::registerConditionalEnumerator(Node f,
477
                                                Node e,
478
                                                Node cond,
479
                                                unsigned strategy_index)
480
{
481
  // only allow one decision tree per strategy point
482
17
  if (d_stratpt_to_dt.find(e) != d_stratpt_to_dt.end())
483
  {
484
    return;
485
  }
486
  // we will do unification for this candidate
487
17
  d_unif_candidates.insert(f);
488
  // add to the list of all conditional enumerators
489
51
  if (std::find(d_cond_enums.begin(), d_cond_enums.end(), cond)
490
51
      == d_cond_enums.end())
491
  {
492
17
    d_cond_enums.push_back(cond);
493
17
    d_cand_cenums[f].push_back(cond);
494
17
    d_cenum_to_stratpt[cond].clear();
495
  }
496
  // register that this strategy node has a decision tree construction
497
17
  d_stratpt_to_dt[e].initialize(cond, this, &d_strategy[f], strategy_index);
498
  // associate conditional enumerator with strategy node
499
17
  d_cenum_to_stratpt[cond].push_back(e);
500
}
501
502
17
void SygusUnifRl::DecisionTreeInfo::initialize(Node cond_enum,
503
                                               SygusUnifRl* unif,
504
                                               SygusUnifStrategy* strategy,
505
                                               unsigned strategy_index)
506
{
507
17
  d_cond_enum = cond_enum;
508
17
  d_unif = unif;
509
17
  d_strategy = strategy;
510
17
  d_strategy_index = strategy_index;
511
17
  d_true = NodeManager::currentNM()->mkConst(true);
512
17
  d_false = NodeManager::currentNM()->mkConst(false);
513
  // Retrieve template
514
17
  EnumInfo& eiv = d_strategy->getEnumInfo(d_cond_enum);
515
17
  d_template = NodePair(eiv.d_template, eiv.d_template_arg);
516
  // Initialize classifier
517
17
  d_pt_sep.initialize(this);
518
17
}
519
520
1348
void SygusUnifRl::DecisionTreeInfo::setConditions(
521
    Node guard, const std::vector<Node>& enums, const std::vector<Node>& conds)
522
{
523
1348
  Assert(enums.size() == conds.size());
524
  // set the guard
525
1348
  d_guard = guard;
526
  // clear old condition values
527
1348
  d_enums.clear();
528
1348
  d_conds.clear();
529
  // set new condition values
530
1348
  d_enums.insert(d_enums.end(), enums.begin(), enums.end());
531
1348
  d_conds.insert(d_conds.end(), conds.begin(), conds.end());
532
  // add to condition pool
533
1348
  if (d_unif->usingConditionPool())
534
  {
535
    d_cond_mvs.insert(conds.begin(), conds.end());
536
    if (Trace.isOn("sygus-unif-cond-pool"))
537
    {
538
      for (const Node& condv : conds)
539
      {
540
        if (d_cond_mvs.find(condv) == d_cond_mvs.end())
541
        {
542
          Trace("sygus-unif-cond-pool")
543
              << "  ...adding to condition pool : "
544
              << d_unif->d_tds->sygusToBuiltin(condv, condv.getType()) << "\n";
545
        }
546
      }
547
    }
548
  }
549
1348
}
550
551
1334
unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const
552
{
553
1334
  return d_strategy_index;
554
}
555
556
1334
Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
557
                                             std::vector<Node>& lemmas)
558
{
559
1334
  if (!d_template.first.isNull())
560
  {
561
    Trace("sygus-unif-sol") << "...templated conditions unsupported\n";
562
    return Node::null();
563
  }
564
2668
  Trace("sygus-unif-sol") << "Decision::buildSol with " << d_hds.size()
565
2668
                          << " evaluation heads and " << d_conds.size()
566
1334
                          << " conditions..." << std::endl;
567
  // reset the trie
568
1334
  d_pt_sep.d_trie.clear();
569
1334
  return d_unif->usingConditionPool() ? buildSolAllCond(cons, lemmas)
570
1334
                                      : buildSolMinCond(cons, lemmas);
571
}
572
573
Node SygusUnifRl::DecisionTreeInfo::buildSolAllCond(Node cons,
574
                                                    std::vector<Node>& lemmas)
575
{
576
  // model values for evaluation heads
577
  std::map<Node, Node> hd_mv;
578
  // add conditions
579
  d_conds.clear();
580
  d_conds.insert(d_conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
581
  // shuffle conditions before bulding DT
582
  //
583
  // this does not impact whether it's possible to build a solution, but it does
584
  // impact the potential size of the resulting solution (can make it smaller,
585
  // bigger, or have no impact) and which conditions will be present in the DT,
586
  // which influences the "quality" of the solution for cases not covered in the
587
  // current data points
588
  if (options::sygusUnifShuffleCond())
589
  {
590
    std::shuffle(d_conds.begin(), d_conds.end(), Random::getRandom());
591
  }
592
  unsigned num_conds = d_conds.size();
593
  for (unsigned i = 0; i < num_conds; ++i)
594
  {
595
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, i);
596
  }
597
  // add heads
598
  for (const Node& e : d_hds)
599
  {
600
    Node v = d_unif->d_parent->getModelValue(e);
601
    hd_mv[e] = v;
602
    Node er = d_pt_sep.d_trie.add(e, &d_pt_sep, num_conds);
603
    // are we in conflict?
604
    if (er == e)
605
    {
606
      // new separation class, no conflict
607
      continue;
608
    }
609
    Assert(hd_mv.find(er) != hd_mv.end());
610
    // merged into separation class with same model value, no conflict
611
    if (hd_mv[e] == hd_mv[er])
612
    {
613
      continue;
614
    }
615
    // conflict. Explanation?
616
    Trace("sygus-unif-sol")
617
        << "  ...can't separate " << e << " from " << er << std::endl;
618
    return Node::null();
619
  }
620
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
621
  Node sol = extractSol(cons, hd_mv);
622
  // repeated solution
623
  if (options::sygusUnifCondIndNoRepeatSol()
624
      && d_sols.find(sol) != d_sols.end())
625
  {
626
    return Node::null();
627
  }
628
  d_sols.insert(sol);
629
  return sol;
630
}
631
632
1334
Node SygusUnifRl::DecisionTreeInfo::buildSolMinCond(Node cons,
633
                                                    std::vector<Node>& lemmas)
634
{
635
1334
  NodeManager* nm = NodeManager::currentNM();
636
  // model values for evaluation heads
637
2668
  std::map<Node, Node> hd_mv;
638
  // the current explanation of why there has not yet been a separation conflict
639
2668
  std::vector<Node> exp;
640
  // is the above explanation ready to be sent out as a lemma?
641
1334
  bool exp_conflict = false;
642
  // the index of the head we are considering
643
1334
  unsigned hd_counter = 0;
644
  // the index of the condition we are considering
645
1334
  unsigned c_counter = 0;
646
  // do we need to resolve a separation conflict?
647
1334
  bool needs_sep_resolve = false;
648
  // This loop simultaneously builds the solution in terms of a lazy trie
649
  // (LazyTrieMulti), and checks whether a separation conflict exists. We
650
  // enforce that the separation conflicts we encounter while building
651
  // this solution are resolved, in order, by the condition enumerators.
652
  // If not, then we add a (conflict) lemma stating that the current model
653
  // value of the condition enumerator must be different. We also call this
654
  // a "separation lemma".
655
  //
656
  // As a simple example, say we have:
657
  //   evalution heads: (eval e1 0 0), (eval e2 1 2)
658
  //   conditions: c1
659
  // where M(e1) = x, M(e2) = y, and M(c1) = x>1. After adding e1 and e2, we are
660
  // in conflict since { e1, e2 } form a separation class, M(e1)!=M(e2), and
661
  // M(c1) does not separate e1 and e2 since:
662
  //   (x>1){x->0,y->0} = (x>1){x->1,y->2} = false
663
  // Hence, we would fail to build a solution in this case, and instead send a
664
  // separation lemma of the form:
665
  //   ~( e1 != e2 ^ c1 = [x<1] )
666
  //
667
  // Say we have:
668
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
669
  //   conditions: c1 c2
670
  // where M(e1) = x, M(e2) = y, M(e3) = x+1, M(c1) = x>0 and M(c2) = x<0.
671
  // After adding e1 and e2, { e1, e2 } form a separation class, M(e1)!=M(e2),
672
  // but M(c1) separates e1 and e2 since
673
  //   (x>0){x->0,y->0} = false, and
674
  //   (x>1){x->1,y->2} = true
675
  // Hence, we get new separation classes { e1 } and { e2 }, and afterwards
676
  // add e3. We then get { e2, e3 } as a separation class, which is also a
677
  // conflict since M(e2)!=M(e3). We check if M(c2) resolves this conflict.
678
  // It does not, since (x<1){x->0,y->0} = (x<1){x->1,y->2} = false. Hence,
679
  // we get a separation lemma:
680
  //  ~( c1 = [x>1] ^ e2 != e3 ^ c2 = [x<1] )
681
  //
682
  // Say we have:
683
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
684
  //   conditions: c1
685
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = x>0.
686
  // After adding e1 and e2, we have separation class { e1, e2 }. This is not a
687
  // conflict since M(e1)=M(e2). We then add e3, obtaining separation class
688
  // { e1, e2, e3 }, which is in conflict since M(e3)!=M(e1), and the condition
689
  // c1 does not separate e3 and the representative of this class, e1. Hence we
690
  // get a separation lemma of the form:
691
  //  ~( e1 = e2 ^ e1 != e3 ^ c1 = [x>0] )
692
  //
693
  // It also may be the case that we exhaust the pool of condition enumerators.
694
  // Say we have:
695
  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
696
  //   conditions: c1
697
  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = y>0. After adding e1, e2,
698
  // and e3, we have a separation class { e1, e2, e3 } that is in conflict
699
  // since M(e3)!=M(e1). We add the condition c1, which separates into new
700
  // equivalence classes { e1 }, { e2, e3 }. We are still in separation conflict
701
  // since M(e3)!=M(e2). However, we do not have any further conditions to use
702
  // to resolve this conflict. Thus, we add the separation lemma:
703
  //  ~( e1 = e2 ^ e1 != e3 ^ e2 != e3 ^ c1 = [y>0] ^ G_1 )
704
  // where G_1 is a guard stating that we use at most 1 condition.
705
2668
  Node e;
706
2668
  Node er;
707
15546
  while (hd_counter < d_hds.size() || needs_sep_resolve)
708
  {
709
8394
    if (!needs_sep_resolve)
710
    {
711
      // add the head to the trie
712
8322
      e = d_hds[hd_counter];
713
8322
      hd_mv[e] = d_unif->d_parent->getModelValue(e);
714
8322
      if (Trace.isOn("sygus-unif-sol"))
715
      {
716
        std::stringstream ss;
717
        TermDbSygus::toStreamSygus(ss, hd_mv[e]);
718
        Trace("sygus-unif-sol")
719
            << "  add evaluation head (" << hd_counter << "/" << d_hds.size()
720
            << "): " << e << " -> " << ss.str() << std::endl;
721
      }
722
8322
      hd_counter++;
723
      // get the representative of the trie
724
8322
      er = d_pt_sep.d_trie.add(e, &d_pt_sep, c_counter);
725
8322
      Trace("sygus-unif-sol") << "  ...separation class " << er << std::endl;
726
      // are we in conflict?
727
8322
      if (er == e)
728
      {
729
        // new separation class, no conflict
730
8368
        continue;
731
      }
732
6988
      Assert(hd_mv.find(er) != hd_mv.end());
733
12006
      if (hd_mv[er] == hd_mv[e])
734
      {
735
        // merged into separation class with same model value, no conflict
736
        // add to explanation
737
        // this states that it mattered that (er = e) at the time that e was
738
        // added to the trie. Notice that er and e may become separated later,
739
        // but to ensure the overall invariant, this equality must persist in
740
        // the explanation.
741
5018
        exp.push_back(er.eqNode(e));
742
5018
        Trace("sygus-unif-sol") << "  ...equal model values " << std::endl;
743
10036
        Trace("sygus-unif-sol")
744
5018
            << "  ...add to explanation " << er.eqNode(e) << std::endl;
745
5018
        continue;
746
      }
747
    }
748
    // must include in the explanation that we hit a conflict at this point in
749
    // the construction
750
2042
    exp.push_back(e.eqNode(er).negate());
751
    // we are in separation conflict, does the next condition resolve this?
752
    // check whether we have have exhausted our condition pool. If so, we
753
    // are in conflict and this conflict depends on the guard.
754
2042
    if (c_counter >= d_conds.size())
755
    {
756
      // truncated separation lemma
757
716
      Assert(!d_guard.isNull());
758
716
      exp.push_back(d_guard);
759
716
      exp_conflict = true;
760
2004
      break;
761
    }
762
1326
    Assert(c_counter < d_conds.size());
763
1398
    Node ce = d_enums[c_counter];
764
1398
    Node cv = d_conds[c_counter];
765
1326
    Assert(ce.getType() == cv.getType());
766
1326
    if (Trace.isOn("sygus-unif-sol"))
767
    {
768
      std::stringstream ss;
769
      TermDbSygus::toStreamSygus(ss, cv);
770
      Trace("sygus-unif-sol")
771
          << "  add condition (" << c_counter << "/" << d_conds.size()
772
          << "): " << ce << " -> " << ss.str() << std::endl;
773
    }
774
    // cache the separation class
775
1398
    std::vector<Node> prev_sep_c = d_pt_sep.d_trie.d_rep_to_class[er];
776
    // add new classifier
777
1326
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, c_counter);
778
1326
    c_counter++;
779
    // add to explanation
780
    // c_exp is a conjunction of testers applied to shared selector chains
781
1398
    Node c_exp = d_unif->d_tds->getExplain()->getExplanationForEquality(ce, cv);
782
1326
    exp.push_back(c_exp);
783
    std::map<Node, std::vector<Node>>::iterator itr =
784
1326
        d_pt_sep.d_trie.d_rep_to_class.find(e);
785
    // since e is last in its separation class, if it becomes a representative,
786
    // then it is separated from all values in prev_sep_c
787
2008
    if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
788
    {
789
1364
      Trace("sygus-unif-sol")
790
682
          << "  ...resolves separation conflict with all" << std::endl;
791
682
      needs_sep_resolve = false;
792
682
      continue;
793
    }
794
644
    itr = d_pt_sep.d_trie.d_rep_to_class.find(er);
795
    // since er is first in its separation class, it remains a representative
796
644
    Assert(itr != d_pt_sep.d_trie.d_rep_to_class.end());
797
    // is e still in the separation class of er?
798
1932
    if (std::find(itr->second.begin(), itr->second.end(), e)
799
1932
        != itr->second.end())
800
    {
801
1144
      Trace("sygus-unif-sol")
802
572
          << "  ...does not resolve separation conflict with current"
803
572
          << std::endl;
804
      // the condition does not separate e and er
805
      // this violates the invariant that the i^th conditional enumerator
806
      // resolves the i^th separation conflict
807
572
      exp_conflict = true;
808
572
      SygusTypeInfo& ti = d_unif->d_tds->getTypeInfo(ce.getType());
809
      // The reasoning below is only necessary if we use symbolic constructors.
810
572
      if (!ti.hasSubtermSymbolicCons())
811
      {
812
572
        break;
813
      }
814
      // Since the explanation of the condition (c_exp above) does not account
815
      // for builtin subterms, we additionally require that the valuation of
816
      // the condition is indeed different on the two points.
817
      // For example, say ce has model value equal to the SyGuS datatype term:
818
      //   C_leq_xy( 0, 1 )
819
      // where C_leq_xy is a SyGuS datatype constructor taking two integer
820
      // constants c_x and c_y, and whose builtin version is:
821
      //   (0*x + 1*y >= 0)
822
      // Then, c_exp above is:
823
      //   is-C_leq_xy( ce )
824
      // which is added to our explanation of the conflict, which does not
825
      // account for the values of the arguments of C_leq_xy.
826
      // Now, say that we are in a separation conflict due to f(1,2) and f(2,3)
827
      // being assigned different values; the value of ce does not separate
828
      // these two terms since:
829
      //   (y>=0) { x -> 1, y -> 2 } = (y>=0) { x -> 2, y -> 3 } = true
830
      // The code below adds a constraint that states that the above values are
831
      // the same, which is part of the reason for the conflict. In the above
832
      // example, we generate:
833
      //   (DT_SYGUS_EVAL ce 1 2) == (DT_SYGUS_EVAL ce 2 3) { ce -> M(ce) }
834
      // which unfolds via the SygusEvalUnfold utility to:
835
      //   ( (c_x ce)*1 + (c_y ce)*2 >= 0 ) == ( (c_x ce)*2 + (c_y ce)*3 >= 0 )
836
      // where c_x and c_y are the selectors of the subfields of C_leq_xy.
837
      Trace("sygus-unif-sol-sym")
838
          << "Explain symbolic separation conflict" << std::endl;
839
      std::map<Node, std::vector<Node>>::iterator ith;
840
      Node ceApp[2];
841
      SygusEvalUnfold* eunf = d_unif->d_tds->getEvalUnfold();
842
      std::map<Node, Node> vtm;
843
      vtm[ce] = cv;
844
      Trace("sygus-unif-sol-sym")
845
          << "Model value for " << ce << " is " << cv << std::endl;
846
      for (unsigned r = 0; r < 2; r++)
847
      {
848
        std::vector<Node> cechildren;
849
        cechildren.push_back(ce);
850
        Node ecurr = r == 0 ? e : er;
851
        ith = d_unif->d_hd_to_pt.find(ecurr);
852
        AlwaysAssert(ith != d_unif->d_hd_to_pt.end());
853
        cechildren.insert(
854
            cechildren.end(), ith->second.begin(), ith->second.end());
855
        Node cea = nm->mkNode(DT_SYGUS_EVAL, cechildren);
856
        Trace("sygus-unif-sol-sym")
857
            << "Sep conflict app #" << r << " : " << cea << std::endl;
858
        std::vector<Node> tmpExp;
859
        cea = eunf->unfold(cea, vtm, tmpExp, true, true);
860
        Trace("sygus-unif-sol-sym") << "Unfolded to : " << cea << std::endl;
861
        ceApp[r] = cea;
862
      }
863
      Node ceAppEq = ceApp[0].eqNode(ceApp[1]);
864
      Trace("sygus-unif-sol-sym")
865
          << "Sep conflict app explanation is : " << ceAppEq << std::endl;
866
      exp.push_back(ceAppEq);
867
      break;
868
    }
869
144
    Trace("sygus-unif-sol")
870
72
        << "  ...resolves separation conflict with current, but not all"
871
72
        << std::endl;
872
    // find the new term to resolve a separation
873
144
    Node new_er = Node::null();
874
    // scan the previous list and find the representative of the class that e is
875
    // now in
876
160
    for (unsigned i = 0, size = prev_sep_c.size(); i < size; i++)
877
    {
878
248
      Node check_er = prev_sep_c[i];
879
160
      if (check_er != er && check_er != e)
880
      {
881
88
        itr = d_pt_sep.d_trie.d_rep_to_class.find(check_er);
882
88
        if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
883
        {
884
216
          if (std::find(itr->second.begin(), itr->second.end(), e)
885
216
              != itr->second.end())
886
          {
887
72
            new_er = check_er;
888
72
            break;
889
          }
890
        }
891
      }
892
    }
893
    // should find exactly one
894
72
    Assert(!new_er.isNull());
895
72
    er = new_er;
896
72
    needs_sep_resolve = true;
897
  }
898
1334
  if (exp_conflict)
899
  {
900
2576
    Node lemma = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp);
901
1288
    lemma = lemma.negate();
902
1288
    Trace("sygus-unif-sol") << "  ......conflict is " << lemma << std::endl;
903
1288
    lemmas.push_back(lemma);
904
1288
    return Node::null();
905
  }
906
907
46
  Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
908
46
  return extractSol(cons, hd_mv);
909
}
910
911
46
Node SygusUnifRl::DecisionTreeInfo::extractSol(Node cons,
912
                                               std::map<Node, Node>& hd_mv)
913
{
914
  // rebuild decision tree using heuristic learning
915
46
  if (d_unif->usingConditionPoolInfoGain())
916
  {
917
    recomputeSolHeuristically(hd_mv);
918
  }
919
46
  return d_pt_sep.extractSol(cons, hd_mv);
920
}
921
922
46
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::extractSol(
923
    Node cons, std::map<Node, Node>& hd_mv)
924
{
925
  // Traverse trie and build ITE with cons
926
46
  NodeManager* nm = NodeManager::currentNM();
927
92
  std::map<IndTriePair, Node> cache;
928
46
  std::map<IndTriePair, Node>::iterator it;
929
92
  std::vector<IndTriePair> visit;
930
46
  unsigned index = 0;
931
  LazyTrie* trie;
932
46
  IndTriePair root = IndTriePair(0, &d_trie.d_trie);
933
46
  visit.push_back(root);
934
362
  while (!visit.empty())
935
  {
936
158
    index = visit.back().first;
937
158
    trie = visit.back().second;
938
158
    visit.pop_back();
939
158
    IndTriePair cur = IndTriePair(index, trie);
940
158
    it = cache.find(cur);
941
    // traverse children so results are saved to build node for parent
942
158
    if (it == cache.end())
943
    {
944
      // leaf
945
202
      if (trie->d_children.empty())
946
      {
947
82
        Assert(hd_mv.find(trie->d_lazy_child) != hd_mv.end());
948
82
        cache[cur] = hd_mv[trie->d_lazy_child];
949
164
        Trace("sygus-unif-sol-debug") << "......leaf, build "
950
246
                                      << d_dt->d_unif->d_tds->sygusToBuiltin(
951
246
                                             cache[cur], cache[cur].getType())
952
82
                                      << "\n";
953
204
        continue;
954
      }
955
38
      cache[cur] = Node::null();
956
38
      visit.push_back(cur);
957
112
      for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
958
      {
959
74
        visit.push_back(IndTriePair(index + 1, &p_nt.second));
960
      }
961
38
      continue;
962
    }
963
    // retrieve terms of children and build result
964
38
    Assert(it->second.isNull());
965
38
    Assert(trie->d_children.size() == 1 || trie->d_children.size() == 2);
966
74
    std::vector<Node> children(4);
967
38
    children[0] = cons;
968
38
    children[1] = d_dt->d_conds[index];
969
38
    unsigned i = 0;
970
112
    for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
971
    {
972
74
      i = p_nt.first.getConst<bool>() ? 2 : 3;
973
74
      Assert(cache.find(IndTriePair(index + 1, &p_nt.second)) != cache.end());
974
74
      children[i] = cache[IndTriePair(index + 1, &p_nt.second)];
975
74
      Assert(!children[i].isNull());
976
    }
977
    // condition is useless or result children are equal, no no need for ITE
978
40
    if (trie->d_children.size() == 1 || children[2] == children[3])
979
    {
980
2
      cache[cur] = children[i];
981
4
      Trace("sygus-unif-sol-debug")
982
2
          << "......no need for cond "
983
6
          << d_dt->d_unif->d_tds->sygusToBuiltin(d_dt->d_conds[index],
984
6
                                                 d_dt->d_conds[index].getType())
985
2
          << ", build "
986
6
          << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur],
987
6
                                                 cache[cur].getType())
988
2
          << "\n";
989
2
      continue;
990
    }
991
36
    Assert(trie->d_children.size() == 2);
992
36
    cache[cur] = nm->mkNode(APPLY_CONSTRUCTOR, children);
993
72
    Trace("sygus-unif-sol-debug")
994
36
        << "......build node "
995
72
        << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur], cache[cur].getType())
996
36
        << "\n";
997
  }
998
46
  Assert(cache.find(root) != cache.end());
999
46
  Assert(!cache.find(root)->second.isNull());
1000
92
  return cache[root];
1001
}
1002
1003
void SygusUnifRl::DecisionTreeInfo::recomputeSolHeuristically(
1004
    std::map<Node, Node>& hd_mv)
1005
{
1006
  // reset the trie
1007
  d_pt_sep.d_trie.clear();
1008
  // TODO workaround and not really sure this is the last condition, since I put
1009
  // a set here. Maybe make d_cond_mvs into a vector
1010
  Node backup_last_cond = d_conds.back();
1011
  d_conds.clear();
1012
  for (const Node& e : d_hds)
1013
  {
1014
    d_pt_sep.d_trie.add(e, &d_pt_sep, 0);
1015
  }
1016
  // init vector of conds
1017
  std::vector<Node> conds;
1018
  conds.insert(conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
1019
1020
  // recursively build trie by picking best condition for respective points
1021
  buildDtInfoGain(d_hds, conds, hd_mv, 1);
1022
  // if no condition was added (i.e. points are already classified at root
1023
  // level), use last condition as candidate
1024
  if (d_conds.empty())
1025
  {
1026
    Trace("sygus-unif-dt") << "......using last condition "
1027
                           << d_unif->d_tds->sygusToBuiltin(
1028
                                  backup_last_cond, backup_last_cond.getType())
1029
                           << " as candidate\n";
1030
    d_conds.push_back(backup_last_cond);
1031
    d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1032
  }
1033
}
1034
1035
void SygusUnifRl::DecisionTreeInfo::buildDtInfoGain(std::vector<Node>& hds,
1036
                                                    std::vector<Node> conds,
1037
                                                    std::map<Node, Node>& hd_mv,
1038
                                                    int ind)
1039
{
1040
  // test if fully classified
1041
  if (hds.size() < 2)
1042
  {
1043
    indent("sygus-unif-dt", ind);
1044
    Trace("sygus-unif-dt") << "..set fully classified: "
1045
                           << (hds.empty() ? "empty" : "unary") << "\n";
1046
    return;
1047
  }
1048
  Node v1 = hd_mv[hds[0]];
1049
  unsigned i = 1, size = hds.size();
1050
  for (; i < size; ++i)
1051
  {
1052
    if (hd_mv[hds[i]] != v1)
1053
    {
1054
      break;
1055
    }
1056
  }
1057
  if (i == size)
1058
  {
1059
    indent("sygus-unif-dt", ind);
1060
    Trace("sygus-unif-dt") << "..set fully classified: " << hds.size() << " "
1061
                           << (d_unif->d_tds->sygusToBuiltin(v1, v1.getType())
1062
                                       == d_true
1063
                                   ? "good"
1064
                                   : "bad")
1065
                           << " points\n";
1066
    return;
1067
  }
1068
  // pick condition to further classify
1069
  double maxgain = -1;
1070
  unsigned picked_cond = 0;
1071
  std::vector<std::pair<std::vector<Node>, std::vector<Node>>> splits;
1072
  double current_set_entropy = getEntropy(hds, hd_mv, ind);
1073
  for (unsigned j = 0, conds_size = conds.size(); j < conds_size; ++j)
1074
  {
1075
    std::pair<std::vector<Node>, std::vector<Node>> split =
1076
        evaluateCond(hds, conds[j]);
1077
    splits.push_back(split);
1078
    Assert(hds.size() == split.first.size() + split.second.size());
1079
    double gain =
1080
        current_set_entropy
1081
        - (split.first.size() * getEntropy(split.first, hd_mv, ind)
1082
           + split.second.size() * getEntropy(split.second, hd_mv, ind))
1083
              / hds.size();
1084
    indent("sygus-unif-dt-debug", ind);
1085
    Trace("sygus-unif-dt-debug")
1086
        << "..gain of "
1087
        << d_unif->d_tds->sygusToBuiltin(conds[j], conds[j].getType()) << " is "
1088
        << gain << "\n";
1089
    if (gain > maxgain)
1090
    {
1091
      maxgain = gain;
1092
      picked_cond = j;
1093
    }
1094
  }
1095
  // add picked condition
1096
  indent("sygus-unif-dt", ind);
1097
  Trace("sygus-unif-dt") << "..picked condition "
1098
                         << d_unif->d_tds->sygusToBuiltin(
1099
                                conds[picked_cond],
1100
                                conds[picked_cond].getType())
1101
                         << "\n";
1102
  d_conds.push_back(conds[picked_cond]);
1103
  conds.erase(conds.begin() + picked_cond);
1104
  d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1105
  // recurse
1106
  buildDtInfoGain(splits[picked_cond].first, conds, hd_mv, ind + 1);
1107
  buildDtInfoGain(splits[picked_cond].second, conds, hd_mv, ind + 1);
1108
}
1109
1110
std::pair<std::vector<Node>, std::vector<Node>>
1111
SygusUnifRl::DecisionTreeInfo::evaluateCond(std::vector<Node>& pts, Node cond)
1112
{
1113
  std::vector<Node> good, bad;
1114
  for (const Node& pt : pts)
1115
  {
1116
    if (d_pt_sep.computeCond(cond, pt) == d_true)
1117
    {
1118
      good.push_back(pt);
1119
      continue;
1120
    }
1121
    Assert(d_pt_sep.computeCond(cond, pt) == d_false);
1122
    bad.push_back(pt);
1123
  }
1124
  return std::pair<std::vector<Node>, std::vector<Node>>(good, bad);
1125
}
1126
1127
double SygusUnifRl::DecisionTreeInfo::getEntropy(const std::vector<Node>& hds,
1128
                                                 std::map<Node, Node>& hd_mv,
1129
                                                 int ind)
1130
{
1131
  double p = 0, n = 0;
1132
  TermDbSygus* tds = d_unif->d_tds;
1133
  // get number of points evaluated positively and negatively with feature
1134
  for (const Node& e : hds)
1135
  {
1136
    if (tds->sygusToBuiltin(hd_mv[e]) == d_true)
1137
    {
1138
      p++;
1139
      continue;
1140
    }
1141
    Assert(tds->sygusToBuiltin(hd_mv[e]) == d_false);
1142
    n++;
1143
  }
1144
  // compute entropy
1145
  return p == 0 || n == 0 ? 0
1146
                          : ((-p / (p + n)) * log2(p / (p + n)))
1147
                                - ((n / (p + n)) * log2(n / (p + n)));
1148
}
1149
1150
17
void SygusUnifRl::DecisionTreeInfo::PointSeparator::initialize(
1151
    DecisionTreeInfo* dt)
1152
{
1153
17
  d_dt = dt;
1154
17
}
1155
1156
8308
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::evaluate(Node n,
1157
                                                             unsigned index)
1158
{
1159
8308
  Assert(index < d_dt->d_conds.size());
1160
  // Retrieve respective built_in condition
1161
16616
  Node cond = d_dt->d_conds[index];
1162
16616
  return computeCond(cond, n);
1163
}
1164
1165
8308
Node SygusUnifRl::DecisionTreeInfo::PointSeparator::computeCond(Node cond,
1166
                                                                Node hd)
1167
{
1168
16616
  std::pair<Node, Node> cond_hd = std::pair<Node, Node>(cond, hd);
1169
  std::map<std::pair<Node, Node>, Node>::iterator it =
1170
8308
      d_eval_cond_hd.find(cond_hd);
1171
8308
  if (it != d_eval_cond_hd.end())
1172
  {
1173
6064
    return it->second;
1174
  }
1175
4488
  TypeNode tn = cond.getType();
1176
4488
  Node builtin_cond = d_dt->d_unif->d_tds->sygusToBuiltin(cond, tn);
1177
  // Retrieve evaluation point
1178
2244
  Assert(d_dt->d_unif->d_hd_to_pt.find(hd) != d_dt->d_unif->d_hd_to_pt.end());
1179
4488
  std::vector<Node> pt = d_dt->d_unif->d_hd_to_pt[hd];
1180
  // compute the result
1181
2244
  if (Trace.isOn("sygus-unif-rl-sep"))
1182
  {
1183
    Trace("sygus-unif-rl-sep")
1184
        << "Evaluate cond " << builtin_cond << " on pt " << hd << " ( ";
1185
    for (const Node& pti : pt)
1186
    {
1187
      Trace("sygus-unif-rl-sep") << pti << " ";
1188
    }
1189
    Trace("sygus-unif-rl-sep") << ")\n";
1190
  }
1191
4488
  Node res = d_dt->d_unif->d_tds->evaluateBuiltin(tn, builtin_cond, pt);
1192
2244
  Trace("sygus-unif-rl-sep") << "...got res = " << res << "\n";
1193
  // If condition is templated, recompute result accordingly
1194
4488
  Node templ = d_dt->d_template.first;
1195
4488
  TNode templ_var = d_dt->d_template.second;
1196
2244
  if (!templ.isNull())
1197
  {
1198
    res = templ.substitute(templ_var, res);
1199
    res = Rewriter::rewrite(res);
1200
    Trace("sygus-unif-rl-sep")
1201
        << "...after template res = " << res << std::endl;
1202
  }
1203
2244
  Assert(res.isConst());
1204
2244
  d_eval_cond_hd[cond_hd] = res;
1205
2244
  return res;
1206
}
1207
1208
}  // namespace quantifiers
1209
}  // namespace theory
1210
26703
}  // namespace CVC4