GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/ext_theory.cpp Lines: 162 250 64.8 %
Date: 2021-08-01 Branches: 276 713 38.7 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Extended theory interface.
14
 *
15
 * This implements a generic module, used by theory solvers, for performing
16
 * "context-dependent simplification", as described in Reynolds et al
17
 * "Designing Theory Solvers with Extensions", FroCoS 2017.
18
 */
19
20
#include "theory/ext_theory.h"
21
22
#include "base/check.h"
23
#include "smt/smt_statistics_registry.h"
24
#include "theory/output_channel.h"
25
#include "theory/quantifiers_engine.h"
26
#include "theory/rewriter.h"
27
#include "theory/substitutions.h"
28
29
using namespace std;
30
31
namespace cvc5 {
32
namespace theory {
33
34
const char* toString(ExtReducedId id)
35
{
36
  switch (id)
37
  {
38
    case ExtReducedId::SR_CONST: return "SR_CONST";
39
    case ExtReducedId::REDUCTION: return "REDUCTION";
40
    case ExtReducedId::ARITH_SR_ZERO: return "ARITH_SR_ZERO";
41
    case ExtReducedId::ARITH_SR_LINEAR: return "ARITH_SR_LINEAR";
42
    case ExtReducedId::STRINGS_SR_CONST: return "STRINGS_SR_CONST";
43
    case ExtReducedId::STRINGS_NEG_CTN_DEQ: return "STRINGS_NEG_CTN_DEQ";
44
    case ExtReducedId::STRINGS_POS_CTN: return "STRINGS_POS_CTN";
45
    case ExtReducedId::STRINGS_CTN_DECOMPOSE: return "STRINGS_CTN_DECOMPOSE";
46
    case ExtReducedId::STRINGS_REGEXP_INTER: return "STRINGS_REGEXP_INTER";
47
    case ExtReducedId::STRINGS_REGEXP_INTER_SUBSUME:
48
      return "STRINGS_REGEXP_INTER_SUBSUME";
49
    case ExtReducedId::STRINGS_REGEXP_INCLUDE: return "STRINGS_REGEXP_INCLUDE";
50
    case ExtReducedId::STRINGS_REGEXP_INCLUDE_NEG:
51
      return "STRINGS_REGEXP_INCLUDE_NEG";
52
    default: return "?ExtReducedId?";
53
  }
54
}
55
56
std::ostream& operator<<(std::ostream& out, ExtReducedId id)
57
{
58
  out << toString(id);
59
  return out;
60
}
61
62
bool ExtTheoryCallback::getCurrentSubstitution(
63
    int effort,
64
    const std::vector<Node>& vars,
65
    std::vector<Node>& subs,
66
    std::map<Node, std::vector<Node> >& exp)
67
{
68
  return false;
69
}
70
bool ExtTheoryCallback::isExtfReduced(
71
    int effort, Node n, Node on, std::vector<Node>& exp, ExtReducedId& id)
72
{
73
  id = ExtReducedId::SR_CONST;
74
  return n.isConst();
75
}
76
bool ExtTheoryCallback::getReduction(int effort,
77
                                    Node n,
78
                                    Node& nr,
79
                                    bool& isSatDep)
80
{
81
  return false;
82
}
83
84
14972
ExtTheory::ExtTheory(ExtTheoryCallback& p,
85
                     context::Context* c,
86
                     context::UserContext* u,
87
14972
                     OutputChannel& out)
88
    : d_parent(p),
89
      d_out(out),
90
      d_ext_func_terms(c),
91
      d_extfExtReducedIdMap(c),
92
      d_ci_inactive(u),
93
      d_has_extf(c),
94
      d_lemmas(u),
95
14972
      d_pp_lemmas(u)
96
{
97
14972
  d_true = NodeManager::currentNM()->mkConst(true);
98
14972
}
99
100
// Gets all leaf terms in n.
101
15997
std::vector<Node> ExtTheory::collectVars(Node n)
102
{
103
15997
  std::vector<Node> vars;
104
31994
  std::set<Node> visited;
105
31994
  std::vector<Node> worklist;
106
15997
  worklist.push_back(n);
107
216289
  while (!worklist.empty())
108
  {
109
167422
    Node current = worklist.back();
110
100146
    worklist.pop_back();
111
100146
    if (current.isConst() || visited.count(current) > 0)
112
    {
113
32870
      continue;
114
    }
115
67276
    visited.insert(current);
116
    // Treat terms not belonging to this theory as leaf
117
    // note : chould include terms not belonging to this theory
118
    // (commented below)
119
67276
    if (current.getNumChildren() > 0)
120
    {
121
43221
      worklist.insert(worklist.end(), current.begin(), current.end());
122
    }
123
    else
124
    {
125
24055
      vars.push_back(current);
126
    }
127
  }
128
31994
  return vars;
129
}
130
131
Node ExtTheory::getSubstitutedTerm(int effort,
132
                                   Node term,
133
                                   std::vector<Node>& exp)
134
{
135
  std::vector<Node> terms;
136
  terms.push_back(term);
137
  std::vector<Node> sterms;
138
  std::vector<std::vector<Node> > exps;
139
  getSubstitutedTerms(effort, terms, sterms, exps);
140
  Assert(sterms.size() == 1);
141
  Assert(exps.size() == 1);
142
  exp.insert(exp.end(), exps[0].begin(), exps[0].end());
143
  return sterms[0];
144
}
145
146
// do inferences
147
31878
void ExtTheory::getSubstitutedTerms(int effort,
148
                                    const std::vector<Node>& terms,
149
                                    std::vector<Node>& sterms,
150
                                    std::vector<std::vector<Node> >& exp)
151
{
152
63756
  Trace("extt-debug") << "getSubstitutedTerms for " << terms.size() << " / "
153
63756
                      << d_ext_func_terms.size() << " extended functions."
154
31878
                      << std::endl;
155
31878
  if (!terms.empty())
156
  {
157
    // all variables we need to find a substitution for
158
8118
    std::vector<Node> vars;
159
8118
    std::vector<Node> sub;
160
8118
    std::map<Node, std::vector<Node> > expc;
161
34074
    for (const Node& n : terms)
162
    {
163
      // do substitution, rewrite
164
30015
      std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
165
30015
      Assert(iti != d_extf_info.end());
166
78896
      for (const Node& v : iti->second.d_vars)
167
      {
168
48881
        if (std::find(vars.begin(), vars.end(), v) == vars.end())
169
        {
170
22307
          vars.push_back(v);
171
        }
172
      }
173
    }
174
4059
    bool useSubs = d_parent.getCurrentSubstitution(effort, vars, sub, expc);
175
    // get the current substitution for all variables
176
4059
    Assert(!useSubs || vars.size() == sub.size());
177
34074
    for (const Node& n : terms)
178
    {
179
60030
      Node ns = n;
180
60030
      std::vector<Node> expn;
181
30015
      if (useSubs)
182
      {
183
        // do substitution
184
17875
        ns = n.substitute(vars.begin(), vars.end(), sub.begin(), sub.end());
185
17875
        if (ns != n)
186
        {
187
          // build explanation: explanation vars = sub for each vars in FV(n)
188
5168
          std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
189
5168
          Assert(iti != d_extf_info.end());
190
14960
          for (const Node& v : iti->second.d_vars)
191
          {
192
9792
            std::map<Node, std::vector<Node> >::iterator itx = expc.find(v);
193
9792
            if (itx != expc.end())
194
            {
195
11612
              for (const Node& e : itx->second)
196
              {
197
5806
                if (std::find(expn.begin(), expn.end(), e) == expn.end())
198
                {
199
5806
                  expn.push_back(e);
200
                }
201
              }
202
            }
203
          }
204
        }
205
35750
        Trace("extt-debug") << "  have " << n << " == " << ns
206
17875
                            << ", exp size=" << expn.size() << "." << std::endl;
207
      }
208
      // add to vector
209
30015
      sterms.push_back(ns);
210
30015
      exp.push_back(expn);
211
    }
212
  }
213
31878
}
214
215
31878
bool ExtTheory::doInferencesInternal(int effort,
216
                                     const std::vector<Node>& terms,
217
                                     std::vector<Node>& nred,
218
                                     bool batch,
219
                                     bool isRed)
220
{
221
31878
  if (batch)
222
  {
223
31878
    bool addedLemma = false;
224
31878
    if (isRed)
225
    {
226
      for (const Node& n : terms)
227
      {
228
        Node nr;
229
        // note: could do reduction with substitution here
230
        bool satDep = false;
231
        if (!d_parent.getReduction(effort, n, nr, satDep))
232
        {
233
          nred.push_back(n);
234
        }
235
        else
236
        {
237
          if (!nr.isNull() && n != nr)
238
          {
239
            Node lem = NodeManager::currentNM()->mkNode(kind::EQUAL, n, nr);
240
            if (sendLemma(lem, true))
241
            {
242
              Trace("extt-lemma")
243
                  << "ExtTheory : reduction lemma : " << lem << std::endl;
244
              addedLemma = true;
245
            }
246
          }
247
          markReduced(n, ExtReducedId::REDUCTION, satDep);
248
        }
249
      }
250
    }
251
    else
252
    {
253
63756
      std::vector<Node> sterms;
254
63756
      std::vector<std::vector<Node> > exp;
255
31878
      getSubstitutedTerms(effort, terms, sterms, exp);
256
63756
      std::map<Node, unsigned> sterm_index;
257
31878
      NodeManager* nm = NodeManager::currentNM();
258
61893
      for (unsigned i = 0, size = terms.size(); i < size; i++)
259
      {
260
30015
        bool processed = false;
261
30015
        if (sterms[i] != terms[i])
262
        {
263
10336
          Node sr = Rewriter::rewrite(sterms[i]);
264
          // ask the theory if this term is reduced, e.g. is it constant or it
265
          // is a non-extf term.
266
          ExtReducedId id;
267
5168
          if (d_parent.isExtfReduced(effort, sr, terms[i], exp[i], id))
268
          {
269
3796
            processed = true;
270
3796
            markReduced(terms[i], id);
271
            // We have exp[i] => terms[i] = sr, convert this to a clause.
272
            // This ensures the proof infrastructure can process this as a
273
            // normal theory lemma.
274
7592
            Node eq = terms[i].eqNode(sr);
275
7592
            Node lem = eq;
276
3796
            if (!exp[i].empty())
277
            {
278
7592
              std::vector<Node> eei;
279
7997
              for (const Node& e : exp[i])
280
              {
281
4201
                eei.push_back(e.negate());
282
              }
283
3796
              eei.push_back(eq);
284
3796
              lem = nm->mkNode(kind::OR, eei);
285
            }
286
287
7592
            Trace("extt-debug") << "ExtTheory::doInferences : infer : " << eq
288
3796
                                << " by " << exp[i] << std::endl;
289
3796
            Trace("extt-debug") << "...send lemma " << lem << std::endl;
290
3796
            if (sendLemma(lem))
291
            {
292
2444
              Trace("extt-lemma")
293
1222
                  << "ExtTheory : substitution + rewrite lemma : " << lem
294
1222
                  << std::endl;
295
1222
              addedLemma = true;
296
            }
297
          }
298
          else
299
          {
300
            // check if we have already reduced this
301
1372
            std::map<Node, unsigned>::iterator itsi = sterm_index.find(sr);
302
1372
            if (itsi == sterm_index.end())
303
            {
304
1104
              sterm_index[sr] = i;
305
            }
306
            else
307
            {
308
              // unsigned j = itsi->second;
309
              // note : can add (non-reducing) lemma :
310
              //   exp[j] ^ exp[i] => sterms[i] = sterms[j]
311
            }
312
313
1372
            Trace("extt-nred") << "Non-reduced term : " << sr << std::endl;
314
          }
315
        }
316
        else
317
        {
318
24847
          Trace("extt-nred") << "Non-reduced term : " << sterms[i] << std::endl;
319
        }
320
30015
        if (!processed)
321
        {
322
26219
          nred.push_back(terms[i]);
323
        }
324
      }
325
    }
326
31878
    return addedLemma;
327
  }
328
  // non-batch
329
  std::vector<Node> nnred;
330
  if (terms.empty())
331
  {
332
    for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
333
         it != d_ext_func_terms.end();
334
         ++it)
335
    {
336
      if ((*it).second && !isContextIndependentInactive((*it).first))
337
      {
338
        std::vector<Node> nterms;
339
        nterms.push_back((*it).first);
340
        if (doInferencesInternal(effort, nterms, nnred, true, isRed))
341
        {
342
          return true;
343
        }
344
      }
345
    }
346
  }
347
  else
348
  {
349
    for (const Node& n : terms)
350
    {
351
      std::vector<Node> nterms;
352
      nterms.push_back(n);
353
      if (doInferencesInternal(effort, nterms, nnred, true, isRed))
354
      {
355
        return true;
356
      }
357
    }
358
  }
359
  return false;
360
}
361
362
3796
bool ExtTheory::sendLemma(Node lem, bool preprocess)
363
{
364
3796
  if (preprocess)
365
  {
366
    if (d_pp_lemmas.find(lem) == d_pp_lemmas.end())
367
    {
368
      d_pp_lemmas.insert(lem);
369
      d_out.lemma(lem);
370
      return true;
371
    }
372
  }
373
  else
374
  {
375
3796
    if (d_lemmas.find(lem) == d_lemmas.end())
376
    {
377
1222
      d_lemmas.insert(lem);
378
1222
      d_out.lemma(lem);
379
1222
      return true;
380
    }
381
  }
382
2574
  return false;
383
}
384
385
bool ExtTheory::doInferences(int effort,
386
                             const std::vector<Node>& terms,
387
                             std::vector<Node>& nred,
388
                             bool batch)
389
{
390
  if (!terms.empty())
391
  {
392
    return doInferencesInternal(effort, terms, nred, batch, false);
393
  }
394
  return false;
395
}
396
397
31878
bool ExtTheory::doInferences(int effort, std::vector<Node>& nred, bool batch)
398
{
399
63756
  std::vector<Node> terms = getActive();
400
63756
  return doInferencesInternal(effort, terms, nred, batch, false);
401
}
402
403
bool ExtTheory::doReductions(int effort,
404
                             const std::vector<Node>& terms,
405
                             std::vector<Node>& nred,
406
                             bool batch)
407
{
408
  if (!terms.empty())
409
  {
410
    return doInferencesInternal(effort, terms, nred, batch, true);
411
  }
412
  return false;
413
}
414
415
bool ExtTheory::doReductions(int effort, std::vector<Node>& nred, bool batch)
416
{
417
  const std::vector<Node> terms = getActive();
418
  return doInferencesInternal(effort, terms, nred, batch, true);
419
}
420
421
// Register term.
422
646302
void ExtTheory::registerTerm(Node n)
423
{
424
646302
  if (d_extf_kind.find(n.getKind()) != d_extf_kind.end())
425
  {
426
142111
    if (d_ext_func_terms.find(n) == d_ext_func_terms.end())
427
    {
428
15997
      Trace("extt-debug") << "Found extended function : " << n << std::endl;
429
15997
      d_ext_func_terms[n] = true;
430
15997
      d_has_extf = n;
431
15997
      d_extf_info[n].d_vars = collectVars(n);
432
    }
433
  }
434
646302
}
435
436
// mark reduced
437
126114
void ExtTheory::markReduced(Node n, ExtReducedId rid, bool satDep)
438
{
439
126114
  Trace("extt-debug") << "Mark reduced " << n << std::endl;
440
126114
  registerTerm(n);
441
126114
  Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end());
442
126114
  d_ext_func_terms[n] = false;
443
126114
  d_extfExtReducedIdMap[n] = rid;
444
126114
  if (!satDep)
445
  {
446
    d_ci_inactive[n] = rid;
447
  }
448
449
  // update has_extf
450
126114
  if (d_has_extf.get() == n)
451
  {
452
174716
    for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
453
174716
         it != d_ext_func_terms.end();
454
         ++it)
455
    {
456
      // if not already reduced
457
168110
      if ((*it).second && !isContextIndependentInactive((*it).first))
458
      {
459
53136
        d_has_extf = (*it).first;
460
      }
461
    }
462
  }
463
126114
}
464
465
644919
bool ExtTheory::isContextIndependentInactive(Node n) const
466
{
467
644919
  ExtReducedId rid = ExtReducedId::UNKNOWN;
468
644919
  return isContextIndependentInactive(n, rid);
469
}
470
471
904503
bool ExtTheory::isContextIndependentInactive(Node n, ExtReducedId& rid) const
472
{
473
904503
  NodeExtReducedIdMap::iterator it = d_ci_inactive.find(n);
474
904503
  if (it != d_ci_inactive.end())
475
  {
476
    rid = it->second;
477
    return true;
478
  }
479
904503
  return false;
480
}
481
482
3015
void ExtTheory::getTerms(std::vector<Node>& terms)
483
{
484
25201
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
485
25201
       it != d_ext_func_terms.end();
486
       ++it)
487
  {
488
22186
    terms.push_back((*it).first);
489
  }
490
3015
}
491
492
bool ExtTheory::hasActiveTerm() const { return !d_has_extf.get().isNull(); }
493
494
259584
bool ExtTheory::isActive(Node n) const
495
{
496
259584
  ExtReducedId rid = ExtReducedId::UNKNOWN;
497
259584
  return isActive(n, rid);
498
}
499
500
259584
bool ExtTheory::isActive(Node n, ExtReducedId& rid) const
501
{
502
259584
  NodeBoolMap::const_iterator it = d_ext_func_terms.find(n);
503
259584
  if (it != d_ext_func_terms.end())
504
  {
505
259584
    if ((*it).second)
506
    {
507
259584
      return !isContextIndependentInactive(n, rid);
508
    }
509
    NodeExtReducedIdMap::const_iterator itr = d_extfExtReducedIdMap.find(n);
510
    Assert(itr != d_extfExtReducedIdMap.end());
511
    rid = itr->second;
512
    return false;
513
  }
514
  return false;
515
}
516
517
// get active
518
96366
std::vector<Node> ExtTheory::getActive() const
519
{
520
96366
  std::vector<Node> active;
521
1022237
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
522
1022237
       it != d_ext_func_terms.end();
523
       ++it)
524
  {
525
    // if not already reduced
526
925871
    if ((*it).second && !isContextIndependentInactive((*it).first))
527
    {
528
590490
      active.push_back((*it).first);
529
    }
530
  }
531
96366
  return active;
532
}
533
534
8362
std::vector<Node> ExtTheory::getActive(Kind k) const
535
{
536
8362
  std::vector<Node> active;
537
30988
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
538
30988
       it != d_ext_func_terms.end();
539
       ++it)
540
  {
541
    // if not already reduced
542
52310
    if ((*it).first.getKind() == k && (*it).second
543
46545
        && !isContextIndependentInactive((*it).first))
544
    {
545
1293
      active.push_back((*it).first);
546
    }
547
  }
548
8362
  return active;
549
}
550
551
}  // namespace theory
552
29280
}  // namespace cvc5