GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/term_tuple_enumerator.cpp Lines: 157 207 75.8 %
Date: 2021-09-29 Branches: 241 666 36.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   MikolasJanota, Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of an enumeration of tuples of terms for the purpose of
14
 * quantifier instantiation.
15
 */
16
#include "theory/quantifiers/term_tuple_enumerator.h"
17
18
#include <algorithm>
19
#include <functional>
20
#include <iterator>
21
#include <map>
22
#include <vector>
23
24
#include "base/map_util.h"
25
#include "base/output.h"
26
#include "options/quantifiers_options.h"
27
#include "smt/smt_statistics_registry.h"
28
#include "theory/quantifiers/index_trie.h"
29
#include "theory/quantifiers/quant_module.h"
30
#include "theory/quantifiers/relevant_domain.h"
31
#include "theory/quantifiers/term_pools.h"
32
#include "theory/quantifiers/term_registry.h"
33
#include "theory/quantifiers/term_util.h"
34
#include "util/statistics_stats.h"
35
36
namespace cvc5 {
37
38
template <typename T>
39
1928
static Cvc5ostream& operator<<(Cvc5ostream& out, const std::vector<T>& v)
40
{
41
1928
  out << "[ ";
42
1928
  std::copy(v.begin(), v.end(), std::ostream_iterator<T>(out, " "));
43
1928
  return out << "]";
44
}
45
46
/** Tracing purposes, printing a masked vector of indices. */
47
static void traceMaskedVector(const char* trace,
48
                              const char* name,
49
                              const std::vector<bool>& mask,
50
                              const std::vector<size_t>& values)
51
{
52
  Assert(mask.size() == values.size());
53
  Trace(trace) << name << " [ ";
54
  for (size_t variableIx = 0; variableIx < mask.size(); variableIx++)
55
  {
56
    if (mask[variableIx])
57
    {
58
      Trace(trace) << values[variableIx] << " ";
59
    }
60
    else
61
    {
62
      Trace(trace) << "_ ";
63
    }
64
  }
65
  Trace(trace) << "]" << std::endl;
66
}
67
68
namespace theory {
69
namespace quantifiers {
70
/**
71
 * Base class for enumerators of tuples of terms for the purpose of
72
 * quantification instantiation. The tuples are represented as tuples of
73
 * indices of  terms, where the tuple has as many elements as there are
74
 * quantified variables in the considered quantifier.
75
 *
76
 * Like so, we see a tuple as a number, where the digits may have different
77
 * ranges. The most significant digits are stored first.
78
 *
79
 * Tuples are enumerated  in a lexicographic order in stages. There are 2
80
 * possible strategies, either  all tuples in a given stage have the same sum of
81
 * digits, or, the maximum  over these digits is the same.
82
 * */
83
class TermTupleEnumeratorBase : public TermTupleEnumeratorInterface
84
{
85
 public:
86
  /** Initialize the class with the quantifier to be instantiated. */
87
583
  TermTupleEnumeratorBase(Node quantifier, const TermTupleEnumeratorEnv* env)
88
583
      : d_quantifier(quantifier),
89
1166
        d_variableCount(d_quantifier[0].getNumChildren()),
90
        d_env(env),
91
        d_stepCounter(0),
92
        d_disabledCombinations(
93
1749
            true)  // do not record combinations with no blanks
94
95
  {
96
583
    d_changePrefix = d_variableCount;
97
583
  }
98
99
583
  virtual ~TermTupleEnumeratorBase() = default;
100
101
  // implementation of the TermTupleEnumeratorInterface
102
  virtual void init() override;
103
  virtual bool hasNext() override;
104
  virtual void next(/*out*/ std::vector<Node>& terms) override;
105
  virtual void failureReason(const std::vector<bool>& mask) override;
106
  // end of implementation of the TermTupleEnumeratorInterface
107
108
 protected:
109
  /** the quantifier whose variables are being instantiated */
110
  const Node d_quantifier;
111
  /** number of variables in the quantifier */
112
  const size_t d_variableCount;
113
  /** env of structures with a longer lifespan */
114
  const TermTupleEnumeratorEnv* const d_env;
115
  /** type for each variable */
116
  std::vector<TypeNode> d_typeCache;
117
  /** number of candidate terms for each variable */
118
  std::vector<size_t> d_termsSizes;
119
  /** tuple of indices of the current terms */
120
  std::vector<size_t> d_termIndex;
121
  /** total number of steps of the enumerator */
122
  uint32_t d_stepCounter;
123
124
  /** a data structure storing disabled combinations of terms */
125
  IndexTrie d_disabledCombinations;
126
127
  /** current sum/max  of digits, depending on the strategy */
128
  size_t d_currentStage;
129
  /**total number of stages*/
130
  size_t d_stageCount;
131
  /**becomes false once the enumerator runs out of options*/
132
  bool d_hasNext;
133
  /** the length of the prefix that has to be changed in the next
134
  combination, i.e.  the number of the most significant digits that need to be
135
  changed in order to escape a  useless instantiation */
136
  size_t d_changePrefix;
137
  /** Move onto the next stage */
138
  bool increaseStage();
139
  /** Move onto the next stage, sum strategy. */
140
  bool increaseStageSum();
141
  /** Move onto the next stage, max strategy. */
142
  bool increaseStageMax();
143
  /** Move on in the current stage */
144
  bool nextCombination();
145
  /** Move onto the next combination. */
146
  bool nextCombinationInternal();
147
  /** Find the next lexicographically smallest combination of terms that change
148
   * on the change prefix, each digit is within the current state,  and there is
149
   * at least one digit not in the previous stage. */
150
  bool nextCombinationSum();
151
  /** Find the next lexicographically smallest combination of terms that change
152
   * on the change prefix and their sum is equal to d_currentStage. */
153
  bool nextCombinationMax();
154
  /** Set up terms for given variable.  */
155
  virtual size_t prepareTerms(size_t variableIx) = 0;
156
  /** Get a given term for a given variable.  */
157
  CVC5_WARN_UNUSED_RESULT virtual Node getTerm(size_t variableIx,
158
                                               size_t term_index) = 0;
159
};
160
161
/**
162
 * Enumerate ground terms as they come from the term database.
163
 */
164
class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase
165
{
166
 public:
167
166
  TermTupleEnumeratorBasic(Node quantifier,
168
                           const TermTupleEnumeratorEnv* env,
169
                           QuantifiersState& qs,
170
                           TermDb* td)
171
166
      : TermTupleEnumeratorBase(quantifier, env), d_qs(qs), d_tdb(td)
172
  {
173
166
  }
174
175
332
  virtual ~TermTupleEnumeratorBasic() = default;
176
177
 protected:
178
  /**  a list of terms for each type */
179
  std::map<TypeNode, std::vector<Node> > d_termDbList;
180
  virtual size_t prepareTerms(size_t variableIx) override;
181
  virtual Node getTerm(size_t variableIx, size_t term_index) override;
182
  /** Reference to quantifiers state */
183
  QuantifiersState& d_qs;
184
  /** Pointer to term database */
185
  TermDb* d_tdb;
186
};
187
188
/**
189
 * Enumerate ground terms according to the relevant domain class.
190
 */
191
class TermTupleEnumeratorRD : public TermTupleEnumeratorBase
192
{
193
 public:
194
416
  TermTupleEnumeratorRD(Node quantifier,
195
                        const TermTupleEnumeratorEnv* env,
196
                        RelevantDomain* rd)
197
416
      : TermTupleEnumeratorBase(quantifier, env), d_rd(rd)
198
  {
199
416
  }
200
832
  virtual ~TermTupleEnumeratorRD() = default;
201
202
 protected:
203
2980
  virtual size_t prepareTerms(size_t variableIx) override
204
  {
205
2980
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms.size();
206
  }
207
5489
  virtual Node getTerm(size_t variableIx, size_t term_index) override
208
  {
209
5489
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms[term_index];
210
  }
211
  /** The relevant domain */
212
  RelevantDomain* d_rd;
213
};
214
215
583
void TermTupleEnumeratorBase::init()
216
{
217
1166
  Trace("inst-alg-rd") << "Initializing enumeration " << d_quantifier
218
583
                       << std::endl;
219
583
  d_currentStage = 0;
220
583
  d_hasNext = true;
221
583
  d_stageCount = 1;  // in the case of full effort we do at least one stage
222
223
583
  if (d_variableCount == 0)
224
  {
225
    d_hasNext = false;
226
    return;
227
  }
228
229
  // prepare a sequence of terms for each quantified variable
230
  // additionally initialize the cache for variable types
231
3813
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
232
  {
233
3230
    d_typeCache.push_back(d_quantifier[0][variableIx].getType());
234
3230
    const size_t termsSize = prepareTerms(variableIx);
235
6460
    Trace("inst-alg-rd") << "Variable " << variableIx << " has " << termsSize
236
3230
                         << " in relevant domain." << std::endl;
237
3230
    if (termsSize == 0 && !d_env->d_fullEffort)
238
    {
239
      d_hasNext = false;
240
      return;  // give up on an empty dommain
241
    }
242
3230
    d_termsSizes.push_back(termsSize);
243
3230
    d_stageCount = std::max(d_stageCount, termsSize);
244
  }
245
246
1166
  Trace("inst-alg-rd") << "Will do " << d_stageCount
247
583
                       << " stages of instantiation." << std::endl;
248
583
  d_termIndex.resize(d_variableCount, 0);
249
}
250
251
1989
bool TermTupleEnumeratorBase::hasNext()
252
{
253
1989
  if (!d_hasNext)
254
  {
255
    return false;
256
  }
257
258
1989
  if (d_stepCounter++ == 0)
259
  {  // TODO:any (nice)  way of avoiding this special if?
260
583
    Assert(d_currentStage == 0);
261
1166
    Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..."
262
583
                         << std::endl;
263
583
    return true;
264
  }
265
266
  // try to find the next combination
267
1406
  return d_hasNext = nextCombination();
268
}
269
270
1404
void TermTupleEnumeratorBase::failureReason(const std::vector<bool>& mask)
271
{
272
1404
  if (Trace.isOn("inst-alg"))
273
  {
274
    traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex);
275
  }
276
1404
  d_disabledCombinations.add(mask, d_termIndex);  // record failure
277
  // update change prefix accordingly
278
2082
  for (d_changePrefix = mask.size();
279
2082
       d_changePrefix && !mask[d_changePrefix - 1];
280
678
       d_changePrefix--)
281
    ;
282
1404
}
283
284
1678
void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
285
{
286
1678
  Trace("inst-alg-rd") << "Try instantiation: " << d_termIndex << std::endl;
287
1678
  terms.resize(d_variableCount);
288
9018
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
289
  {
290
7340
    const Node t = d_termsSizes[variableIx] == 0
291
                       ? Node::null()
292
14680
                       : getTerm(variableIx, d_termIndex[variableIx]);
293
7340
    terms[variableIx] = t;
294
7340
    Trace("inst-alg-rd") << t << "  ";
295
14680
    Assert(t.isNull()
296
           || t.getType().isComparableTo(d_quantifier[0][variableIx].getType()))
297
7340
        << "Bad type: " << t << " " << t.getType() << " "
298
7340
        << d_quantifier[0][variableIx].getType();
299
  }
300
1678
  Trace("inst-alg-rd") << std::endl;
301
1678
}
302
303
bool TermTupleEnumeratorBase::increaseStageSum()
304
{
305
  const size_t lowerBound = d_currentStage + 1;
306
  Trace("inst-alg-rd") << "Try sum " << lowerBound << "..." << std::endl;
307
  d_currentStage = 0;
308
  for (size_t digit = d_termIndex.size();
309
       d_currentStage < lowerBound && digit--;)
310
  {
311
    const size_t missing = lowerBound - d_currentStage;
312
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
313
    d_termIndex[digit] = std::min(missing, maxValue);
314
    d_currentStage += d_termIndex[digit];
315
  }
316
  return d_currentStage >= lowerBound;
317
}
318
319
943
bool TermTupleEnumeratorBase::increaseStage()
320
{
321
943
  d_changePrefix = d_variableCount;  // simply reset upon increase stage
322
943
  return d_env->d_increaseSum ? increaseStageSum() : increaseStageMax();
323
}
324
325
943
bool TermTupleEnumeratorBase::increaseStageMax()
326
{
327
943
  d_currentStage++;
328
943
  if (d_currentStage >= d_stageCount)
329
  {
330
311
    return false;
331
  }
332
632
  Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..." << std::endl;
333
  // skipping some elements that have already been definitely seen
334
  // find the least significant digit that can be set to the current stage
335
  // TODO: should we skip all?
336
632
  std::fill(d_termIndex.begin(), d_termIndex.end(), 0);
337
632
  bool found = false;
338
1463
  for (size_t digit = d_termIndex.size(); !found && digit--;)
339
  {
340
831
    if (d_termsSizes[digit] > d_currentStage)
341
    {
342
632
      found = true;
343
632
      d_termIndex[digit] = d_currentStage;
344
    }
345
  }
346
632
  Assert(found);
347
632
  return found;
348
}
349
350
1488
bool TermTupleEnumeratorBase::nextCombination()
351
{
352
  while (true)
353
  {
354
1570
    Trace("inst-alg-rd") << "changePrefix " << d_changePrefix << std::endl;
355
1488
    if (!nextCombinationInternal() && !increaseStage())
356
    {
357
311
      return false;  // ran out of combinations
358
    }
359
1177
    if (!d_disabledCombinations.find(d_termIndex, d_changePrefix))
360
    {
361
1095
      return true;  // current combination vetted by disabled combinations
362
    }
363
  }
364
}
365
366
/** Move onto the next combination, depending on the strategy. */
367
1488
bool TermTupleEnumeratorBase::nextCombinationInternal()
368
{
369
1488
  return d_env->d_increaseSum ? nextCombinationSum() : nextCombinationMax();
370
}
371
372
/** Find the next lexicographically smallest combination of terms that change
373
 * on the change prefix and their sum is equal to d_currentStage. */
374
1488
bool TermTupleEnumeratorBase::nextCombinationMax()
375
{
376
  // look for the least significant digit, within change prefix,
377
  // that can be increased
378
1488
  bool found = false;
379
1488
  size_t increaseDigit = d_changePrefix;
380
7960
  while (!found && increaseDigit--)
381
  {
382
3236
    const size_t new_value = d_termIndex[increaseDigit] + 1;
383
3236
    if (new_value < d_termsSizes[increaseDigit] && new_value <= d_currentStage)
384
    {
385
545
      d_termIndex[increaseDigit] = new_value;
386
      // send everything after the increased digit to 0
387
545
      std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
388
545
      found = true;
389
    }
390
  }
391
1488
  if (!found)
392
  {
393
943
    return false;  // nothing to increase
394
  }
395
  // check if the combination has at least one digit in the current stage,
396
  // since at least one digit was increased, no need for this in stage 1
397
545
  bool inStage = d_currentStage <= 1;
398
883
  for (size_t i = increaseDigit + 1; !inStage && i--;)
399
  {
400
338
    inStage = d_termIndex[i] >= d_currentStage;
401
  }
402
545
  if (!inStage)  // look for a digit that can increase to current stage
403
  {
404
90
    for (increaseDigit = d_variableCount, found = false;
405
90
         !found && increaseDigit--;)
406
    {
407
45
      found = d_termsSizes[increaseDigit] > d_currentStage;
408
    }
409
45
    if (!found)
410
    {
411
      return false;  // nothing to increase to the current stage
412
    }
413
45
    Assert(d_termsSizes[increaseDigit] > d_currentStage
414
           && d_termIndex[increaseDigit] < d_currentStage);
415
45
    d_termIndex[increaseDigit] = d_currentStage;
416
    // send everything after the increased digit to 0
417
45
    std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
418
  }
419
545
  return true;
420
}
421
422
/** Find the next lexicographically smallest combination of terms that change
423
 * on the change prefix, each digit is within the current state,  and there is
424
 * at least one digit not in the previous stage. */
425
bool TermTupleEnumeratorBase::nextCombinationSum()
426
{
427
  size_t suffixSum = 0;
428
  bool found = false;
429
  size_t increaseDigit = d_termIndex.size();
430
  while (increaseDigit--)
431
  {
432
    const size_t newValue = d_termIndex[increaseDigit] + 1;
433
    found = suffixSum > 0 && newValue < d_termsSizes[increaseDigit]
434
            && increaseDigit < d_changePrefix;
435
    if (found)
436
    {
437
      // digit can be increased and suffix can be decreased
438
      d_termIndex[increaseDigit] = newValue;
439
      break;
440
    }
441
    suffixSum += d_termIndex[increaseDigit];
442
    d_termIndex[increaseDigit] = 0;
443
  }
444
  if (!found)
445
  {
446
    return false;
447
  }
448
  Assert(suffixSum > 0);
449
  // increaseDigit went up by one, hence, distribute (suffixSum - 1) in the
450
  // least significant digits
451
  suffixSum--;
452
  for (size_t digit = d_termIndex.size(); suffixSum > 0 && digit--;)
453
  {
454
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
455
    d_termIndex[digit] = std::min(suffixSum, maxValue);
456
    suffixSum -= d_termIndex[digit];
457
  }
458
  Assert(suffixSum == 0);  // everything should have been distributed
459
  return true;
460
}
461
462
249
size_t TermTupleEnumeratorBasic::prepareTerms(size_t variableIx)
463
{
464
498
  const TypeNode type_node = d_typeCache[variableIx];
465
249
  if (!ContainsKey(d_termDbList, type_node))
466
  {
467
193
    const size_t ground_terms_count = d_tdb->getNumTypeGroundTerms(type_node);
468
386
    std::map<Node, Node> repsFound;
469
10067
    for (size_t j = 0; j < ground_terms_count; j++)
470
    {
471
19748
      Node gt = d_tdb->getTypeGroundTerm(type_node, j);
472
9874
      if (!options::cegqi() || !quantifiers::TermUtil::hasInstConstAttr(gt))
473
      {
474
19748
        Node rep = d_qs.getRepresentative(gt);
475
9874
        if (repsFound.find(rep) == repsFound.end())
476
        {
477
602
          repsFound[rep] = gt;
478
602
          d_termDbList[type_node].push_back(gt);
479
        }
480
      }
481
    }
482
  }
483
484
498
  Trace("inst-alg-rd") << "Instantiation Terms for child " << variableIx << ": "
485
249
                       << d_termDbList[type_node] << std::endl;
486
498
  return d_termDbList[type_node].size();
487
}
488
489
1160
Node TermTupleEnumeratorBasic::getTerm(size_t variableIx, size_t term_index)
490
{
491
2320
  const TypeNode type_node = d_typeCache[variableIx];
492
1160
  Assert(term_index < d_termDbList[type_node].size());
493
2320
  return d_termDbList[type_node][term_index];
494
}
495
496
/**
497
 * Enumerate ground terms as they come from a user-provided term pool
498
 */
499
class TermTupleEnumeratorPool : public TermTupleEnumeratorBase
500
{
501
 public:
502
1
  TermTupleEnumeratorPool(Node quantifier,
503
                          const TermTupleEnumeratorEnv* env,
504
                          TermPools* tp,
505
                          Node pool)
506
1
      : TermTupleEnumeratorBase(quantifier, env), d_tp(tp), d_pool(pool)
507
  {
508
1
    Assert(d_pool.getKind() == kind::INST_POOL);
509
1
  }
510
511
2
  virtual ~TermTupleEnumeratorPool() = default;
512
513
 protected:
514
  /** Pointer to the term pool utility */
515
  TermPools* d_tp;
516
  /** The pool annotation */
517
  Node d_pool;
518
  /**  a list of terms for each id */
519
  std::map<size_t, std::vector<Node> > d_poolList;
520
  /** gets the terms from the pool */
521
1
  size_t prepareTerms(size_t variableIx) override
522
  {
523
1
    Assert(d_pool.getNumChildren() > variableIx);
524
    // prepare terms from pool
525
1
    d_poolList[variableIx].clear();
526
1
    d_tp->getTermsForPool(d_pool[variableIx], d_poolList[variableIx]);
527
2
    Trace("pool-inst") << "Instantiation Terms for child " << variableIx << ": "
528
1
                       << d_poolList[variableIx] << std::endl;
529
1
    return d_poolList[variableIx].size();
530
  }
531
2
  Node getTerm(size_t variableIx, size_t term_index) override
532
  {
533
2
    Assert(term_index < d_poolList[variableIx].size());
534
2
    return d_poolList[variableIx][term_index];
535
  }
536
};
537
538
166
TermTupleEnumeratorInterface* mkTermTupleEnumerator(
539
    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs, TermDb* td)
540
{
541
  return static_cast<TermTupleEnumeratorInterface*>(
542
166
      new TermTupleEnumeratorBasic(q, env, qs, td));
543
}
544
416
TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
545
    Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd)
546
{
547
  return static_cast<TermTupleEnumeratorInterface*>(
548
416
      new TermTupleEnumeratorRD(q, env, rd));
549
}
550
551
1
TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool(
552
    Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node pool)
553
{
554
  return static_cast<TermTupleEnumeratorInterface*>(
555
1
      new TermTupleEnumeratorPool(q, env, tp, pool));
556
}
557
558
}  // namespace quantifiers
559
}  // namespace theory
560
22746
}  // namespace cvc5