GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/term_tuple_enumerator.cpp Lines: 157 207 75.8 %
Date: 2021-09-17 Branches: 241 666 36.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   MikolasJanota, Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of an enumeration of tuples of terms for the purpose of
14
 * quantifier instantiation.
15
 */
16
#include "theory/quantifiers/term_tuple_enumerator.h"
17
18
#include <algorithm>
19
#include <functional>
20
#include <iterator>
21
#include <map>
22
#include <vector>
23
24
#include "base/map_util.h"
25
#include "base/output.h"
26
#include "options/quantifiers_options.h"
27
#include "smt/smt_statistics_registry.h"
28
#include "theory/quantifiers/index_trie.h"
29
#include "theory/quantifiers/quant_module.h"
30
#include "theory/quantifiers/relevant_domain.h"
31
#include "theory/quantifiers/term_pools.h"
32
#include "theory/quantifiers/term_registry.h"
33
#include "theory/quantifiers/term_util.h"
34
#include "util/statistics_stats.h"
35
36
namespace cvc5 {
37
38
template <typename T>
39
3792
static Cvc5ostream& operator<<(Cvc5ostream& out, const std::vector<T>& v)
40
{
41
3792
  out << "[ ";
42
3792
  std::copy(v.begin(), v.end(), std::ostream_iterator<T>(out, " "));
43
3792
  return out << "]";
44
}
45
46
/** Tracing purposes, printing a masked vector of indices. */
47
static void traceMaskedVector(const char* trace,
48
                              const char* name,
49
                              const std::vector<bool>& mask,
50
                              const std::vector<size_t>& values)
51
{
52
  Assert(mask.size() == values.size());
53
  Trace(trace) << name << " [ ";
54
  for (size_t variableIx = 0; variableIx < mask.size(); variableIx++)
55
  {
56
    if (mask[variableIx])
57
    {
58
      Trace(trace) << values[variableIx] << " ";
59
    }
60
    else
61
    {
62
      Trace(trace) << "_ ";
63
    }
64
  }
65
  Trace(trace) << "]" << std::endl;
66
}
67
68
namespace theory {
69
namespace quantifiers {
70
/**
71
 * Base class for enumerators of tuples of terms for the purpose of
72
 * quantification instantiation. The tuples are represented as tuples of
73
 * indices of  terms, where the tuple has as many elements as there are
74
 * quantified variables in the considered quantifier.
75
 *
76
 * Like so, we see a tuple as a number, where the digits may have different
77
 * ranges. The most significant digits are stored first.
78
 *
79
 * Tuples are enumerated  in a lexicographic order in stages. There are 2
80
 * possible strategies, either  all tuples in a given stage have the same sum of
81
 * digits, or, the maximum  over these digits is the same.
82
 * */
83
class TermTupleEnumeratorBase : public TermTupleEnumeratorInterface
84
{
85
 public:
86
  /** Initialize the class with the quantifier to be instantiated. */
87
1294
  TermTupleEnumeratorBase(Node quantifier, const TermTupleEnumeratorEnv* env)
88
1294
      : d_quantifier(quantifier),
89
2588
        d_variableCount(d_quantifier[0].getNumChildren()),
90
        d_env(env),
91
        d_stepCounter(0),
92
        d_disabledCombinations(
93
3882
            true)  // do not record combinations with no blanks
94
95
  {
96
1294
    d_changePrefix = d_variableCount;
97
1294
  }
98
99
1294
  virtual ~TermTupleEnumeratorBase() = default;
100
101
  // implementation of the TermTupleEnumeratorInterface
102
  virtual void init() override;
103
  virtual bool hasNext() override;
104
  virtual void next(/*out*/ std::vector<Node>& terms) override;
105
  virtual void failureReason(const std::vector<bool>& mask) override;
106
  // end of implementation of the TermTupleEnumeratorInterface
107
108
 protected:
109
  /** the quantifier whose variables are being instantiated */
110
  const Node d_quantifier;
111
  /** number of variables in the quantifier */
112
  const size_t d_variableCount;
113
  /** env of structures with a longer lifespan */
114
  const TermTupleEnumeratorEnv* const d_env;
115
  /** type for each variable */
116
  std::vector<TypeNode> d_typeCache;
117
  /** number of candidate terms for each variable */
118
  std::vector<size_t> d_termsSizes;
119
  /** tuple of indices of the current terms */
120
  std::vector<size_t> d_termIndex;
121
  /** total number of steps of the enumerator */
122
  uint32_t d_stepCounter;
123
124
  /** a data structure storing disabled combinations of terms */
125
  IndexTrie d_disabledCombinations;
126
127
  /** current sum/max  of digits, depending on the strategy */
128
  size_t d_currentStage;
129
  /**total number of stages*/
130
  size_t d_stageCount;
131
  /**becomes false once the enumerator runs out of options*/
132
  bool d_hasNext;
133
  /** the length of the prefix that has to be changed in the next
134
  combination, i.e.  the number of the most significant digits that need to be
135
  changed in order to escape a  useless instantiation */
136
  size_t d_changePrefix;
137
  /** Move onto the next stage */
138
  bool increaseStage();
139
  /** Move onto the next stage, sum strategy. */
140
  bool increaseStageSum();
141
  /** Move onto the next stage, max strategy. */
142
  bool increaseStageMax();
143
  /** Move on in the current stage */
144
  bool nextCombination();
145
  /** Move onto the next combination. */
146
  bool nextCombinationInternal();
147
  /** Find the next lexicographically smallest combination of terms that change
148
   * on the change prefix, each digit is within the current state,  and there is
149
   * at least one digit not in the previous stage. */
150
  bool nextCombinationSum();
151
  /** Find the next lexicographically smallest combination of terms that change
152
   * on the change prefix and their sum is equal to d_currentStage. */
153
  bool nextCombinationMax();
154
  /** Set up terms for given variable.  */
155
  virtual size_t prepareTerms(size_t variableIx) = 0;
156
  /** Get a given term for a given variable.  */
157
  CVC5_WARN_UNUSED_RESULT virtual Node getTerm(size_t variableIx,
158
                                               size_t term_index) = 0;
159
};
160
161
/**
162
 * Enumerate ground terms as they come from the term database.
163
 */
164
class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase
165
{
166
 public:
167
229
  TermTupleEnumeratorBasic(Node quantifier,
168
                           const TermTupleEnumeratorEnv* env,
169
                           QuantifiersState& qs,
170
                           TermDb* td)
171
229
      : TermTupleEnumeratorBase(quantifier, env), d_qs(qs), d_tdb(td)
172
  {
173
229
  }
174
175
458
  virtual ~TermTupleEnumeratorBasic() = default;
176
177
 protected:
178
  /**  a list of terms for each type */
179
  std::map<TypeNode, std::vector<Node> > d_termDbList;
180
  virtual size_t prepareTerms(size_t variableIx) override;
181
  virtual Node getTerm(size_t variableIx, size_t term_index) override;
182
  /** Reference to quantifiers state */
183
  QuantifiersState& d_qs;
184
  /** Pointer to term database */
185
  TermDb* d_tdb;
186
};
187
188
/**
189
 * Enumerate ground terms according to the relevant domain class.
190
 */
191
class TermTupleEnumeratorRD : public TermTupleEnumeratorBase
192
{
193
 public:
194
1061
  TermTupleEnumeratorRD(Node quantifier,
195
                        const TermTupleEnumeratorEnv* env,
196
                        RelevantDomain* rd)
197
1061
      : TermTupleEnumeratorBase(quantifier, env), d_rd(rd)
198
  {
199
1061
  }
200
2122
  virtual ~TermTupleEnumeratorRD() = default;
201
202
 protected:
203
9386
  virtual size_t prepareTerms(size_t variableIx) override
204
  {
205
9386
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms.size();
206
  }
207
16648
  virtual Node getTerm(size_t variableIx, size_t term_index) override
208
  {
209
16648
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms[term_index];
210
  }
211
  /** The relevant domain */
212
  RelevantDomain* d_rd;
213
};
214
215
1294
void TermTupleEnumeratorBase::init()
216
{
217
2588
  Trace("inst-alg-rd") << "Initializing enumeration " << d_quantifier
218
1294
                       << std::endl;
219
1294
  d_currentStage = 0;
220
1294
  d_hasNext = true;
221
1294
  d_stageCount = 1;  // in the case of full effort we do at least one stage
222
223
1294
  if (d_variableCount == 0)
224
  {
225
    d_hasNext = false;
226
    return;
227
  }
228
229
  // prepare a sequence of terms for each quantified variable
230
  // additionally initialize the cache for variable types
231
11032
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
232
  {
233
9738
    d_typeCache.push_back(d_quantifier[0][variableIx].getType());
234
9738
    const size_t termsSize = prepareTerms(variableIx);
235
19476
    Trace("inst-alg-rd") << "Variable " << variableIx << " has " << termsSize
236
9738
                         << " in relevant domain." << std::endl;
237
9738
    if (termsSize == 0 && !d_env->d_fullEffort)
238
    {
239
      d_hasNext = false;
240
      return;  // give up on an empty dommain
241
    }
242
9738
    d_termsSizes.push_back(termsSize);
243
9738
    d_stageCount = std::max(d_stageCount, termsSize);
244
  }
245
246
2588
  Trace("inst-alg-rd") << "Will do " << d_stageCount
247
1294
                       << " stages of instantiation." << std::endl;
248
1294
  d_termIndex.resize(d_variableCount, 0);
249
}
250
251
3930
bool TermTupleEnumeratorBase::hasNext()
252
{
253
3930
  if (!d_hasNext)
254
  {
255
    return false;
256
  }
257
258
3930
  if (d_stepCounter++ == 0)
259
  {  // TODO:any (nice)  way of avoiding this special if?
260
1294
    Assert(d_currentStage == 0);
261
2588
    Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..."
262
1294
                         << std::endl;
263
1294
    return true;
264
  }
265
266
  // try to find the next combination
267
2636
  return d_hasNext = nextCombination();
268
}
269
270
2628
void TermTupleEnumeratorBase::failureReason(const std::vector<bool>& mask)
271
{
272
2628
  if (Trace.isOn("inst-alg"))
273
  {
274
    traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex);
275
  }
276
2628
  d_disabledCombinations.add(mask, d_termIndex);  // record failure
277
  // update change prefix accordingly
278
4737
  for (d_changePrefix = mask.size();
279
4737
       d_changePrefix && !mask[d_changePrefix - 1];
280
2109
       d_changePrefix--)
281
    ;
282
2628
}
283
284
3440
void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
285
{
286
3440
  Trace("inst-alg-rd") << "Try instantiation: " << d_termIndex << std::endl;
287
3440
  terms.resize(d_variableCount);
288
24008
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
289
  {
290
20568
    const Node t = d_termsSizes[variableIx] == 0
291
                       ? Node::null()
292
41136
                       : getTerm(variableIx, d_termIndex[variableIx]);
293
20568
    terms[variableIx] = t;
294
20568
    Trace("inst-alg-rd") << t << "  ";
295
41136
    Assert(t.isNull()
296
           || t.getType().isComparableTo(d_quantifier[0][variableIx].getType()))
297
20568
        << "Bad type: " << t << " " << t.getType() << " "
298
20568
        << d_quantifier[0][variableIx].getType();
299
  }
300
3440
  Trace("inst-alg-rd") << std::endl;
301
3440
}
302
303
bool TermTupleEnumeratorBase::increaseStageSum()
304
{
305
  const size_t lowerBound = d_currentStage + 1;
306
  Trace("inst-alg-rd") << "Try sum " << lowerBound << "..." << std::endl;
307
  d_currentStage = 0;
308
  for (size_t digit = d_termIndex.size();
309
       d_currentStage < lowerBound && digit--;)
310
  {
311
    const size_t missing = lowerBound - d_currentStage;
312
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
313
    d_termIndex[digit] = std::min(missing, maxValue);
314
    d_currentStage += d_termIndex[digit];
315
  }
316
  return d_currentStage >= lowerBound;
317
}
318
319
1608
bool TermTupleEnumeratorBase::increaseStage()
320
{
321
1608
  d_changePrefix = d_variableCount;  // simply reset upon increase stage
322
1608
  return d_env->d_increaseSum ? increaseStageSum() : increaseStageMax();
323
}
324
325
1608
bool TermTupleEnumeratorBase::increaseStageMax()
326
{
327
1608
  d_currentStage++;
328
1608
  if (d_currentStage >= d_stageCount)
329
  {
330
490
    return false;
331
  }
332
1118
  Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..." << std::endl;
333
  // skipping some elements that have already been definitely seen
334
  // find the least significant digit that can be set to the current stage
335
  // TODO: should we skip all?
336
1118
  std::fill(d_termIndex.begin(), d_termIndex.end(), 0);
337
1118
  bool found = false;
338
2794
  for (size_t digit = d_termIndex.size(); !found && digit--;)
339
  {
340
1676
    if (d_termsSizes[digit] > d_currentStage)
341
    {
342
1118
      found = true;
343
1118
      d_termIndex[digit] = d_currentStage;
344
    }
345
  }
346
1118
  Assert(found);
347
1118
  return found;
348
}
349
350
2907
bool TermTupleEnumeratorBase::nextCombination()
351
{
352
  while (true)
353
  {
354
3178
    Trace("inst-alg-rd") << "changePrefix " << d_changePrefix << std::endl;
355
2907
    if (!nextCombinationInternal() && !increaseStage())
356
    {
357
490
      return false;  // ran out of combinations
358
    }
359
2417
    if (!d_disabledCombinations.find(d_termIndex, d_changePrefix))
360
    {
361
2146
      return true;  // current combination vetted by disabled combinations
362
    }
363
  }
364
}
365
366
/** Move onto the next combination, depending on the strategy. */
367
2907
bool TermTupleEnumeratorBase::nextCombinationInternal()
368
{
369
2907
  return d_env->d_increaseSum ? nextCombinationSum() : nextCombinationMax();
370
}
371
372
/** Find the next lexicographically smallest combination of terms that change
373
 * on the change prefix and their sum is equal to d_currentStage. */
374
2907
bool TermTupleEnumeratorBase::nextCombinationMax()
375
{
376
  // look for the least significant digit, within change prefix,
377
  // that can be increased
378
2907
  bool found = false;
379
2907
  size_t increaseDigit = d_changePrefix;
380
18881
  while (!found && increaseDigit--)
381
  {
382
7987
    const size_t new_value = d_termIndex[increaseDigit] + 1;
383
7987
    if (new_value < d_termsSizes[increaseDigit] && new_value <= d_currentStage)
384
    {
385
1299
      d_termIndex[increaseDigit] = new_value;
386
      // send everything after the increased digit to 0
387
1299
      std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
388
1299
      found = true;
389
    }
390
  }
391
2907
  if (!found)
392
  {
393
1608
    return false;  // nothing to increase
394
  }
395
  // check if the combination has at least one digit in the current stage,
396
  // since at least one digit was increased, no need for this in stage 1
397
1299
  bool inStage = d_currentStage <= 1;
398
1898
  for (size_t i = increaseDigit + 1; !inStage && i--;)
399
  {
400
599
    inStage = d_termIndex[i] >= d_currentStage;
401
  }
402
1299
  if (!inStage)  // look for a digit that can increase to current stage
403
  {
404
204
    for (increaseDigit = d_variableCount, found = false;
405
204
         !found && increaseDigit--;)
406
    {
407
102
      found = d_termsSizes[increaseDigit] > d_currentStage;
408
    }
409
102
    if (!found)
410
    {
411
      return false;  // nothing to increase to the current stage
412
    }
413
102
    Assert(d_termsSizes[increaseDigit] > d_currentStage
414
           && d_termIndex[increaseDigit] < d_currentStage);
415
102
    d_termIndex[increaseDigit] = d_currentStage;
416
    // send everything after the increased digit to 0
417
102
    std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
418
  }
419
1299
  return true;
420
}
421
422
/** Find the next lexicographically smallest combination of terms that change
423
 * on the change prefix, each digit is within the current state,  and there is
424
 * at least one digit not in the previous stage. */
425
bool TermTupleEnumeratorBase::nextCombinationSum()
426
{
427
  size_t suffixSum = 0;
428
  bool found = false;
429
  size_t increaseDigit = d_termIndex.size();
430
  while (increaseDigit--)
431
  {
432
    const size_t newValue = d_termIndex[increaseDigit] + 1;
433
    found = suffixSum > 0 && newValue < d_termsSizes[increaseDigit]
434
            && increaseDigit < d_changePrefix;
435
    if (found)
436
    {
437
      // digit can be increased and suffix can be decreased
438
      d_termIndex[increaseDigit] = newValue;
439
      break;
440
    }
441
    suffixSum += d_termIndex[increaseDigit];
442
    d_termIndex[increaseDigit] = 0;
443
  }
444
  if (!found)
445
  {
446
    return false;
447
  }
448
  Assert(suffixSum > 0);
449
  // increaseDigit went up by one, hence, distribute (suffixSum - 1) in the
450
  // least significant digits
451
  suffixSum--;
452
  for (size_t digit = d_termIndex.size(); suffixSum > 0 && digit--;)
453
  {
454
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
455
    d_termIndex[digit] = std::min(suffixSum, maxValue);
456
    suffixSum -= d_termIndex[digit];
457
  }
458
  Assert(suffixSum == 0);  // everything should have been distributed
459
  return true;
460
}
461
462
348
size_t TermTupleEnumeratorBasic::prepareTerms(size_t variableIx)
463
{
464
696
  const TypeNode type_node = d_typeCache[variableIx];
465
348
  if (!ContainsKey(d_termDbList, type_node))
466
  {
467
256
    const size_t ground_terms_count = d_tdb->getNumTypeGroundTerms(type_node);
468
512
    std::map<Node, Node> repsFound;
469
10469
    for (size_t j = 0; j < ground_terms_count; j++)
470
    {
471
20426
      Node gt = d_tdb->getTypeGroundTerm(type_node, j);
472
10213
      if (!options::cegqi() || !quantifiers::TermUtil::hasInstConstAttr(gt))
473
      {
474
20426
        Node rep = d_qs.getRepresentative(gt);
475
10213
        if (repsFound.find(rep) == repsFound.end())
476
        {
477
929
          repsFound[rep] = gt;
478
929
          d_termDbList[type_node].push_back(gt);
479
        }
480
      }
481
    }
482
  }
483
484
696
  Trace("inst-alg-rd") << "Instantiation Terms for child " << variableIx << ": "
485
348
                       << d_termDbList[type_node] << std::endl;
486
696
  return d_termDbList[type_node].size();
487
}
488
489
1838
Node TermTupleEnumeratorBasic::getTerm(size_t variableIx, size_t term_index)
490
{
491
3676
  const TypeNode type_node = d_typeCache[variableIx];
492
1838
  Assert(term_index < d_termDbList[type_node].size());
493
3676
  return d_termDbList[type_node][term_index];
494
}
495
496
/**
497
 * Enumerate ground terms as they come from a user-provided term pool
498
 */
499
class TermTupleEnumeratorPool : public TermTupleEnumeratorBase
500
{
501
 public:
502
4
  TermTupleEnumeratorPool(Node quantifier,
503
                          const TermTupleEnumeratorEnv* env,
504
                          TermPools* tp,
505
                          Node pool)
506
4
      : TermTupleEnumeratorBase(quantifier, env), d_tp(tp), d_pool(pool)
507
  {
508
4
    Assert(d_pool.getKind() == kind::INST_POOL);
509
4
  }
510
511
8
  virtual ~TermTupleEnumeratorPool() = default;
512
513
 protected:
514
  /** Pointer to the term pool utility */
515
  TermPools* d_tp;
516
  /** The pool annotation */
517
  Node d_pool;
518
  /**  a list of terms for each id */
519
  std::map<size_t, std::vector<Node> > d_poolList;
520
  /** gets the terms from the pool */
521
4
  size_t prepareTerms(size_t variableIx) override
522
  {
523
4
    Assert(d_pool.getNumChildren() > variableIx);
524
    // prepare terms from pool
525
4
    d_poolList[variableIx].clear();
526
4
    d_tp->getTermsForPool(d_pool[variableIx], d_poolList[variableIx]);
527
8
    Trace("pool-inst") << "Instantiation Terms for child " << variableIx << ": "
528
4
                       << d_poolList[variableIx] << std::endl;
529
4
    return d_poolList[variableIx].size();
530
  }
531
8
  Node getTerm(size_t variableIx, size_t term_index) override
532
  {
533
8
    Assert(term_index < d_poolList[variableIx].size());
534
8
    return d_poolList[variableIx][term_index];
535
  }
536
};
537
538
229
TermTupleEnumeratorInterface* mkTermTupleEnumerator(
539
    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs, TermDb* td)
540
{
541
  return static_cast<TermTupleEnumeratorInterface*>(
542
229
      new TermTupleEnumeratorBasic(q, env, qs, td));
543
}
544
1061
TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
545
    Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd)
546
{
547
  return static_cast<TermTupleEnumeratorInterface*>(
548
1061
      new TermTupleEnumeratorRD(q, env, rd));
549
}
550
551
4
TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool(
552
    Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node pool)
553
{
554
  return static_cast<TermTupleEnumeratorInterface*>(
555
4
      new TermTupleEnumeratorPool(q, env, tp, pool));
556
}
557
558
}  // namespace quantifiers
559
}  // namespace theory
560
29577
}  // namespace cvc5