GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/term_tuple_enumerator.cpp Lines: 155 205 75.6 %
Date: 2021-05-22 Branches: 235 632 37.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   MikolasJanota, Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of an enumeration of tuples of terms for the purpose of
14
 * quantifier instantiation.
15
 */
16
#include "theory/quantifiers/term_tuple_enumerator.h"
17
18
#include <algorithm>
19
#include <functional>
20
#include <iterator>
21
#include <map>
22
#include <vector>
23
24
#include "base/map_util.h"
25
#include "base/output.h"
26
#include "options/quantifiers_options.h"
27
#include "smt/smt_statistics_registry.h"
28
#include "theory/quantifiers/index_trie.h"
29
#include "theory/quantifiers/quant_module.h"
30
#include "theory/quantifiers/relevant_domain.h"
31
#include "theory/quantifiers/term_pools.h"
32
#include "theory/quantifiers/term_registry.h"
33
#include "theory/quantifiers/term_util.h"
34
#include "util/statistics_stats.h"
35
36
namespace cvc5 {
37
38
template <typename T>
39
8783
static Cvc5ostream& operator<<(Cvc5ostream& out, const std::vector<T>& v)
40
{
41
8783
  out << "[ ";
42
8783
  std::copy(v.begin(), v.end(), std::ostream_iterator<T>(out, " "));
43
8783
  return out << "]";
44
}
45
46
/** Tracing purposes, printing a masked vector of indices. */
47
static void traceMaskedVector(const char* trace,
48
                              const char* name,
49
                              const std::vector<bool>& mask,
50
                              const std::vector<size_t>& values)
51
{
52
  Assert(mask.size() == values.size());
53
  Trace(trace) << name << " [ ";
54
  for (size_t variableIx = 0; variableIx < mask.size(); variableIx++)
55
  {
56
    if (mask[variableIx])
57
    {
58
      Trace(trace) << values[variableIx] << " ";
59
    }
60
    else
61
    {
62
      Trace(trace) << "_ ";
63
    }
64
  }
65
  Trace(trace) << "]" << std::endl;
66
}
67
68
namespace theory {
69
namespace quantifiers {
70
/**
71
 * Base class for enumerators of tuples of terms for the purpose of
72
 * quantification instantiation. The tuples are represented as tuples of
73
 * indices of  terms, where the tuple has as many elements as there are
74
 * quantified variables in the considered quantifier.
75
 *
76
 * Like so, we see a tuple as a number, where the digits may have different
77
 * ranges. The most significant digits are stored first.
78
 *
79
 * Tuples are enumerated  in a lexicographic order in stages. There are 2
80
 * possible strategies, either  all tuples in a given stage have the same sum of
81
 * digits, or, the maximum  over these digits is the same.
82
 * */
83
class TermTupleEnumeratorBase : public TermTupleEnumeratorInterface
84
{
85
 public:
86
  /** Initialize the class with the quantifier to be instantiated. */
87
1924
  TermTupleEnumeratorBase(Node quantifier, const TermTupleEnumeratorEnv* env)
88
1924
      : d_quantifier(quantifier),
89
3848
        d_variableCount(d_quantifier[0].getNumChildren()),
90
        d_env(env),
91
        d_stepCounter(0),
92
        d_disabledCombinations(
93
5772
            true)  // do not record combinations with no blanks
94
95
  {
96
1924
    d_changePrefix = d_variableCount;
97
1924
  }
98
99
1924
  virtual ~TermTupleEnumeratorBase() = default;
100
101
  // implementation of the TermTupleEnumeratorInterface
102
  virtual void init() override;
103
  virtual bool hasNext() override;
104
  virtual void next(/*out*/ std::vector<Node>& terms) override;
105
  virtual void failureReason(const std::vector<bool>& mask) override;
106
  // end of implementation of the TermTupleEnumeratorInterface
107
108
 protected:
109
  /** the quantifier whose variables are being instantiated */
110
  const Node d_quantifier;
111
  /** number of variables in the quantifier */
112
  const size_t d_variableCount;
113
  /** env of structures with a longer lifespan */
114
  const TermTupleEnumeratorEnv* const d_env;
115
  /** type for each variable */
116
  std::vector<TypeNode> d_typeCache;
117
  /** number of candidate terms for each variable */
118
  std::vector<size_t> d_termsSizes;
119
  /** tuple of indices of the current terms */
120
  std::vector<size_t> d_termIndex;
121
  /** total number of steps of the enumerator */
122
  uint32_t d_stepCounter;
123
124
  /** a data structure storing disabled combinations of terms */
125
  IndexTrie d_disabledCombinations;
126
127
  /** current sum/max  of digits, depending on the strategy */
128
  size_t d_currentStage;
129
  /**total number of stages*/
130
  size_t d_stageCount;
131
  /**becomes false once the enumerator runs out of options*/
132
  bool d_hasNext;
133
  /** the length of the prefix that has to be changed in the next
134
  combination, i.e.  the number of the most significant digits that need to be
135
  changed in order to escape a  useless instantiation */
136
  size_t d_changePrefix;
137
  /** Move onto the next stage */
138
  bool increaseStage();
139
  /** Move onto the next stage, sum strategy. */
140
  bool increaseStageSum();
141
  /** Move onto the next stage, max strategy. */
142
  bool increaseStageMax();
143
  /** Move on in the current stage */
144
  bool nextCombination();
145
  /** Move onto the next combination. */
146
  bool nextCombinationInternal();
147
  /** Find the next lexicographically smallest combination of terms that change
148
   * on the change prefix, each digit is within the current state,  and there is
149
   * at least one digit not in the previous stage. */
150
  bool nextCombinationSum();
151
  /** Find the next lexicographically smallest combination of terms that change
152
   * on the change prefix and their sum is equal to d_currentStage. */
153
  bool nextCombinationMax();
154
  /** Set up terms for given variable.  */
155
  virtual size_t prepareTerms(size_t variableIx) = 0;
156
  /** Get a given term for a given variable.  */
157
  virtual Node getTerm(size_t variableIx,
158
                       size_t term_index) CVC5_WARN_UNUSED_RESULT = 0;
159
};
160
161
/**
162
 * Enumerate ground terms as they come from the term database.
163
 */
164
class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase
165
{
166
 public:
167
414
  TermTupleEnumeratorBasic(Node quantifier,
168
                           const TermTupleEnumeratorEnv* env,
169
                           QuantifiersState& qs,
170
                           TermDb* td)
171
414
      : TermTupleEnumeratorBase(quantifier, env), d_qs(qs), d_tdb(td)
172
  {
173
414
  }
174
175
828
  virtual ~TermTupleEnumeratorBasic() = default;
176
177
 protected:
178
  /**  a list of terms for each type */
179
  std::map<TypeNode, std::vector<Node> > d_termDbList;
180
  virtual size_t prepareTerms(size_t variableIx) override;
181
  virtual Node getTerm(size_t variableIx, size_t term_index) override;
182
  /** Reference to quantifiers state */
183
  QuantifiersState& d_qs;
184
  /** Pointer to term database */
185
  TermDb* d_tdb;
186
};
187
188
/**
189
 * Enumerate ground terms according to the relevant domain class.
190
 */
191
class TermTupleEnumeratorRD : public TermTupleEnumeratorBase
192
{
193
 public:
194
1506
  TermTupleEnumeratorRD(Node quantifier,
195
                        const TermTupleEnumeratorEnv* env,
196
                        RelevantDomain* rd)
197
1506
      : TermTupleEnumeratorBase(quantifier, env), d_rd(rd)
198
  {
199
1506
  }
200
3012
  virtual ~TermTupleEnumeratorRD() = default;
201
202
 protected:
203
10020
  virtual size_t prepareTerms(size_t variableIx) override
204
  {
205
10020
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms.size();
206
  }
207
20069
  virtual Node getTerm(size_t variableIx, size_t term_index) override
208
  {
209
20069
    return d_rd->getRDomain(d_quantifier, variableIx)->d_terms[term_index];
210
  }
211
  /** The relevant domain */
212
  RelevantDomain* d_rd;
213
};
214
215
1924
void TermTupleEnumeratorBase::init()
216
{
217
3848
  Trace("inst-alg-rd") << "Initializing enumeration " << d_quantifier
218
1924
                       << std::endl;
219
1924
  d_currentStage = 0;
220
1924
  d_hasNext = true;
221
1924
  d_stageCount = 1;  // in the case of full effort we do at least one stage
222
223
1924
  if (d_variableCount == 0)
224
  {
225
    d_hasNext = false;
226
    return;
227
  }
228
229
  // prepare a sequence of terms for each quantified variable
230
  // additionally initialize the cache for variable types
231
12611
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
232
  {
233
10687
    d_typeCache.push_back(d_quantifier[0][variableIx].getType());
234
10687
    const size_t termsSize = prepareTerms(variableIx);
235
21374
    Trace("inst-alg-rd") << "Variable " << variableIx << " has " << termsSize
236
10687
                         << " in relevant domain." << std::endl;
237
10687
    if (termsSize == 0 && !d_env->d_fullEffort)
238
    {
239
      d_hasNext = false;
240
      return;  // give up on an empty dommain
241
    }
242
10687
    d_termsSizes.push_back(termsSize);
243
10687
    d_stageCount = std::max(d_stageCount, termsSize);
244
  }
245
246
3848
  Trace("inst-alg-rd") << "Will do " << d_stageCount
247
1924
                       << " stages of instantiation." << std::endl;
248
1924
  d_termIndex.resize(d_variableCount, 0);
249
}
250
251
8922
bool TermTupleEnumeratorBase::hasNext()
252
{
253
8922
  if (!d_hasNext)
254
  {
255
    return false;
256
  }
257
258
8922
  if (d_stepCounter++ == 0)
259
  {  // TODO:any (nice)  way of avoiding this special if?
260
1924
    Assert(d_currentStage == 0);
261
3848
    Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..."
262
1924
                         << std::endl;
263
1924
    return true;
264
  }
265
266
  // try to find the next combination
267
6998
  return d_hasNext = nextCombination();
268
}
269
270
6994
void TermTupleEnumeratorBase::failureReason(const std::vector<bool>& mask)
271
{
272
6994
  if (Trace.isOn("inst-alg"))
273
  {
274
    traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex);
275
  }
276
6994
  d_disabledCombinations.add(mask, d_termIndex);  // record failure
277
  // update change prefix accordingly
278
9331
  for (d_changePrefix = mask.size();
279
9331
       d_changePrefix && !mask[d_changePrefix - 1];
280
2337
       d_changePrefix--)
281
    ;
282
6994
}
283
284
8116
void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
285
{
286
8116
  Trace("inst-alg-rd") << "Try instantiation: " << d_termIndex << std::endl;
287
8116
  terms.resize(d_variableCount);
288
35125
  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
289
  {
290
27009
    const Node t = d_termsSizes[variableIx] == 0
291
                       ? Node::null()
292
54018
                       : getTerm(variableIx, d_termIndex[variableIx]);
293
27009
    terms[variableIx] = t;
294
27009
    Trace("inst-alg-rd") << t << "  ";
295
27009
    Assert(terms[variableIx].isNull()
296
           || terms[variableIx].getType().isComparableTo(
297
               d_quantifier[0][variableIx].getType()));
298
  }
299
8116
  Trace("inst-alg-rd") << std::endl;
300
8116
}
301
302
bool TermTupleEnumeratorBase::increaseStageSum()
303
{
304
  const size_t lowerBound = d_currentStage + 1;
305
  Trace("inst-alg-rd") << "Try sum " << lowerBound << "..." << std::endl;
306
  d_currentStage = 0;
307
  for (size_t digit = d_termIndex.size();
308
       d_currentStage < lowerBound && digit--;)
309
  {
310
    const size_t missing = lowerBound - d_currentStage;
311
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
312
    d_termIndex[digit] = std::min(missing, maxValue);
313
    d_currentStage += d_termIndex[digit];
314
  }
315
  return d_currentStage >= lowerBound;
316
}
317
318
4934
bool TermTupleEnumeratorBase::increaseStage()
319
{
320
4934
  d_changePrefix = d_variableCount;  // simply reset upon increase stage
321
4934
  return d_env->d_increaseSum ? increaseStageSum() : increaseStageMax();
322
}
323
324
4934
bool TermTupleEnumeratorBase::increaseStageMax()
325
{
326
4934
  d_currentStage++;
327
4934
  if (d_currentStage >= d_stageCount)
328
  {
329
806
    return false;
330
  }
331
4128
  Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..." << std::endl;
332
  // skipping some elements that have already been definitely seen
333
  // find the least significant digit that can be set to the current stage
334
  // TODO: should we skip all?
335
4128
  std::fill(d_termIndex.begin(), d_termIndex.end(), 0);
336
4128
  bool found = false;
337
8811
  for (size_t digit = d_termIndex.size(); !found && digit--;)
338
  {
339
4683
    if (d_termsSizes[digit] > d_currentStage)
340
    {
341
4128
      found = true;
342
4128
      d_termIndex[digit] = d_currentStage;
343
    }
344
  }
345
4128
  Assert(found);
346
4128
  return found;
347
}
348
349
7281
bool TermTupleEnumeratorBase::nextCombination()
350
{
351
  while (true)
352
  {
353
7564
    Trace("inst-alg-rd") << "changePrefix " << d_changePrefix << std::endl;
354
7281
    if (!nextCombinationInternal() && !increaseStage())
355
    {
356
806
      return false;  // ran out of combinations
357
    }
358
6475
    if (!d_disabledCombinations.find(d_termIndex, d_changePrefix))
359
    {
360
6192
      return true;  // current combination vetted by disabled combinations
361
    }
362
  }
363
}
364
365
/** Move onto the next combination, depending on the strategy. */
366
7281
bool TermTupleEnumeratorBase::nextCombinationInternal()
367
{
368
7281
  return d_env->d_increaseSum ? nextCombinationSum() : nextCombinationMax();
369
}
370
371
/** Find the next lexicographically smallest combination of terms that change
372
 * on the change prefix and their sum is equal to d_currentStage. */
373
7281
bool TermTupleEnumeratorBase::nextCombinationMax()
374
{
375
  // look for the least significant digit, within change prefix,
376
  // that can be increased
377
7281
  bool found = false;
378
7281
  size_t increaseDigit = d_changePrefix;
379
34107
  while (!found && increaseDigit--)
380
  {
381
13413
    const size_t new_value = d_termIndex[increaseDigit] + 1;
382
13413
    if (new_value < d_termsSizes[increaseDigit] && new_value <= d_currentStage)
383
    {
384
2347
      d_termIndex[increaseDigit] = new_value;
385
      // send everything after the increased digit to 0
386
2347
      std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
387
2347
      found = true;
388
    }
389
  }
390
7281
  if (!found)
391
  {
392
4934
    return false;  // nothing to increase
393
  }
394
  // check if the combination has at least one digit in the current stage,
395
  // since at least one digit was increased, no need for this in stage 1
396
2347
  bool inStage = d_currentStage <= 1;
397
3873
  for (size_t i = increaseDigit + 1; !inStage && i--;)
398
  {
399
1526
    inStage = d_termIndex[i] >= d_currentStage;
400
  }
401
2347
  if (!inStage)  // look for a digit that can increase to current stage
402
  {
403
688
    for (increaseDigit = d_variableCount, found = false;
404
688
         !found && increaseDigit--;)
405
    {
406
344
      found = d_termsSizes[increaseDigit] > d_currentStage;
407
    }
408
344
    if (!found)
409
    {
410
      return false;  // nothing to increase to the current stage
411
    }
412
344
    Assert(d_termsSizes[increaseDigit] > d_currentStage
413
           && d_termIndex[increaseDigit] < d_currentStage);
414
344
    d_termIndex[increaseDigit] = d_currentStage;
415
    // send everything after the increased digit to 0
416
344
    std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
417
  }
418
2347
  return true;
419
}
420
421
/** Find the next lexicographically smallest combination of terms that change
422
 * on the change prefix, each digit is within the current state,  and there is
423
 * at least one digit not in the previous stage. */
424
bool TermTupleEnumeratorBase::nextCombinationSum()
425
{
426
  size_t suffixSum = 0;
427
  bool found = false;
428
  size_t increaseDigit = d_termIndex.size();
429
  while (increaseDigit--)
430
  {
431
    const size_t newValue = d_termIndex[increaseDigit] + 1;
432
    found = suffixSum > 0 && newValue < d_termsSizes[increaseDigit]
433
            && increaseDigit < d_changePrefix;
434
    if (found)
435
    {
436
      // digit can be increased and suffix can be decreased
437
      d_termIndex[increaseDigit] = newValue;
438
      break;
439
    }
440
    suffixSum += d_termIndex[increaseDigit];
441
    d_termIndex[increaseDigit] = 0;
442
  }
443
  if (!found)
444
  {
445
    return false;
446
  }
447
  Assert(suffixSum > 0);
448
  // increaseDigit went up by one, hence, distribute (suffixSum - 1) in the
449
  // least significant digits
450
  suffixSum--;
451
  for (size_t digit = d_termIndex.size(); suffixSum > 0 && digit--;)
452
  {
453
    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
454
    d_termIndex[digit] = std::min(suffixSum, maxValue);
455
    suffixSum -= d_termIndex[digit];
456
  }
457
  Assert(suffixSum == 0);  // everything should have been distributed
458
  return true;
459
}
460
461
663
size_t TermTupleEnumeratorBasic::prepareTerms(size_t variableIx)
462
{
463
1326
  const TypeNode type_node = d_typeCache[variableIx];
464
663
  if (!ContainsKey(d_termDbList, type_node))
465
  {
466
510
    const size_t ground_terms_count = d_tdb->getNumTypeGroundTerms(type_node);
467
1020
    std::map<Node, Node> repsFound;
468
43987
    for (size_t j = 0; j < ground_terms_count; j++)
469
    {
470
86954
      Node gt = d_tdb->getTypeGroundTerm(type_node, j);
471
43477
      if (!options::cegqi() || !quantifiers::TermUtil::hasInstConstAttr(gt))
472
      {
473
86954
        Node rep = d_qs.getRepresentative(gt);
474
43477
        if (repsFound.find(rep) == repsFound.end())
475
        {
476
4523
          repsFound[rep] = gt;
477
4523
          d_termDbList[type_node].push_back(gt);
478
        }
479
      }
480
    }
481
  }
482
483
1326
  Trace("inst-alg-rd") << "Instantiation Terms for child " << variableIx << ": "
484
663
                       << d_termDbList[type_node] << std::endl;
485
1326
  return d_termDbList[type_node].size();
486
}
487
488
4697
Node TermTupleEnumeratorBasic::getTerm(size_t variableIx, size_t term_index)
489
{
490
9394
  const TypeNode type_node = d_typeCache[variableIx];
491
4697
  Assert(term_index < d_termDbList[type_node].size());
492
9394
  return d_termDbList[type_node][term_index];
493
}
494
495
/**
496
 * Enumerate ground terms as they come from a user-provided term pool
497
 */
498
class TermTupleEnumeratorPool : public TermTupleEnumeratorBase
499
{
500
 public:
501
4
  TermTupleEnumeratorPool(Node quantifier,
502
                          const TermTupleEnumeratorEnv* env,
503
                          TermPools* tp,
504
                          Node pool)
505
4
      : TermTupleEnumeratorBase(quantifier, env), d_tp(tp), d_pool(pool)
506
  {
507
4
    Assert(d_pool.getKind() == kind::INST_POOL);
508
4
  }
509
510
8
  virtual ~TermTupleEnumeratorPool() = default;
511
512
 protected:
513
  /** Pointer to the term pool utility */
514
  TermPools* d_tp;
515
  /** The pool annotation */
516
  Node d_pool;
517
  /**  a list of terms for each id */
518
  std::map<size_t, std::vector<Node> > d_poolList;
519
  /** gets the terms from the pool */
520
4
  size_t prepareTerms(size_t variableIx) override
521
  {
522
4
    Assert(d_pool.getNumChildren() > variableIx);
523
    // prepare terms from pool
524
4
    d_poolList[variableIx].clear();
525
4
    d_tp->getTermsForPool(d_pool[variableIx], d_poolList[variableIx]);
526
8
    Trace("pool-inst") << "Instantiation Terms for child " << variableIx << ": "
527
4
                       << d_poolList[variableIx] << std::endl;
528
4
    return d_poolList[variableIx].size();
529
  }
530
4
  Node getTerm(size_t variableIx, size_t term_index) override
531
  {
532
4
    Assert(term_index < d_poolList[variableIx].size());
533
4
    return d_poolList[variableIx][term_index];
534
  }
535
};
536
537
414
TermTupleEnumeratorInterface* mkTermTupleEnumerator(
538
    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs, TermDb* td)
539
{
540
  return static_cast<TermTupleEnumeratorInterface*>(
541
414
      new TermTupleEnumeratorBasic(q, env, qs, td));
542
}
543
1506
TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
544
    Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd)
545
{
546
  return static_cast<TermTupleEnumeratorInterface*>(
547
1506
      new TermTupleEnumeratorRD(q, env, rd));
548
}
549
550
4
TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool(
551
    Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node pool)
552
{
553
  return static_cast<TermTupleEnumeratorInterface*>(
554
4
      new TermTupleEnumeratorPool(q, env, tp, pool));
555
}
556
557
}  // namespace quantifiers
558
}  // namespace theory
559
28194
}  // namespace cvc5