GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/ext_theory.cpp Lines: 183 250 73.2 %
Date: 2021-05-22 Branches: 312 713 43.8 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Extended theory interface.
14
 *
15
 * This implements a generic module, used by theory solvers, for performing
16
 * "context-dependent simplification", as described in Reynolds et al
17
 * "Designing Theory Solvers with Extensions", FroCoS 2017.
18
 */
19
20
#include "theory/ext_theory.h"
21
22
#include "base/check.h"
23
#include "smt/smt_statistics_registry.h"
24
#include "theory/output_channel.h"
25
#include "theory/quantifiers_engine.h"
26
#include "theory/rewriter.h"
27
#include "theory/substitutions.h"
28
29
using namespace std;
30
31
namespace cvc5 {
32
namespace theory {
33
34
const char* toString(ExtReducedId id)
35
{
36
  switch (id)
37
  {
38
    case ExtReducedId::SR_CONST: return "SR_CONST";
39
    case ExtReducedId::REDUCTION: return "REDUCTION";
40
    case ExtReducedId::ARITH_SR_ZERO: return "ARITH_SR_ZERO";
41
    case ExtReducedId::ARITH_SR_LINEAR: return "ARITH_SR_LINEAR";
42
    case ExtReducedId::STRINGS_SR_CONST: return "STRINGS_SR_CONST";
43
    case ExtReducedId::STRINGS_NEG_CTN_DEQ: return "STRINGS_NEG_CTN_DEQ";
44
    case ExtReducedId::STRINGS_POS_CTN: return "STRINGS_POS_CTN";
45
    case ExtReducedId::STRINGS_CTN_DECOMPOSE: return "STRINGS_CTN_DECOMPOSE";
46
    case ExtReducedId::STRINGS_REGEXP_INTER: return "STRINGS_REGEXP_INTER";
47
    case ExtReducedId::STRINGS_REGEXP_INTER_SUBSUME:
48
      return "STRINGS_REGEXP_INTER_SUBSUME";
49
    case ExtReducedId::STRINGS_REGEXP_INCLUDE: return "STRINGS_REGEXP_INCLUDE";
50
    case ExtReducedId::STRINGS_REGEXP_INCLUDE_NEG:
51
      return "STRINGS_REGEXP_INCLUDE_NEG";
52
    default: return "?ExtReducedId?";
53
  }
54
}
55
56
std::ostream& operator<<(std::ostream& out, ExtReducedId id)
57
{
58
  out << toString(id);
59
  return out;
60
}
61
62
bool ExtTheoryCallback::getCurrentSubstitution(
63
    int effort,
64
    const std::vector<Node>& vars,
65
    std::vector<Node>& subs,
66
    std::map<Node, std::vector<Node> >& exp)
67
{
68
  return false;
69
}
70
bool ExtTheoryCallback::isExtfReduced(
71
    int effort, Node n, Node on, std::vector<Node>& exp, ExtReducedId& id)
72
{
73
  id = ExtReducedId::SR_CONST;
74
  return n.isConst();
75
}
76
bool ExtTheoryCallback::getReduction(int effort,
77
                                    Node n,
78
                                    Node& nr,
79
                                    bool& isSatDep)
80
{
81
  return false;
82
}
83
84
23778
ExtTheory::ExtTheory(ExtTheoryCallback& p,
85
                     context::Context* c,
86
                     context::UserContext* u,
87
23778
                     OutputChannel& out)
88
    : d_parent(p),
89
      d_out(out),
90
      d_ext_func_terms(c),
91
      d_extfExtReducedIdMap(c),
92
      d_ci_inactive(u),
93
      d_has_extf(c),
94
      d_lemmas(u),
95
23778
      d_pp_lemmas(u)
96
{
97
23778
  d_true = NodeManager::currentNM()->mkConst(true);
98
23778
}
99
100
// Gets all leaf terms in n.
101
13903
std::vector<Node> ExtTheory::collectVars(Node n)
102
{
103
13903
  std::vector<Node> vars;
104
27806
  std::set<Node> visited;
105
27806
  std::vector<Node> worklist;
106
13903
  worklist.push_back(n);
107
177941
  while (!worklist.empty())
108
  {
109
137169
    Node current = worklist.back();
110
82019
    worklist.pop_back();
111
82019
    if (current.isConst() || visited.count(current) > 0)
112
    {
113
26869
      continue;
114
    }
115
55150
    visited.insert(current);
116
    // Treat terms not belonging to this theory as leaf
117
    // note : chould include terms not belonging to this theory
118
    // (commented below)
119
55150
    if (current.getNumChildren() > 0)
120
    {
121
34205
      worklist.insert(worklist.end(), current.begin(), current.end());
122
    }
123
    else
124
    {
125
20945
      vars.push_back(current);
126
    }
127
  }
128
27806
  return vars;
129
}
130
131
Node ExtTheory::getSubstitutedTerm(int effort,
132
                                   Node term,
133
                                   std::vector<Node>& exp)
134
{
135
  std::vector<Node> terms;
136
  terms.push_back(term);
137
  std::vector<Node> sterms;
138
  std::vector<std::vector<Node> > exps;
139
  getSubstitutedTerms(effort, terms, sterms, exps);
140
  Assert(sterms.size() == 1);
141
  Assert(exps.size() == 1);
142
  exp.insert(exp.end(), exps[0].begin(), exps[0].end());
143
  return sterms[0];
144
}
145
146
// do inferences
147
52206
void ExtTheory::getSubstitutedTerms(int effort,
148
                                    const std::vector<Node>& terms,
149
                                    std::vector<Node>& sterms,
150
                                    std::vector<std::vector<Node> >& exp)
151
{
152
104412
  Trace("extt-debug") << "getSubstitutedTerms for " << terms.size() << " / "
153
104412
                      << d_ext_func_terms.size() << " extended functions."
154
52206
                      << std::endl;
155
52206
  if (!terms.empty())
156
  {
157
    // all variables we need to find a substitution for
158
8244
    std::vector<Node> vars;
159
8244
    std::vector<Node> sub;
160
8244
    std::map<Node, std::vector<Node> > expc;
161
35471
    for (const Node& n : terms)
162
    {
163
      // do substitution, rewrite
164
31349
      std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
165
31349
      Assert(iti != d_extf_info.end());
166
82509
      for (const Node& v : iti->second.d_vars)
167
      {
168
51160
        if (std::find(vars.begin(), vars.end(), v) == vars.end())
169
        {
170
22918
          vars.push_back(v);
171
        }
172
      }
173
    }
174
4122
    bool useSubs = d_parent.getCurrentSubstitution(effort, vars, sub, expc);
175
    // get the current substitution for all variables
176
4122
    Assert(!useSubs || vars.size() == sub.size());
177
35471
    for (const Node& n : terms)
178
    {
179
62698
      Node ns = n;
180
62698
      std::vector<Node> expn;
181
31349
      if (useSubs)
182
      {
183
        // do substitution
184
18833
        ns = n.substitute(vars.begin(), vars.end(), sub.begin(), sub.end());
185
18833
        if (ns != n)
186
        {
187
          // build explanation: explanation vars = sub for each vars in FV(n)
188
5798
          std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
189
5798
          Assert(iti != d_extf_info.end());
190
16752
          for (const Node& v : iti->second.d_vars)
191
          {
192
10954
            std::map<Node, std::vector<Node> >::iterator itx = expc.find(v);
193
10954
            if (itx != expc.end())
194
            {
195
13356
              for (const Node& e : itx->second)
196
              {
197
6678
                if (std::find(expn.begin(), expn.end(), e) == expn.end())
198
                {
199
6678
                  expn.push_back(e);
200
                }
201
              }
202
            }
203
          }
204
        }
205
37666
        Trace("extt-debug") << "  have " << n << " == " << ns
206
18833
                            << ", exp size=" << expn.size() << "." << std::endl;
207
      }
208
      // add to vector
209
31349
      sterms.push_back(ns);
210
31349
      exp.push_back(expn);
211
    }
212
  }
213
52206
}
214
215
52263
bool ExtTheory::doInferencesInternal(int effort,
216
                                     const std::vector<Node>& terms,
217
                                     std::vector<Node>& nred,
218
                                     bool batch,
219
                                     bool isRed)
220
{
221
52263
  if (batch)
222
  {
223
52263
    bool addedLemma = false;
224
52263
    if (isRed)
225
    {
226
231
      for (const Node& n : terms)
227
      {
228
348
        Node nr;
229
        // note: could do reduction with substitution here
230
174
        bool satDep = false;
231
174
        if (!d_parent.getReduction(effort, n, nr, satDep))
232
        {
233
          nred.push_back(n);
234
        }
235
        else
236
        {
237
174
          if (!nr.isNull() && n != nr)
238
          {
239
348
            Node lem = NodeManager::currentNM()->mkNode(kind::EQUAL, n, nr);
240
174
            if (sendLemma(lem, true))
241
            {
242
348
              Trace("extt-lemma")
243
174
                  << "ExtTheory : reduction lemma : " << lem << std::endl;
244
174
              addedLemma = true;
245
            }
246
          }
247
174
          markReduced(n, ExtReducedId::REDUCTION, satDep);
248
        }
249
      }
250
    }
251
    else
252
    {
253
104412
      std::vector<Node> sterms;
254
104412
      std::vector<std::vector<Node> > exp;
255
52206
      getSubstitutedTerms(effort, terms, sterms, exp);
256
104412
      std::map<Node, unsigned> sterm_index;
257
52206
      NodeManager* nm = NodeManager::currentNM();
258
83555
      for (unsigned i = 0, size = terms.size(); i < size; i++)
259
      {
260
31349
        bool processed = false;
261
31349
        if (sterms[i] != terms[i])
262
        {
263
11596
          Node sr = Rewriter::rewrite(sterms[i]);
264
          // ask the theory if this term is reduced, e.g. is it constant or it
265
          // is a non-extf term.
266
          ExtReducedId id;
267
5798
          if (d_parent.isExtfReduced(effort, sr, terms[i], exp[i], id))
268
          {
269
4158
            processed = true;
270
4158
            markReduced(terms[i], id);
271
            // We have exp[i] => terms[i] = sr, convert this to a clause.
272
            // This ensures the proof infrastructure can process this as a
273
            // normal theory lemma.
274
8316
            Node eq = terms[i].eqNode(sr);
275
8316
            Node lem = eq;
276
4158
            if (!exp[i].empty())
277
            {
278
8316
              std::vector<Node> eei;
279
8849
              for (const Node& e : exp[i])
280
              {
281
4691
                eei.push_back(e.negate());
282
              }
283
4158
              eei.push_back(eq);
284
4158
              lem = nm->mkNode(kind::OR, eei);
285
            }
286
287
8316
            Trace("extt-debug") << "ExtTheory::doInferences : infer : " << eq
288
4158
                                << " by " << exp[i] << std::endl;
289
4158
            Trace("extt-debug") << "...send lemma " << lem << std::endl;
290
4158
            if (sendLemma(lem))
291
            {
292
2404
              Trace("extt-lemma")
293
1202
                  << "ExtTheory : substitution + rewrite lemma : " << lem
294
1202
                  << std::endl;
295
1202
              addedLemma = true;
296
            }
297
          }
298
          else
299
          {
300
            // check if we have already reduced this
301
1640
            std::map<Node, unsigned>::iterator itsi = sterm_index.find(sr);
302
1640
            if (itsi == sterm_index.end())
303
            {
304
1372
              sterm_index[sr] = i;
305
            }
306
            else
307
            {
308
              // unsigned j = itsi->second;
309
              // note : can add (non-reducing) lemma :
310
              //   exp[j] ^ exp[i] => sterms[i] = sterms[j]
311
            }
312
313
1640
            Trace("extt-nred") << "Non-reduced term : " << sr << std::endl;
314
          }
315
        }
316
        else
317
        {
318
25551
          Trace("extt-nred") << "Non-reduced term : " << sterms[i] << std::endl;
319
        }
320
31349
        if (!processed)
321
        {
322
27191
          nred.push_back(terms[i]);
323
        }
324
      }
325
    }
326
52263
    return addedLemma;
327
  }
328
  // non-batch
329
  std::vector<Node> nnred;
330
  if (terms.empty())
331
  {
332
    for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
333
         it != d_ext_func_terms.end();
334
         ++it)
335
    {
336
      if ((*it).second && !isContextIndependentInactive((*it).first))
337
      {
338
        std::vector<Node> nterms;
339
        nterms.push_back((*it).first);
340
        if (doInferencesInternal(effort, nterms, nnred, true, isRed))
341
        {
342
          return true;
343
        }
344
      }
345
    }
346
  }
347
  else
348
  {
349
    for (const Node& n : terms)
350
    {
351
      std::vector<Node> nterms;
352
      nterms.push_back(n);
353
      if (doInferencesInternal(effort, nterms, nnred, true, isRed))
354
      {
355
        return true;
356
      }
357
    }
358
  }
359
  return false;
360
}
361
362
4332
bool ExtTheory::sendLemma(Node lem, bool preprocess)
363
{
364
4332
  if (preprocess)
365
  {
366
174
    if (d_pp_lemmas.find(lem) == d_pp_lemmas.end())
367
    {
368
174
      d_pp_lemmas.insert(lem);
369
174
      d_out.lemma(lem);
370
174
      return true;
371
    }
372
  }
373
  else
374
  {
375
4158
    if (d_lemmas.find(lem) == d_lemmas.end())
376
    {
377
1202
      d_lemmas.insert(lem);
378
1202
      d_out.lemma(lem);
379
1202
      return true;
380
    }
381
  }
382
2956
  return false;
383
}
384
385
bool ExtTheory::doInferences(int effort,
386
                             const std::vector<Node>& terms,
387
                             std::vector<Node>& nred,
388
                             bool batch)
389
{
390
  if (!terms.empty())
391
  {
392
    return doInferencesInternal(effort, terms, nred, batch, false);
393
  }
394
  return false;
395
}
396
397
52206
bool ExtTheory::doInferences(int effort, std::vector<Node>& nred, bool batch)
398
{
399
104412
  std::vector<Node> terms = getActive();
400
104412
  return doInferencesInternal(effort, terms, nred, batch, false);
401
}
402
403
57
bool ExtTheory::doReductions(int effort,
404
                             const std::vector<Node>& terms,
405
                             std::vector<Node>& nred,
406
                             bool batch)
407
{
408
57
  if (!terms.empty())
409
  {
410
57
    return doInferencesInternal(effort, terms, nred, batch, true);
411
  }
412
  return false;
413
}
414
415
bool ExtTheory::doReductions(int effort, std::vector<Node>& nred, bool batch)
416
{
417
  const std::vector<Node> terms = getActive();
418
  return doInferencesInternal(effort, terms, nred, batch, true);
419
}
420
421
// Register term.
422
724868
void ExtTheory::registerTerm(Node n)
423
{
424
724868
  if (d_extf_kind.find(n.getKind()) != d_extf_kind.end())
425
  {
426
87470
    if (d_ext_func_terms.find(n) == d_ext_func_terms.end())
427
    {
428
13903
      Trace("extt-debug") << "Found extended function : " << n << std::endl;
429
13903
      d_ext_func_terms[n] = true;
430
13903
      d_has_extf = n;
431
13903
      d_extf_info[n].d_vars = collectVars(n);
432
    }
433
  }
434
724868
}
435
436
// mark reduced
437
73567
void ExtTheory::markReduced(Node n, ExtReducedId rid, bool satDep)
438
{
439
73567
  Trace("extt-debug") << "Mark reduced " << n << std::endl;
440
73567
  registerTerm(n);
441
73567
  Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end());
442
73567
  d_ext_func_terms[n] = false;
443
73567
  d_extfExtReducedIdMap[n] = rid;
444
73567
  if (!satDep)
445
  {
446
174
    d_ci_inactive[n] = rid;
447
  }
448
449
  // update has_extf
450
73567
  if (d_has_extf.get() == n)
451
  {
452
99604
    for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
453
99604
         it != d_ext_func_terms.end();
454
         ++it)
455
    {
456
      // if not already reduced
457
94545
      if ((*it).second && !isContextIndependentInactive((*it).first))
458
      {
459
30291
        d_has_extf = (*it).first;
460
      }
461
    }
462
  }
463
73567
}
464
465
426134
bool ExtTheory::isContextIndependentInactive(Node n) const
466
{
467
426134
  ExtReducedId rid = ExtReducedId::UNKNOWN;
468
426134
  return isContextIndependentInactive(n, rid);
469
}
470
471
596561
bool ExtTheory::isContextIndependentInactive(Node n, ExtReducedId& rid) const
472
{
473
596561
  NodeExtReducedIdMap::iterator it = d_ci_inactive.find(n);
474
596561
  if (it != d_ci_inactive.end())
475
  {
476
1225
    rid = it->second;
477
1225
    return true;
478
  }
479
595336
  return false;
480
}
481
482
3031
void ExtTheory::getTerms(std::vector<Node>& terms)
483
{
484
26486
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
485
26486
       it != d_ext_func_terms.end();
486
       ++it)
487
  {
488
23455
    terms.push_back((*it).first);
489
  }
490
3031
}
491
492
bool ExtTheory::hasActiveTerm() const { return !d_has_extf.get().isNull(); }
493
494
170427
bool ExtTheory::isActive(Node n) const
495
{
496
170427
  ExtReducedId rid = ExtReducedId::UNKNOWN;
497
170427
  return isActive(n, rid);
498
}
499
500
170427
bool ExtTheory::isActive(Node n, ExtReducedId& rid) const
501
{
502
170427
  NodeBoolMap::const_iterator it = d_ext_func_terms.find(n);
503
170427
  if (it != d_ext_func_terms.end())
504
  {
505
170427
    if ((*it).second)
506
    {
507
170427
      return !isContextIndependentInactive(n, rid);
508
    }
509
    NodeExtReducedIdMap::const_iterator itr = d_extfExtReducedIdMap.find(n);
510
    Assert(itr != d_extfExtReducedIdMap.end());
511
    rid = itr->second;
512
    return false;
513
  }
514
  return false;
515
}
516
517
// get active
518
104456
std::vector<Node> ExtTheory::getActive() const
519
{
520
104456
  std::vector<Node> active;
521
713439
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
522
713439
       it != d_ext_func_terms.end();
523
       ++it)
524
  {
525
    // if not already reduced
526
608983
    if ((*it).second && !isContextIndependentInactive((*it).first))
527
    {
528
393593
      active.push_back((*it).first);
529
    }
530
  }
531
104456
  return active;
532
}
533
534
7526
std::vector<Node> ExtTheory::getActive(Kind k) const
535
{
536
7526
  std::vector<Node> active;
537
23523
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
538
23523
       it != d_ext_func_terms.end();
539
       ++it)
540
  {
541
    // if not already reduced
542
35899
    if ((*it).first.getKind() == k && (*it).second
543
33019
        && !isContextIndependentInactive((*it).first))
544
    {
545
1025
      active.push_back((*it).first);
546
    }
547
  }
548
7526
  return active;
549
}
550
551
}  // namespace theory
552
28194
}  // namespace cvc5