GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/term_tuple_enumerator.cpp Lines: 157 207 75.8 %
Date: 2021-11-06 Branches: 241 666 36.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   MikolasJanota, Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of an enumeration of tuples of terms for the purpose of
14
 * quantifier instantiation.
15
 */
16
#include "theory/quantifiers/term_tuple_enumerator.h"
17
18
#include <algorithm>
19
#include <functional>
20
#include <iterator>
21
#include <map>
22
#include <vector>
23
24
#include "base/map_util.h"
25
#include "base/output.h"
26
#include "options/quantifiers_options.h"
27
#include "smt/smt_statistics_registry.h"
28
#include "theory/quantifiers/index_trie.h"
29
#include "theory/quantifiers/quant_module.h"
30
#include "theory/quantifiers/relevant_domain.h"
31
#include "theory/quantifiers/term_pools.h"
32
#include "theory/quantifiers/term_registry.h"
33
#include "theory/quantifiers/term_util.h"
34
#include "util/statistics_stats.h"
35
36
namespace cvc5 {
37
38
template <typename T>
39
5729
static Cvc5ostream& operator<<(Cvc5ostream& out, const std::vector<T>& v)
40
{
41
5729
  out << "[ ";
42
5729
  std::copy(v.begin(), v.end(), std::ostream_iterator<T>(out, " "));
43
5729
  return out << "]";
44
}
45
46
/** Tracing purposes, printing a masked vector of indices. */
47
static void traceMaskedVector(const char* trace,
48
                              const char* name,
49
                              const std::vector<bool>& mask,
50
                              const std::vector<size_t>& values)
51
{
52
  Assert(mask.size() == values.size());
53
  Trace(trace) << name << " [ ";
54
  for (size_t variableIx = 0; variableIx < mask.size(); variableIx++)
55
  {
56
    if (mask[variableIx])
57
    {
58
      Trace(trace) << values[variableIx] << " ";
59
    }
60
    else
61
    {
62
      Trace(trace) << "_ ";
63
    }
64
  }
65
  Trace(trace) << "]" << std::endl;
66
}
67
68
namespace theory {
69
namespace quantifiers {
70
/**
71
 * Base class for enumerators of tuples of terms for the purpose of
72
 * quantification instantiation. The tuples are represented as tuples of
73
 * indices of  terms, where the tuple has as many elements as there are
74
 * quantified variables in the considered quantifier.
75
 *
76
 * Like so, we see a tuple as a number, where the digits may have different
77
 * ranges. The most significant digits are stored first.
78
 *
79
 * Tuples are enumerated  in a lexicographic order in stages. There are 2
80
 * possible strategies, either  all tuples in a given stage have the same sum of
81
 * digits, or, the maximum  over these digits is the same.
82
 * */
83
class TermTupleEnumeratorBase : public TermTupleEnumeratorInterface
84
{
85
 public:
86
  /** Initialize the class with the quantifier to be instantiated. */
87
1723
  TermTupleEnumeratorBase(Node quantifier, const TermTupleEnumeratorEnv* env)
88
1723
      : d_quantifier(quantifier),
89
3446
        d_variableCount(d_quantifier[0].getNumChildren()),
90
        d_env(env),
91
        d_stepCounter(0),
92
        d_disabledCombinations(
93
5169
            true)  // do not record combinations with no blanks
94
95
  {
96
1723
    d_changePrefix = d_variableCount;
97
1723
  }
98
99
1723
  virtual ~TermTupleEnumeratorBase() = default;
100
101
  // implementation of the TermTupleEnumeratorInterface
102
  virtual void init() override;
103
  virtual bool hasNext() override;
104
  virtual void next(/*out*/ std::vector<Node>& terms) override;
105
  virtual void failureReason(const std::vector<bool>& mask) override;
106
  // end of implementation of the TermTupleEnumeratorInterface
107
108
 protected:
109
  /** the quantifier whose variables are being instantiated */
110
  const Node d_quantifier;
111
  /** number of variables in the quantifier */
112
  const size_t d_variableCount;
113
  /** env of structures with a longer lifespan */
114
  const TermTupleEnumeratorEnv* const d_env;
115
  /** type for each variable */
116
  std::vector<TypeNode> d_typeCache;
117
  /** number of candidate terms for each variable */
118
  std::vector<size_t> d_termsSizes;
119
  /** tuple of indices of the current terms */
120
  std::vector<size_t> d_termIndex;
121
  /** total number of steps of the enumerator */
122
  uint32_t d_stepCounter;
123
124
  /** a data structure storing disabled combinations of terms */
125
  IndexTrie d_disabledCombinations;
126
127
  /** current sum/max  of digits, depending on the strategy */
128
  size_t d_currentStage;
129
  /**total number of stages*/
130
  size_t d_stageCount;
131
  /**becomes false once the enumerator runs out of options*/
132
  bool d_hasNext;
133
  /** the length of the prefix that has to be changed in the next
134
  combination, i.e.  the number of the most significant digits that need to be
135
  changed in order to escape a  useless instantiation */
136
  size_t d_changePrefix;
137
  /** Move onto the next stage */
138
  bool increaseStage();
139
  /** Move onto the next stage, sum strategy. */
140
  bool increaseStageSum();
141
  /** Move onto the next stage, max strategy. */
142
  bool increaseStageMax();
143
  /** Move on in the current stage */
144
  bool nextCombination();
145
  /** Move onto the next combination. */
146
  bool nextCombinationInternal();
147
  /** Find the next lexicographically smallest combination of terms that change
148
   * on the change prefix, each digit is within the current state,  and there is
149
   * at least one digit not in the previous stage. */
150
  bool nextCombinationSum();
151
  /** Find the next lexicographically smallest combination of terms that change
152
   * on the change prefix and their sum is equal to d_currentStage. */
153
  bool nextCombinationMax();
154
  /** Set up terms for given variable.  */
155
  virtual size_t prepareTerms(size_t variableIx) = 0;
156
  /** Get a given term for a given variable.  */
157
  CVC5_WARN_UNUSED_RESULT virtual Node getTerm(size_t variableIx,
158
                                               size_t term_index) = 0;
159
};
160
161
/**
162
 * Enumerate ground terms as they come from the term database.
163
 */
164
class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase
165
{
166
 public:
167
356
  TermTupleEnumeratorBasic(Node quantifier,
168
                           const TermTupleEnumeratorEnv* env,
169
                           QuantifiersState& qs,
170
                           TermDb* td)
171
356
      : TermTupleEnumeratorBase(quantifier, env), d_qs(qs), d_tdb(td)
172
  {
173
356
  }
174
175
712
  virtual ~TermTupleEnumeratorBasic() = default;
176
177
 protected:
178
  /**  a list of terms for each type */
179
  std::map<TypeNode, std::vector<Node> > d_termDbList;
180
  virtual size_t prepareTerms(size_t variableIx) override;
181
  virtual Node getTerm(size_t variableIx, size_t term_index) override;
182
  /** Reference to quantifiers state */
183
  QuantifiersState& d_qs;
184
  /** Pointer to term database */
185
  TermDb* d_tdb;
186
};
187
188
/**
189
 * Enumerate ground terms according to the relevant domain class.
190
 */
191
class TermTupleEnumeratorRD : public TermTupleEnumeratorBase
192
{
193
 public:
194
1363
  TermTupleEnumeratorRD(Node quantifier,
195
                        const TermTupleEnumeratorEnv* env,
196
                        RelevantDomain* rd)
197
1363
      : TermTupleEnumeratorBase(quantifier, env), d_rd(rd)
198
  {
199
1363
  }
200
2726
  virtual ~TermTupleEnumeratorRD() = default;
201
202
 protected:
203
9858
  virtual size_t prepareTerms(size_t variableIx) override
204
  {
205
9858
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms.size();
206
  }
207
17801
  virtual Node getTerm(size_t variableIx, size_t term_index) override
208
  {
209
17801
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms[term_index];
210
  }
211
  /** The relevant domain */
212
  RelevantDomain* d_rd;
213
};
214
215
1723
void TermTupleEnumeratorBase::init()
216
{
217
3446
  Trace("inst-alg-rd") << "Initializing enumeration " << d_quantifier
218
1723
                       << std::endl;
219
1723
  d_currentStage = 0;
220
1723
  d_hasNext = true;
221
1723
  d_stageCount = 1;  // in the case of full effort we do at least one stage
222
223
1723
  if (d_variableCount == 0)
224
  {
225
    d_hasNext = false;
226
    return;
227
  }
228
229
  // prepare a sequence of terms for each quantified variable
230
  // additionally initialize the cache for variable types
231
12142
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
232
  {
233
10419
    d_typeCache.push_back(d_quantifier[0][variableIx].getType());
234
10419
    const size_t termsSize = prepareTerms(variableIx);
235
20838
    Trace("inst-alg-rd") << "Variable " << variableIx << " has " << termsSize
236
10419
                         << " in relevant domain." << std::endl;
237
10419
    if (termsSize == 0 && !d_env->d_fullEffort)
238
    {
239
      d_hasNext = false;
240
      return;  // give up on an empty dommain
241
    }
242
10419
    d_termsSizes.push_back(termsSize);
243
10419
    d_stageCount = std::max(d_stageCount, termsSize);
244
  }
245
246
3446
  Trace("inst-alg-rd") << "Will do " << d_stageCount
247
1723
                       << " stages of instantiation." << std::endl;
248
1723
  d_termIndex.resize(d_variableCount, 0);
249
}
250
251
5847
bool TermTupleEnumeratorBase::hasNext()
252
{
253
5847
  if (!d_hasNext)
254
  {
255
    return false;
256
  }
257
258
5847
  if (d_stepCounter++ == 0)
259
  {  // TODO:any (nice)  way of avoiding this special if?
260
1723
    Assert(d_currentStage == 0);
261
3446
    Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..."
262
1723
                         << std::endl;
263
1723
    return true;
264
  }
265
266
  // try to find the next combination
267
4124
  return d_hasNext = nextCombination();
268
}
269
270
4116
void TermTupleEnumeratorBase::failureReason(const std::vector<bool>& mask)
271
{
272
4116
  if (Trace.isOn("inst-alg"))
273
  {
274
    traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex);
275
  }
276
4116
  d_disabledCombinations.add(mask, d_termIndex);  // record failure
277
  // update change prefix accordingly
278
6232
  for (d_changePrefix = mask.size();
279
6232
       d_changePrefix && !mask[d_changePrefix - 1];
280
2116
       d_changePrefix--)
281
    ;
282
4116
}
283
284
5168
void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
285
{
286
5168
  Trace("inst-alg-rd") << "Try instantiation: " << d_termIndex << std::endl;
287
5168
  terms.resize(d_variableCount);
288
28302
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
289
  {
290
23134
    const Node t = d_termsSizes[variableIx] == 0
291
                       ? Node::null()
292
46268
                       : getTerm(variableIx, d_termIndex[variableIx]);
293
23134
    terms[variableIx] = t;
294
23134
    Trace("inst-alg-rd") << t << "  ";
295
46268
    Assert(t.isNull()
296
           || t.getType().isComparableTo(d_quantifier[0][variableIx].getType()))
297
23134
        << "Bad type: " << t << " " << t.getType() << " "
298
23134
        << d_quantifier[0][variableIx].getType();
299
  }
300
5168
  Trace("inst-alg-rd") << std::endl;
301
5168
}
302
303
bool TermTupleEnumeratorBase::increaseStageSum()
304
{
305
  const size_t lowerBound = d_currentStage + 1;
306
  Trace("inst-alg-rd") << "Try sum " << lowerBound << "..." << std::endl;
307
  d_currentStage = 0;
308
  for (size_t digit = d_termIndex.size();
309
       d_currentStage < lowerBound && digit--;)
310
  {
311
    const size_t missing = lowerBound - d_currentStage;
312
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
313
    d_termIndex[digit] = std::min(missing, maxValue);
314
    d_currentStage += d_termIndex[digit];
315
  }
316
  return d_currentStage >= lowerBound;
317
}
318
319
2807
bool TermTupleEnumeratorBase::increaseStage()
320
{
321
2807
  d_changePrefix = d_variableCount;  // simply reset upon increase stage
322
2807
  return d_env->d_increaseSum ? increaseStageSum() : increaseStageMax();
323
}
324
325
2807
bool TermTupleEnumeratorBase::increaseStageMax()
326
{
327
2807
  d_currentStage++;
328
2807
  if (d_currentStage >= d_stageCount)
329
  {
330
679
    return false;
331
  }
332
2128
  Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..." << std::endl;
333
  // skipping some elements that have already been definitely seen
334
  // find the least significant digit that can be set to the current stage
335
  // TODO: should we skip all?
336
2128
  std::fill(d_termIndex.begin(), d_termIndex.end(), 0);
337
2128
  bool found = false;
338
4814
  for (size_t digit = d_termIndex.size(); !found && digit--;)
339
  {
340
2686
    if (d_termsSizes[digit] > d_currentStage)
341
    {
342
2128
      found = true;
343
2128
      d_termIndex[digit] = d_currentStage;
344
    }
345
  }
346
2128
  Assert(found);
347
2128
  return found;
348
}
349
350
4402
bool TermTupleEnumeratorBase::nextCombination()
351
{
352
  while (true)
353
  {
354
4680
    Trace("inst-alg-rd") << "changePrefix " << d_changePrefix << std::endl;
355
4402
    if (!nextCombinationInternal() && !increaseStage())
356
    {
357
679
      return false;  // ran out of combinations
358
    }
359
3723
    if (!d_disabledCombinations.find(d_termIndex, d_changePrefix))
360
    {
361
3445
      return true;  // current combination vetted by disabled combinations
362
    }
363
  }
364
}
365
366
/** Move onto the next combination, depending on the strategy. */
367
4402
bool TermTupleEnumeratorBase::nextCombinationInternal()
368
{
369
4402
  return d_env->d_increaseSum ? nextCombinationSum() : nextCombinationMax();
370
}
371
372
/** Find the next lexicographically smallest combination of terms that change
373
 * on the change prefix and their sum is equal to d_currentStage. */
374
4402
bool TermTupleEnumeratorBase::nextCombinationMax()
375
{
376
  // look for the least significant digit, within change prefix,
377
  // that can be increased
378
4402
  bool found = false;
379
4402
  size_t increaseDigit = d_changePrefix;
380
24274
  while (!found && increaseDigit--)
381
  {
382
9936
    const size_t new_value = d_termIndex[increaseDigit] + 1;
383
9936
    if (new_value < d_termsSizes[increaseDigit] && new_value <= d_currentStage)
384
    {
385
1595
      d_termIndex[increaseDigit] = new_value;
386
      // send everything after the increased digit to 0
387
1595
      std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
388
1595
      found = true;
389
    }
390
  }
391
4402
  if (!found)
392
  {
393
2807
    return false;  // nothing to increase
394
  }
395
  // check if the combination has at least one digit in the current stage,
396
  // since at least one digit was increased, no need for this in stage 1
397
1595
  bool inStage = d_currentStage <= 1;
398
2348
  for (size_t i = increaseDigit + 1; !inStage && i--;)
399
  {
400
753
    inStage = d_termIndex[i] >= d_currentStage;
401
  }
402
1595
  if (!inStage)  // look for a digit that can increase to current stage
403
  {
404
296
    for (increaseDigit = d_variableCount, found = false;
405
296
         !found && increaseDigit--;)
406
    {
407
148
      found = d_termsSizes[increaseDigit] > d_currentStage;
408
    }
409
148
    if (!found)
410
    {
411
      return false;  // nothing to increase to the current stage
412
    }
413
148
    Assert(d_termsSizes[increaseDigit] > d_currentStage
414
           && d_termIndex[increaseDigit] < d_currentStage);
415
148
    d_termIndex[increaseDigit] = d_currentStage;
416
    // send everything after the increased digit to 0
417
148
    std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
418
  }
419
1595
  return true;
420
}
421
422
/** Find the next lexicographically smallest combination of terms that change
423
 * on the change prefix, each digit is within the current state,  and there is
424
 * at least one digit not in the previous stage. */
425
bool TermTupleEnumeratorBase::nextCombinationSum()
426
{
427
  size_t suffixSum = 0;
428
  bool found = false;
429
  size_t increaseDigit = d_termIndex.size();
430
  while (increaseDigit--)
431
  {
432
    const size_t newValue = d_termIndex[increaseDigit] + 1;
433
    found = suffixSum > 0 && newValue < d_termsSizes[increaseDigit]
434
            && increaseDigit < d_changePrefix;
435
    if (found)
436
    {
437
      // digit can be increased and suffix can be decreased
438
      d_termIndex[increaseDigit] = newValue;
439
      break;
440
    }
441
    suffixSum += d_termIndex[increaseDigit];
442
    d_termIndex[increaseDigit] = 0;
443
  }
444
  if (!found)
445
  {
446
    return false;
447
  }
448
  Assert(suffixSum > 0);
449
  // increaseDigit went up by one, hence, distribute (suffixSum - 1) in the
450
  // least significant digits
451
  suffixSum--;
452
  for (size_t digit = d_termIndex.size(); suffixSum > 0 && digit--;)
453
  {
454
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
455
    d_termIndex[digit] = std::min(suffixSum, maxValue);
456
    suffixSum -= d_termIndex[digit];
457
  }
458
  Assert(suffixSum == 0);  // everything should have been distributed
459
  return true;
460
}
461
462
557
size_t TermTupleEnumeratorBasic::prepareTerms(size_t variableIx)
463
{
464
1114
  const TypeNode type_node = d_typeCache[variableIx];
465
557
  if (!ContainsKey(d_termDbList, type_node))
466
  {
467
426
    const size_t ground_terms_count = d_tdb->getNumTypeGroundTerms(type_node);
468
852
    std::map<Node, Node> repsFound;
469
22355
    for (size_t j = 0; j < ground_terms_count; j++)
470
    {
471
43858
      Node gt = d_tdb->getTypeGroundTerm(type_node, j);
472
21929
      if (!options::cegqi() || !quantifiers::TermUtil::hasInstConstAttr(gt))
473
      {
474
43858
        Node rep = d_qs.getRepresentative(gt);
475
21929
        if (repsFound.find(rep) == repsFound.end())
476
        {
477
2766
          repsFound[rep] = gt;
478
2766
          d_termDbList[type_node].push_back(gt);
479
        }
480
      }
481
    }
482
  }
483
484
1114
  Trace("inst-alg-rd") << "Instantiation Terms for child " << variableIx << ": "
485
557
                       << d_termDbList[type_node] << std::endl;
486
1114
  return d_termDbList[type_node].size();
487
}
488
489
3095
Node TermTupleEnumeratorBasic::getTerm(size_t variableIx, size_t term_index)
490
{
491
6190
  const TypeNode type_node = d_typeCache[variableIx];
492
3095
  Assert(term_index < d_termDbList[type_node].size());
493
6190
  return d_termDbList[type_node][term_index];
494
}
495
496
/**
497
 * Enumerate ground terms as they come from a user-provided term pool
498
 */
499
class TermTupleEnumeratorPool : public TermTupleEnumeratorBase
500
{
501
 public:
502
4
  TermTupleEnumeratorPool(Node quantifier,
503
                          const TermTupleEnumeratorEnv* env,
504
                          TermPools* tp,
505
                          Node pool)
506
4
      : TermTupleEnumeratorBase(quantifier, env), d_tp(tp), d_pool(pool)
507
  {
508
4
    Assert(d_pool.getKind() == kind::INST_POOL);
509
4
  }
510
511
8
  virtual ~TermTupleEnumeratorPool() = default;
512
513
 protected:
514
  /** Pointer to the term pool utility */
515
  TermPools* d_tp;
516
  /** The pool annotation */
517
  Node d_pool;
518
  /**  a list of terms for each id */
519
  std::map<size_t, std::vector<Node> > d_poolList;
520
  /** gets the terms from the pool */
521
4
  size_t prepareTerms(size_t variableIx) override
522
  {
523
4
    Assert(d_pool.getNumChildren() > variableIx);
524
    // prepare terms from pool
525
4
    d_poolList[variableIx].clear();
526
4
    d_tp->getTermsForPool(d_pool[variableIx], d_poolList[variableIx]);
527
8
    Trace("pool-inst") << "Instantiation Terms for child " << variableIx << ": "
528
4
                       << d_poolList[variableIx] << std::endl;
529
4
    return d_poolList[variableIx].size();
530
  }
531
8
  Node getTerm(size_t variableIx, size_t term_index) override
532
  {
533
8
    Assert(term_index < d_poolList[variableIx].size());
534
8
    return d_poolList[variableIx][term_index];
535
  }
536
};
537
538
356
TermTupleEnumeratorInterface* mkTermTupleEnumerator(
539
    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs, TermDb* td)
540
{
541
  return static_cast<TermTupleEnumeratorInterface*>(
542
356
      new TermTupleEnumeratorBasic(q, env, qs, td));
543
}
544
1363
TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
545
    Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd)
546
{
547
  return static_cast<TermTupleEnumeratorInterface*>(
548
1363
      new TermTupleEnumeratorRD(q, env, rd));
549
}
550
551
4
TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool(
552
    Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node pool)
553
{
554
  return static_cast<TermTupleEnumeratorInterface*>(
555
4
      new TermTupleEnumeratorPool(q, env, tp, pool));
556
}
557
558
}  // namespace quantifiers
559
}  // namespace theory
560
31137
}  // namespace cvc5