GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/term_tuple_enumerator.cpp Lines: 155 205 75.6 %
Date: 2021-08-16 Branches: 236 624 37.8 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   MikolasJanota, Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of an enumeration of tuples of terms for the purpose of
14
 * quantifier instantiation.
15
 */
16
#include "theory/quantifiers/term_tuple_enumerator.h"
17
18
#include <algorithm>
19
#include <functional>
20
#include <iterator>
21
#include <map>
22
#include <vector>
23
24
#include "base/map_util.h"
25
#include "base/output.h"
26
#include "options/quantifiers_options.h"
27
#include "smt/smt_statistics_registry.h"
28
#include "theory/quantifiers/index_trie.h"
29
#include "theory/quantifiers/quant_module.h"
30
#include "theory/quantifiers/relevant_domain.h"
31
#include "theory/quantifiers/term_pools.h"
32
#include "theory/quantifiers/term_registry.h"
33
#include "theory/quantifiers/term_util.h"
34
#include "util/statistics_stats.h"
35
36
namespace cvc5 {
37
38
template <typename T>
39
3795
static Cvc5ostream& operator<<(Cvc5ostream& out, const std::vector<T>& v)
40
{
41
3795
  out << "[ ";
42
3795
  std::copy(v.begin(), v.end(), std::ostream_iterator<T>(out, " "));
43
3795
  return out << "]";
44
}
45
46
/** Tracing purposes, printing a masked vector of indices. */
47
static void traceMaskedVector(const char* trace,
48
                              const char* name,
49
                              const std::vector<bool>& mask,
50
                              const std::vector<size_t>& values)
51
{
52
  Assert(mask.size() == values.size());
53
  Trace(trace) << name << " [ ";
54
  for (size_t variableIx = 0; variableIx < mask.size(); variableIx++)
55
  {
56
    if (mask[variableIx])
57
    {
58
      Trace(trace) << values[variableIx] << " ";
59
    }
60
    else
61
    {
62
      Trace(trace) << "_ ";
63
    }
64
  }
65
  Trace(trace) << "]" << std::endl;
66
}
67
68
namespace theory {
69
namespace quantifiers {
70
/**
71
 * Base class for enumerators of tuples of terms for the purpose of
72
 * quantification instantiation. The tuples are represented as tuples of
73
 * indices of  terms, where the tuple has as many elements as there are
74
 * quantified variables in the considered quantifier.
75
 *
76
 * Like so, we see a tuple as a number, where the digits may have different
77
 * ranges. The most significant digits are stored first.
78
 *
79
 * Tuples are enumerated  in a lexicographic order in stages. There are 2
80
 * possible strategies, either  all tuples in a given stage have the same sum of
81
 * digits, or, the maximum  over these digits is the same.
82
 * */
83
class TermTupleEnumeratorBase : public TermTupleEnumeratorInterface
84
{
85
 public:
86
  /** Initialize the class with the quantifier to be instantiated. */
87
1284
  TermTupleEnumeratorBase(Node quantifier, const TermTupleEnumeratorEnv* env)
88
1284
      : d_quantifier(quantifier),
89
2568
        d_variableCount(d_quantifier[0].getNumChildren()),
90
        d_env(env),
91
        d_stepCounter(0),
92
        d_disabledCombinations(
93
3852
            true)  // do not record combinations with no blanks
94
95
  {
96
1284
    d_changePrefix = d_variableCount;
97
1284
  }
98
99
1284
  virtual ~TermTupleEnumeratorBase() = default;
100
101
  // implementation of the TermTupleEnumeratorInterface
102
  virtual void init() override;
103
  virtual bool hasNext() override;
104
  virtual void next(/*out*/ std::vector<Node>& terms) override;
105
  virtual void failureReason(const std::vector<bool>& mask) override;
106
  // end of implementation of the TermTupleEnumeratorInterface
107
108
 protected:
109
  /** the quantifier whose variables are being instantiated */
110
  const Node d_quantifier;
111
  /** number of variables in the quantifier */
112
  const size_t d_variableCount;
113
  /** env of structures with a longer lifespan */
114
  const TermTupleEnumeratorEnv* const d_env;
115
  /** type for each variable */
116
  std::vector<TypeNode> d_typeCache;
117
  /** number of candidate terms for each variable */
118
  std::vector<size_t> d_termsSizes;
119
  /** tuple of indices of the current terms */
120
  std::vector<size_t> d_termIndex;
121
  /** total number of steps of the enumerator */
122
  uint32_t d_stepCounter;
123
124
  /** a data structure storing disabled combinations of terms */
125
  IndexTrie d_disabledCombinations;
126
127
  /** current sum/max  of digits, depending on the strategy */
128
  size_t d_currentStage;
129
  /**total number of stages*/
130
  size_t d_stageCount;
131
  /**becomes false once the enumerator runs out of options*/
132
  bool d_hasNext;
133
  /** the length of the prefix that has to be changed in the next
134
  combination, i.e.  the number of the most significant digits that need to be
135
  changed in order to escape a  useless instantiation */
136
  size_t d_changePrefix;
137
  /** Move onto the next stage */
138
  bool increaseStage();
139
  /** Move onto the next stage, sum strategy. */
140
  bool increaseStageSum();
141
  /** Move onto the next stage, max strategy. */
142
  bool increaseStageMax();
143
  /** Move on in the current stage */
144
  bool nextCombination();
145
  /** Move onto the next combination. */
146
  bool nextCombinationInternal();
147
  /** Find the next lexicographically smallest combination of terms that change
148
   * on the change prefix, each digit is within the current state,  and there is
149
   * at least one digit not in the previous stage. */
150
  bool nextCombinationSum();
151
  /** Find the next lexicographically smallest combination of terms that change
152
   * on the change prefix and their sum is equal to d_currentStage. */
153
  bool nextCombinationMax();
154
  /** Set up terms for given variable.  */
155
  virtual size_t prepareTerms(size_t variableIx) = 0;
156
  /** Get a given term for a given variable.  */
157
  virtual Node getTerm(size_t variableIx,
158
                       size_t term_index) CVC5_WARN_UNUSED_RESULT = 0;
159
};
160
161
/**
162
 * Enumerate ground terms as they come from the term database.
163
 */
164
class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase
165
{
166
 public:
167
220
  TermTupleEnumeratorBasic(Node quantifier,
168
                           const TermTupleEnumeratorEnv* env,
169
                           QuantifiersState& qs,
170
                           TermDb* td)
171
220
      : TermTupleEnumeratorBase(quantifier, env), d_qs(qs), d_tdb(td)
172
  {
173
220
  }
174
175
440
  virtual ~TermTupleEnumeratorBasic() = default;
176
177
 protected:
178
  /**  a list of terms for each type */
179
  std::map<TypeNode, std::vector<Node> > d_termDbList;
180
  virtual size_t prepareTerms(size_t variableIx) override;
181
  virtual Node getTerm(size_t variableIx, size_t term_index) override;
182
  /** Reference to quantifiers state */
183
  QuantifiersState& d_qs;
184
  /** Pointer to term database */
185
  TermDb* d_tdb;
186
};
187
188
/**
189
 * Enumerate ground terms according to the relevant domain class.
190
 */
191
class TermTupleEnumeratorRD : public TermTupleEnumeratorBase
192
{
193
 public:
194
1060
  TermTupleEnumeratorRD(Node quantifier,
195
                        const TermTupleEnumeratorEnv* env,
196
                        RelevantDomain* rd)
197
1060
      : TermTupleEnumeratorBase(quantifier, env), d_rd(rd)
198
  {
199
1060
  }
200
2120
  virtual ~TermTupleEnumeratorRD() = default;
201
202
 protected:
203
9383
  virtual size_t prepareTerms(size_t variableIx) override
204
  {
205
9383
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms.size();
206
  }
207
16716
  virtual Node getTerm(size_t variableIx, size_t term_index) override
208
  {
209
16716
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms[term_index];
210
  }
211
  /** The relevant domain */
212
  RelevantDomain* d_rd;
213
};
214
215
1284
void TermTupleEnumeratorBase::init()
216
{
217
2568
  Trace("inst-alg-rd") << "Initializing enumeration " << d_quantifier
218
1284
                       << std::endl;
219
1284
  d_currentStage = 0;
220
1284
  d_hasNext = true;
221
1284
  d_stageCount = 1;  // in the case of full effort we do at least one stage
222
223
1284
  if (d_variableCount == 0)
224
  {
225
    d_hasNext = false;
226
    return;
227
  }
228
229
  // prepare a sequence of terms for each quantified variable
230
  // additionally initialize the cache for variable types
231
11013
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
232
  {
233
9729
    d_typeCache.push_back(d_quantifier[0][variableIx].getType());
234
9729
    const size_t termsSize = prepareTerms(variableIx);
235
19458
    Trace("inst-alg-rd") << "Variable " << variableIx << " has " << termsSize
236
9729
                         << " in relevant domain." << std::endl;
237
9729
    if (termsSize == 0 && !d_env->d_fullEffort)
238
    {
239
      d_hasNext = false;
240
      return;  // give up on an empty dommain
241
    }
242
9729
    d_termsSizes.push_back(termsSize);
243
9729
    d_stageCount = std::max(d_stageCount, termsSize);
244
  }
245
246
2568
  Trace("inst-alg-rd") << "Will do " << d_stageCount
247
1284
                       << " stages of instantiation." << std::endl;
248
1284
  d_termIndex.resize(d_variableCount, 0);
249
}
250
251
3936
bool TermTupleEnumeratorBase::hasNext()
252
{
253
3936
  if (!d_hasNext)
254
  {
255
    return false;
256
  }
257
258
3936
  if (d_stepCounter++ == 0)
259
  {  // TODO:any (nice)  way of avoiding this special if?
260
1284
    Assert(d_currentStage == 0);
261
2568
    Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..."
262
1284
                         << std::endl;
263
1284
    return true;
264
  }
265
266
  // try to find the next combination
267
2652
  return d_hasNext = nextCombination();
268
}
269
270
2644
void TermTupleEnumeratorBase::failureReason(const std::vector<bool>& mask)
271
{
272
2644
  if (Trace.isOn("inst-alg"))
273
  {
274
    traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex);
275
  }
276
2644
  d_disabledCombinations.add(mask, d_termIndex);  // record failure
277
  // update change prefix accordingly
278
4753
  for (d_changePrefix = mask.size();
279
4753
       d_changePrefix && !mask[d_changePrefix - 1];
280
2109
       d_changePrefix--)
281
    ;
282
2644
}
283
284
3449
void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
285
{
286
3449
  Trace("inst-alg-rd") << "Try instantiation: " << d_termIndex << std::endl;
287
3449
  terms.resize(d_variableCount);
288
24033
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
289
  {
290
20584
    const Node t = d_termsSizes[variableIx] == 0
291
                       ? Node::null()
292
41168
                       : getTerm(variableIx, d_termIndex[variableIx]);
293
20584
    terms[variableIx] = t;
294
20584
    Trace("inst-alg-rd") << t << "  ";
295
20584
    Assert(terms[variableIx].isNull()
296
           || terms[variableIx].getType().isComparableTo(
297
               d_quantifier[0][variableIx].getType()));
298
  }
299
3449
  Trace("inst-alg-rd") << std::endl;
300
3449
}
301
302
bool TermTupleEnumeratorBase::increaseStageSum()
303
{
304
  const size_t lowerBound = d_currentStage + 1;
305
  Trace("inst-alg-rd") << "Try sum " << lowerBound << "..." << std::endl;
306
  d_currentStage = 0;
307
  for (size_t digit = d_termIndex.size();
308
       d_currentStage < lowerBound && digit--;)
309
  {
310
    const size_t missing = lowerBound - d_currentStage;
311
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
312
    d_termIndex[digit] = std::min(missing, maxValue);
313
    d_currentStage += d_termIndex[digit];
314
  }
315
  return d_currentStage >= lowerBound;
316
}
317
318
1622
bool TermTupleEnumeratorBase::increaseStage()
319
{
320
1622
  d_changePrefix = d_variableCount;  // simply reset upon increase stage
321
1622
  return d_env->d_increaseSum ? increaseStageSum() : increaseStageMax();
322
}
323
324
1622
bool TermTupleEnumeratorBase::increaseStageMax()
325
{
326
1622
  d_currentStage++;
327
1622
  if (d_currentStage >= d_stageCount)
328
  {
329
487
    return false;
330
  }
331
1135
  Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..." << std::endl;
332
  // skipping some elements that have already been definitely seen
333
  // find the least significant digit that can be set to the current stage
334
  // TODO: should we skip all?
335
1135
  std::fill(d_termIndex.begin(), d_termIndex.end(), 0);
336
1135
  bool found = false;
337
2813
  for (size_t digit = d_termIndex.size(); !found && digit--;)
338
  {
339
1678
    if (d_termsSizes[digit] > d_currentStage)
340
    {
341
1135
      found = true;
342
1135
      d_termIndex[digit] = d_currentStage;
343
    }
344
  }
345
1135
  Assert(found);
346
1135
  return found;
347
}
348
349
2932
bool TermTupleEnumeratorBase::nextCombination()
350
{
351
  while (true)
352
  {
353
3212
    Trace("inst-alg-rd") << "changePrefix " << d_changePrefix << std::endl;
354
2932
    if (!nextCombinationInternal() && !increaseStage())
355
    {
356
487
      return false;  // ran out of combinations
357
    }
358
2445
    if (!d_disabledCombinations.find(d_termIndex, d_changePrefix))
359
    {
360
2165
      return true;  // current combination vetted by disabled combinations
361
    }
362
  }
363
}
364
365
/** Move onto the next combination, depending on the strategy. */
366
2932
bool TermTupleEnumeratorBase::nextCombinationInternal()
367
{
368
2932
  return d_env->d_increaseSum ? nextCombinationSum() : nextCombinationMax();
369
}
370
371
/** Find the next lexicographically smallest combination of terms that change
372
 * on the change prefix and their sum is equal to d_currentStage. */
373
2932
bool TermTupleEnumeratorBase::nextCombinationMax()
374
{
375
  // look for the least significant digit, within change prefix,
376
  // that can be increased
377
2932
  bool found = false;
378
2932
  size_t increaseDigit = d_changePrefix;
379
18974
  while (!found && increaseDigit--)
380
  {
381
8021
    const size_t new_value = d_termIndex[increaseDigit] + 1;
382
8021
    if (new_value < d_termsSizes[increaseDigit] && new_value <= d_currentStage)
383
    {
384
1310
      d_termIndex[increaseDigit] = new_value;
385
      // send everything after the increased digit to 0
386
1310
      std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
387
1310
      found = true;
388
    }
389
  }
390
2932
  if (!found)
391
  {
392
1622
    return false;  // nothing to increase
393
  }
394
  // check if the combination has at least one digit in the current stage,
395
  // since at least one digit was increased, no need for this in stage 1
396
1310
  bool inStage = d_currentStage <= 1;
397
1909
  for (size_t i = increaseDigit + 1; !inStage && i--;)
398
  {
399
599
    inStage = d_termIndex[i] >= d_currentStage;
400
  }
401
1310
  if (!inStage)  // look for a digit that can increase to current stage
402
  {
403
204
    for (increaseDigit = d_variableCount, found = false;
404
204
         !found && increaseDigit--;)
405
    {
406
102
      found = d_termsSizes[increaseDigit] > d_currentStage;
407
    }
408
102
    if (!found)
409
    {
410
      return false;  // nothing to increase to the current stage
411
    }
412
102
    Assert(d_termsSizes[increaseDigit] > d_currentStage
413
           && d_termIndex[increaseDigit] < d_currentStage);
414
102
    d_termIndex[increaseDigit] = d_currentStage;
415
    // send everything after the increased digit to 0
416
102
    std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
417
  }
418
1310
  return true;
419
}
420
421
/** Find the next lexicographically smallest combination of terms that change
422
 * on the change prefix, each digit is within the current state,  and there is
423
 * at least one digit not in the previous stage. */
424
bool TermTupleEnumeratorBase::nextCombinationSum()
425
{
426
  size_t suffixSum = 0;
427
  bool found = false;
428
  size_t increaseDigit = d_termIndex.size();
429
  while (increaseDigit--)
430
  {
431
    const size_t newValue = d_termIndex[increaseDigit] + 1;
432
    found = suffixSum > 0 && newValue < d_termsSizes[increaseDigit]
433
            && increaseDigit < d_changePrefix;
434
    if (found)
435
    {
436
      // digit can be increased and suffix can be decreased
437
      d_termIndex[increaseDigit] = newValue;
438
      break;
439
    }
440
    suffixSum += d_termIndex[increaseDigit];
441
    d_termIndex[increaseDigit] = 0;
442
  }
443
  if (!found)
444
  {
445
    return false;
446
  }
447
  Assert(suffixSum > 0);
448
  // increaseDigit went up by one, hence, distribute (suffixSum - 1) in the
449
  // least significant digits
450
  suffixSum--;
451
  for (size_t digit = d_termIndex.size(); suffixSum > 0 && digit--;)
452
  {
453
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
454
    d_termIndex[digit] = std::min(suffixSum, maxValue);
455
    suffixSum -= d_termIndex[digit];
456
  }
457
  Assert(suffixSum == 0);  // everything should have been distributed
458
  return true;
459
}
460
461
342
size_t TermTupleEnumeratorBasic::prepareTerms(size_t variableIx)
462
{
463
684
  const TypeNode type_node = d_typeCache[variableIx];
464
342
  if (!ContainsKey(d_termDbList, type_node))
465
  {
466
248
    const size_t ground_terms_count = d_tdb->getNumTypeGroundTerms(type_node);
467
496
    std::map<Node, Node> repsFound;
468
10494
    for (size_t j = 0; j < ground_terms_count; j++)
469
    {
470
20492
      Node gt = d_tdb->getTypeGroundTerm(type_node, j);
471
10246
      if (!options::cegqi() || !quantifiers::TermUtil::hasInstConstAttr(gt))
472
      {
473
20492
        Node rep = d_qs.getRepresentative(gt);
474
10246
        if (repsFound.find(rep) == repsFound.end())
475
        {
476
891
          repsFound[rep] = gt;
477
891
          d_termDbList[type_node].push_back(gt);
478
        }
479
      }
480
    }
481
  }
482
483
684
  Trace("inst-alg-rd") << "Instantiation Terms for child " << variableIx << ": "
484
342
                       << d_termDbList[type_node] << std::endl;
485
684
  return d_termDbList[type_node].size();
486
}
487
488
1831
Node TermTupleEnumeratorBasic::getTerm(size_t variableIx, size_t term_index)
489
{
490
3662
  const TypeNode type_node = d_typeCache[variableIx];
491
1831
  Assert(term_index < d_termDbList[type_node].size());
492
3662
  return d_termDbList[type_node][term_index];
493
}
494
495
/**
496
 * Enumerate ground terms as they come from a user-provided term pool
497
 */
498
class TermTupleEnumeratorPool : public TermTupleEnumeratorBase
499
{
500
 public:
501
4
  TermTupleEnumeratorPool(Node quantifier,
502
                          const TermTupleEnumeratorEnv* env,
503
                          TermPools* tp,
504
                          Node pool)
505
4
      : TermTupleEnumeratorBase(quantifier, env), d_tp(tp), d_pool(pool)
506
  {
507
4
    Assert(d_pool.getKind() == kind::INST_POOL);
508
4
  }
509
510
8
  virtual ~TermTupleEnumeratorPool() = default;
511
512
 protected:
513
  /** Pointer to the term pool utility */
514
  TermPools* d_tp;
515
  /** The pool annotation */
516
  Node d_pool;
517
  /**  a list of terms for each id */
518
  std::map<size_t, std::vector<Node> > d_poolList;
519
  /** gets the terms from the pool */
520
4
  size_t prepareTerms(size_t variableIx) override
521
  {
522
4
    Assert(d_pool.getNumChildren() > variableIx);
523
    // prepare terms from pool
524
4
    d_poolList[variableIx].clear();
525
4
    d_tp->getTermsForPool(d_pool[variableIx], d_poolList[variableIx]);
526
8
    Trace("pool-inst") << "Instantiation Terms for child " << variableIx << ": "
527
4
                       << d_poolList[variableIx] << std::endl;
528
4
    return d_poolList[variableIx].size();
529
  }
530
8
  Node getTerm(size_t variableIx, size_t term_index) override
531
  {
532
8
    Assert(term_index < d_poolList[variableIx].size());
533
8
    return d_poolList[variableIx][term_index];
534
  }
535
};
536
537
220
TermTupleEnumeratorInterface* mkTermTupleEnumerator(
538
    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs, TermDb* td)
539
{
540
  return static_cast<TermTupleEnumeratorInterface*>(
541
220
      new TermTupleEnumeratorBasic(q, env, qs, td));
542
}
543
1060
TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
544
    Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd)
545
{
546
  return static_cast<TermTupleEnumeratorInterface*>(
547
1060
      new TermTupleEnumeratorRD(q, env, rd));
548
}
549
550
4
TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool(
551
    Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node pool)
552
{
553
  return static_cast<TermTupleEnumeratorInterface*>(
554
4
      new TermTupleEnumeratorPool(q, env, tp, pool));
555
}
556
557
}  // namespace quantifiers
558
}  // namespace theory
559
29340
}  // namespace cvc5