GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus_inst.cpp Lines: 211 254 83.1 %
Date: 2021-03-22 Branches: 379 1020 37.2 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file sygus_inst.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Mathias Preiner, Aina Niemetz, Andrew Reynolds
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief SyGuS instantiator class.
13
 **/
14
15
#include "theory/quantifiers/sygus_inst.h"
16
17
#include <sstream>
18
#include <unordered_set>
19
20
#include "expr/node_algorithm.h"
21
#include "options/quantifiers_options.h"
22
#include "theory/bv/theory_bv_utils.h"
23
#include "theory/datatypes/sygus_datatype_utils.h"
24
#include "theory/quantifiers/first_order_model.h"
25
#include "theory/quantifiers/sygus/sygus_enumerator.h"
26
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
27
#include "theory/quantifiers/sygus/synth_engine.h"
28
#include "theory/quantifiers/term_util.h"
29
#include "theory/quantifiers_engine.h"
30
#include "theory/rewriter.h"
31
32
namespace CVC4 {
33
namespace theory {
34
namespace quantifiers {
35
36
namespace {
37
38
/**
39
 * Collect maximal ground terms with type tn in node n.
40
 *
41
 * @param n: Node to traverse.
42
 * @param tn: Collects only terms with type tn.
43
 * @param terms: Collected terms.
44
 * @param cache: Caches visited nodes.
45
 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
46
 */
47
41
void getMaxGroundTerms(TNode n,
48
                       TypeNode tn,
49
                       std::unordered_set<Node, NodeHashFunction>& terms,
50
                       std::unordered_set<TNode, TNodeHashFunction>& cache,
51
                       bool skip_quant = false)
52
{
53
205
  if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
54
41
      && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
55
  {
56
41
    return;
57
  }
58
59
  Trace("sygus-inst-term") << "Find maximal terms with type " << tn
60
                           << " in: " << n << std::endl;
61
62
  Node cur;
63
  std::vector<TNode> visit;
64
65
  visit.push_back(n);
66
  do
67
  {
68
    cur = visit.back();
69
    visit.pop_back();
70
71
    if (cache.find(cur) != cache.end())
72
    {
73
      continue;
74
    }
75
    cache.insert(cur);
76
77
    if (expr::hasBoundVar(cur) || cur.getType() != tn)
78
    {
79
      if (!skip_quant || cur.getKind() != kind::FORALL)
80
      {
81
        visit.insert(visit.end(), cur.begin(), cur.end());
82
      }
83
    }
84
    else
85
    {
86
      terms.insert(cur);
87
      Trace("sygus-inst-term") << "  found: " << cur << std::endl;
88
    }
89
  } while (!visit.empty());
90
}
91
92
/*
93
 * Collect minimal ground terms with type tn in node n.
94
 *
95
 * @param n: Node to traverse.
96
 * @param tn: Collects only terms with type tn.
97
 * @param terms: Collected terms.
98
 * @param cache: Caches visited nodes and flags indicating whether a minimal
99
 *               term was already found in a subterm.
100
 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
101
 */
102
41
void getMinGroundTerms(
103
    TNode n,
104
    TypeNode tn,
105
    std::unordered_set<Node, NodeHashFunction>& terms,
106
    std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache,
107
    bool skip_quant = false)
108
{
109
82
  if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
110
41
      && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
111
  {
112
    return;
113
  }
114
115
82
  Trace("sygus-inst-term") << "Find minimal terms with type " << tn
116
41
                           << " in: " << n << std::endl;
117
118
82
  Node cur;
119
82
  std::vector<TNode> visit;
120
121
41
  visit.push_back(n);
122
80030
  do
123
  {
124
80071
    cur = visit.back();
125
80071
    visit.pop_back();
126
127
80071
    auto it = cache.find(cur);
128
80071
    if (it == cache.end())
129
    {
130
10572
      cache.emplace(cur, std::make_pair(false, false));
131
5286
      if (!skip_quant || cur.getKind() != kind::FORALL)
132
      {
133
5286
        visit.push_back(cur);
134
5286
        visit.insert(visit.end(), cur.begin(), cur.end());
135
      }
136
    }
137
    /* up-traversal */
138
74785
    else if (!it->second.first)
139
    {
140
5286
      bool found_min_term = false;
141
142
      /* Check if we found a minimal term in one of the children. */
143
24967
      for (const Node& c : cur)
144
      {
145
23386
        found_min_term |= cache[c].second;
146
23386
        if (found_min_term) break;
147
      }
148
149
      /* If we haven't found a minimal term yet, add this term if it has the
150
       * right type. */
151
5286
      if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term)
152
      {
153
133
        terms.insert(cur);
154
133
        found_min_term = true;
155
133
        Trace("sygus-inst-term") << "  found: " << cur << std::endl;
156
      }
157
158
5286
      it->second.first = true;
159
5286
      it->second.second = found_min_term;
160
    }
161
80071
  } while (!visit.empty());
162
}
163
164
/*
165
 * Add special values for a given type.
166
 *
167
 * @param tn: The type node.
168
 * @param extra_cons: A map of TypeNode to constants, which are added in
169
 *                    addition to the default grammar.
170
 */
171
121
void addSpecialValues(
172
    const TypeNode& tn,
173
    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons)
174
{
175
121
  if (tn.isBitVector())
176
  {
177
6
    uint32_t size = tn.getBitVectorSize();
178
6
    extra_cons[tn].insert(bv::utils::mkOnes(size));
179
6
    extra_cons[tn].insert(bv::utils::mkMinSigned(size));
180
6
    extra_cons[tn].insert(bv::utils::mkMaxSigned(size));
181
  }
182
121
}
183
184
}  // namespace
185
186
19
SygusInst::SygusInst(QuantifiersEngine* qe,
187
                     QuantifiersState& qs,
188
                     QuantifiersInferenceManager& qim,
189
19
                     QuantifiersRegistry& qr)
190
    : QuantifiersModule(qs, qim, qr, qe),
191
19
      d_ce_lemma_added(qs.getUserContext()),
192
19
      d_global_terms(qs.getUserContext()),
193
57
      d_notified_assertions(qs.getUserContext())
194
{
195
19
}
196
197
399
bool SygusInst::needsCheck(Theory::Effort e)
198
{
199
399
  return e >= Theory::EFFORT_LAST_CALL;
200
}
201
202
85
QuantifiersModule::QEffort SygusInst::needsModel(Theory::Effort e)
203
{
204
85
  return QEFFORT_STANDARD;
205
}
206
207
85
void SygusInst::reset_round(Theory::Effort e)
208
{
209
85
  d_active_quant.clear();
210
85
  d_inactive_quant.clear();
211
212
85
  FirstOrderModel* model = d_quantEngine->getModel();
213
85
  uint32_t nasserted = model->getNumAssertedQuantifiers();
214
215
202
  for (uint32_t i = 0; i < nasserted; ++i)
216
  {
217
234
    Node q = model->getAssertedQuantifier(i);
218
219
117
    if (model->isQuantifierActive(q))
220
    {
221
89
      d_active_quant.insert(q);
222
178
      Node lit = getCeLiteral(q);
223
224
      bool value;
225
89
      if (d_qstate.getValuation().hasSatValue(lit, value))
226
      {
227
89
        if (!value)
228
        {
229
3
          if (!d_qstate.getValuation().isDecision(lit))
230
          {
231
3
            model->setQuantifierActive(q, false);
232
3
            d_active_quant.erase(q);
233
3
            d_inactive_quant.insert(q);
234
3
            Trace("sygus-inst") << "Set inactive: " << q << std::endl;
235
          }
236
        }
237
      }
238
    }
239
  }
240
85
}
241
242
183
void SygusInst::check(Theory::Effort e, QEffort quant_e)
243
{
244
183
  Trace("sygus-inst") << "Check " << e << ", " << quant_e << std::endl;
245
246
183
  if (quant_e != QEFFORT_STANDARD) return;
247
248
85
  FirstOrderModel* model = d_quantEngine->getModel();
249
85
  Instantiate* inst = d_quantEngine->getInstantiate();
250
85
  TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
251
170
  SygusExplain syexplain(db);
252
85
  NodeManager* nm = NodeManager::currentNM();
253
170
  options::SygusInstMode mode = options::sygusInstMode();
254
255
171
  for (const Node& q : d_active_quant)
256
  {
257
86
    const std::vector<Node>& inst_constants = d_inst_constants.at(q);
258
86
    const std::vector<Node>& dt_evals = d_var_eval.at(q);
259
86
    Assert(inst_constants.size() == dt_evals.size());
260
86
    Assert(inst_constants.size() == q[0].getNumChildren());
261
262
172
    std::vector<Node> terms, eval_unfold_lemmas;
263
780
    for (size_t i = 0, size = q[0].getNumChildren(); i < size; ++i)
264
    {
265
1388
      Node dt_var = inst_constants[i];
266
1388
      Node dt_eval = dt_evals[i];
267
1388
      Node value = model->getValue(dt_var);
268
1388
      Node t = datatypes::utils::sygusToBuiltin(value);
269
694
      terms.push_back(t);
270
271
1388
      std::vector<Node> exp;
272
694
      syexplain.getExplanationForEquality(dt_var, value, exp);
273
1388
      Node lem;
274
694
      if (exp.empty())
275
      {
276
        lem = dt_eval.eqNode(t);
277
      }
278
      else
279
      {
280
2082
        lem = nm->mkNode(kind::IMPLIES,
281
1388
                         exp.size() == 1 ? exp[0] : nm->mkNode(kind::AND, exp),
282
1388
                         dt_eval.eqNode(t));
283
      }
284
694
      eval_unfold_lemmas.push_back(lem);
285
    }
286
287
86
    if (mode == options::SygusInstMode::PRIORITY_INST)
288
    {
289
86
      if (!inst->addInstantiation(q, terms, InferenceId::QUANTIFIERS_INST_SYQI))
290
      {
291
48
        sendEvalUnfoldLemmas(eval_unfold_lemmas);
292
      }
293
    }
294
    else if (mode == options::SygusInstMode::PRIORITY_EVAL)
295
    {
296
      if (!sendEvalUnfoldLemmas(eval_unfold_lemmas))
297
      {
298
        inst->addInstantiation(q, terms, InferenceId::QUANTIFIERS_INST_SYQI);
299
      }
300
    }
301
    else
302
    {
303
      Assert(mode == options::SygusInstMode::INTERLEAVE);
304
      inst->addInstantiation(q, terms, InferenceId::QUANTIFIERS_INST_SYQI);
305
      sendEvalUnfoldLemmas(eval_unfold_lemmas);
306
    }
307
  }
308
}
309
310
48
bool SygusInst::sendEvalUnfoldLemmas(const std::vector<Node>& lemmas)
311
{
312
48
  bool added_lemma = false;
313
469
  for (const Node& lem : lemmas)
314
  {
315
421
    Trace("sygus-inst") << "Evaluation unfolding: " << lem << std::endl;
316
421
    added_lemma |=
317
842
        d_qim.addPendingLemma(lem, InferenceId::QUANTIFIERS_SYQI_EVAL_UNFOLD);
318
  }
319
48
  return added_lemma;
320
}
321
322
3
bool SygusInst::checkCompleteFor(Node q)
323
{
324
3
  return d_inactive_quant.find(q) != d_inactive_quant.end();
325
}
326
327
36
void SygusInst::registerQuantifier(Node q)
328
{
329
36
  Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
330
331
36
  Trace("sygus-inst") << "Register " << q << std::endl;
332
333
72
  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
334
72
  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exclude_cons;
335
72
  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
336
72
  std::unordered_set<Node, NodeHashFunction> term_irrelevant;
337
338
  /* Collect relevant local ground terms for each variable type. */
339
180
  if (options::sygusInstScope() == options::SygusInstScope::IN
340
36
      || options::sygusInstScope() == options::SygusInstScope::BOTH)
341
  {
342
    std::unordered_map<TypeNode,
343
                       std::unordered_set<Node, NodeHashFunction>,
344
                       TypeNodeHashFunction>
345
72
        relevant_terms;
346
157
    for (const Node& var : q[0])
347
    {
348
242
      TypeNode tn = var.getType();
349
350
      /* Collect relevant ground terms for type tn. */
351
121
      if (relevant_terms.find(tn) == relevant_terms.end())
352
      {
353
82
        std::unordered_set<Node, NodeHashFunction> terms;
354
82
        std::unordered_set<TNode, TNodeHashFunction> cache_max;
355
        std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
356
82
            cache_min;
357
358
41
        getMinGroundTerms(q, tn, terms, cache_min);
359
41
        getMaxGroundTerms(q, tn, terms, cache_max);
360
41
        relevant_terms.emplace(tn, terms);
361
      }
362
363
      /* Add relevant ground terms to grammar. */
364
121
      auto& terms = relevant_terms[tn];
365
571
      for (const auto& t : terms)
366
      {
367
900
        TypeNode ttn = t.getType();
368
450
        extra_cons[ttn].insert(t);
369
450
        Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl;
370
      }
371
    }
372
  }
373
374
  /* Collect relevant global ground terms for each variable type. */
375
72
  if (options::sygusInstScope() == options::SygusInstScope::OUT
376
36
      || options::sygusInstScope() == options::SygusInstScope::BOTH)
377
  {
378
    for (const Node& var : q[0])
379
    {
380
      TypeNode tn = var.getType();
381
382
      /* Collect relevant ground terms for type tn. */
383
      if (d_global_terms.find(tn) == d_global_terms.end())
384
      {
385
        std::unordered_set<Node, NodeHashFunction> terms;
386
        std::unordered_set<TNode, TNodeHashFunction> cache_max;
387
        std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
388
            cache_min;
389
390
        for (const Node& a : d_notified_assertions)
391
        {
392
          getMinGroundTerms(a, tn, terms, cache_min, true);
393
          getMaxGroundTerms(a, tn, terms, cache_max, true);
394
        }
395
        d_global_terms.insert(tn, terms);
396
      }
397
398
      /* Add relevant ground terms to grammar. */
399
      auto it = d_global_terms.find(tn);
400
      if (it != d_global_terms.end())
401
      {
402
        for (const auto& t : (*it).second)
403
        {
404
          TypeNode ttn = t.getType();
405
          extra_cons[ttn].insert(t);
406
          Trace("sygus-inst")
407
              << "Adding (global) extra cons: " << t << std::endl;
408
        }
409
      }
410
    }
411
  }
412
413
  /* Construct grammar for each bound variable of 'q'. */
414
36
  Trace("sygus-inst") << "Process variables of " << q << std::endl;
415
72
  std::vector<TypeNode> types;
416
157
  for (const Node& var : q[0])
417
  {
418
121
    addSpecialValues(var.getType(), extra_cons);
419
242
    TypeNode tn = CegGrammarConstructor::mkSygusDefaultType(var.getType(),
420
242
                                                            Node(),
421
242
                                                            var.toString(),
422
                                                            extra_cons,
423
                                                            exclude_cons,
424
                                                            include_cons,
425
242
                                                            term_irrelevant);
426
121
    types.push_back(tn);
427
428
242
    Trace("sygus-inst") << "Construct (default) datatype for " << var
429
121
                        << std::endl
430
121
                        << tn << std::endl;
431
  }
432
433
36
  registerCeLemma(q, types);
434
36
}
435
436
/* Construct grammars for all bound variables of quantifier 'q'. Currently,
437
 * we use the default grammar of the variable's type.
438
 */
439
36
void SygusInst::preRegisterQuantifier(Node q)
440
{
441
36
  Trace("sygus-inst") << "preRegister " << q << std::endl;
442
36
  addCeLemma(q);
443
36
}
444
445
19
void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions)
446
{
447
66
  for (const Node& a : assertions)
448
  {
449
47
    d_notified_assertions.insert(a);
450
  }
451
19
}
452
453
/*****************************************************************************/
454
/* private methods                                                           */
455
/*****************************************************************************/
456
457
125
Node SygusInst::getCeLiteral(Node q)
458
{
459
125
  auto it = d_ce_lits.find(q);
460
125
  if (it != d_ce_lits.end())
461
  {
462
89
    return it->second;
463
  }
464
465
36
  NodeManager* nm = NodeManager::currentNM();
466
72
  Node sk = nm->mkSkolem("CeLiteral", nm->booleanType());
467
72
  Node lit = d_qstate.getValuation().ensureLiteral(sk);
468
36
  d_ce_lits[q] = lit;
469
36
  return lit;
470
}
471
472
36
void SygusInst::registerCeLemma(Node q, std::vector<TypeNode>& types)
473
{
474
36
  Assert(q[0].getNumChildren() == types.size());
475
36
  Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
476
36
  Assert(d_inst_constants.find(q) == d_inst_constants.end());
477
36
  Assert(d_var_eval.find(q) == d_var_eval.end());
478
479
36
  Trace("sygus-inst") << "Register CE Lemma for " << q << std::endl;
480
481
  /* Generate counterexample lemma for 'q'. */
482
36
  NodeManager* nm = NodeManager::currentNM();
483
36
  TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
484
485
  /* For each variable x_i of \forall x_i . P[x_i], create a fresh datatype
486
   * instantiation constant ic_i with type types[i] and wrap each ic_i in
487
   * DT_SYGUS_EVAL(ic_i), which will be used to instantiate x_i. */
488
72
  std::vector<Node> evals;
489
72
  std::vector<Node> inst_constants;
490
157
  for (size_t i = 0, size = types.size(); i < size; ++i)
491
  {
492
242
    TypeNode tn = types[i];
493
242
    TNode var = q[0][i];
494
495
    /* Create the instantiation constant and set attribute accordingly. */
496
242
    Node ic = nm->mkInstConstant(tn);
497
    InstConstantAttribute ica;
498
121
    ic.setAttribute(ica, q);
499
121
    Trace("sygus-inst") << "Create " << ic << " for " << var << std::endl;
500
501
121
    db->registerEnumerator(ic, ic, nullptr, ROLE_ENUM_MULTI_SOLUTION);
502
503
242
    std::vector<Node> args = {ic};
504
242
    Node svl = tn.getDType().getSygusVarList();
505
121
    if (!svl.isNull())
506
    {
507
      args.insert(args.end(), svl.begin(), svl.end());
508
    }
509
242
    Node eval = nm->mkNode(kind::DT_SYGUS_EVAL, args);
510
511
121
    inst_constants.push_back(ic);
512
121
    evals.push_back(eval);
513
  }
514
515
36
  d_inst_constants.emplace(q, inst_constants);
516
36
  d_var_eval.emplace(q, evals);
517
518
72
  Node lit = getCeLiteral(q);
519
36
  d_qim.addPendingPhaseRequirement(lit, true);
520
521
  /* The decision strategy for quantified formula 'q' ensures that its
522
   * counterexample literal is decided on first. It is user-context dependent.
523
   */
524
36
  Assert(d_dstrat.find(q) == d_dstrat.end());
525
  DecisionStrategy* ds = new DecisionStrategySingleton(
526
36
      "CeLiteral", lit, d_qstate.getSatContext(), d_qstate.getValuation());
527
528
36
  d_dstrat[q].reset(ds);
529
36
  d_quantEngine->getDecisionManager()->registerStrategy(
530
      DecisionManager::STRAT_QUANT_CEGQI_FEASIBLE, ds);
531
532
  /* Add counterexample lemma (lit => ~P[x_i/eval_i]) */
533
  Node body =
534
72
      q[1].substitute(q[0].begin(), q[0].end(), evals.begin(), evals.end());
535
72
  Node lem = nm->mkNode(kind::OR, lit.negate(), body.negate());
536
36
  lem = Rewriter::rewrite(lem);
537
538
36
  d_ce_lemmas.emplace(std::make_pair(q, lem));
539
36
  Trace("sygus-inst") << "Register CE Lemma: " << lem << std::endl;
540
36
}
541
542
36
void SygusInst::addCeLemma(Node q)
543
{
544
36
  Assert(d_ce_lemmas.find(q) != d_ce_lemmas.end());
545
546
  /* Already added in previous contexts. */
547
36
  if (d_ce_lemma_added.find(q) != d_ce_lemma_added.end()) return;
548
549
72
  Node lem = d_ce_lemmas[q];
550
36
  d_qim.addPendingLemma(lem, InferenceId::UNKNOWN);
551
36
  d_ce_lemma_added.insert(q);
552
36
  Trace("sygus-inst") << "Add CE Lemma: " << lem << std::endl;
553
}
554
555
}  // namespace quantifiers
556
}  // namespace theory
557
26676
}  // namespace CVC4