GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/ext_theory.cpp Lines: 175 246 71.1 %
Date: 2021-03-22 Branches: 307 854 35.9 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file ext_theory.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Tim King, Morgan Deters
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Extended theory interface.
13
 **
14
 ** This implements a generic module, used by theory solvers, for performing
15
 ** "context-dependent simplification", as described in Reynolds et al
16
 ** "Designing Theory Solvers with Extensions", FroCoS 2017.
17
 **/
18
19
#include "theory/ext_theory.h"
20
21
#include "base/check.h"
22
#include "smt/smt_statistics_registry.h"
23
#include "theory/output_channel.h"
24
#include "theory/quantifiers_engine.h"
25
#include "theory/rewriter.h"
26
#include "theory/substitutions.h"
27
28
using namespace std;
29
30
namespace CVC4 {
31
namespace theory {
32
33
bool ExtTheoryCallback::getCurrentSubstitution(
34
    int effort,
35
    const std::vector<Node>& vars,
36
    std::vector<Node>& subs,
37
    std::map<Node, std::vector<Node> >& exp)
38
{
39
  return false;
40
}
41
bool ExtTheoryCallback::isExtfReduced(int effort,
42
                                      Node n,
43
                                      Node on,
44
                                      std::vector<Node>& exp)
45
{
46
  return n.isConst();
47
}
48
bool ExtTheoryCallback::getReduction(int effort,
49
                                    Node n,
50
                                    Node& nr,
51
                                    bool& isSatDep)
52
{
53
  return false;
54
}
55
56
22336
ExtTheory::ExtTheory(ExtTheoryCallback& p,
57
                     context::Context* c,
58
                     context::UserContext* u,
59
                     OutputChannel& out,
60
22336
                     bool cacheEnabled)
61
    : d_parent(p),
62
      d_out(out),
63
      d_ext_func_terms(c),
64
      d_ci_inactive(u),
65
      d_has_extf(c),
66
      d_lemmas(u),
67
      d_pp_lemmas(u),
68
22336
      d_cacheEnabled(cacheEnabled)
69
{
70
22336
  d_true = NodeManager::currentNM()->mkConst(true);
71
22336
}
72
73
// Gets all leaf terms in n.
74
12406
std::vector<Node> ExtTheory::collectVars(Node n)
75
{
76
12406
  std::vector<Node> vars;
77
24812
  std::set<Node> visited;
78
24812
  std::vector<Node> worklist;
79
12406
  worklist.push_back(n);
80
160138
  while (!worklist.empty())
81
  {
82
124047
    Node current = worklist.back();
83
73866
    worklist.pop_back();
84
73866
    if (current.isConst() || visited.count(current) > 0)
85
    {
86
23685
      continue;
87
    }
88
50181
    visited.insert(current);
89
    // Treat terms not belonging to this theory as leaf
90
    // note : chould include terms not belonging to this theory
91
    // (commented below)
92
50181
    if (current.getNumChildren() > 0)
93
    {
94
31426
      worklist.insert(worklist.end(), current.begin(), current.end());
95
    }
96
    else
97
    {
98
18755
      vars.push_back(current);
99
    }
100
  }
101
24812
  return vars;
102
}
103
104
Node ExtTheory::getSubstitutedTerm(int effort,
105
                                   Node term,
106
                                   std::vector<Node>& exp,
107
                                   bool useCache)
108
{
109
  if (useCache)
110
  {
111
    Assert(d_gst_cache[effort].find(term) != d_gst_cache[effort].end());
112
    exp.insert(exp.end(),
113
               d_gst_cache[effort][term].d_exp.begin(),
114
               d_gst_cache[effort][term].d_exp.end());
115
    return d_gst_cache[effort][term].d_sterm;
116
  }
117
118
  std::vector<Node> terms;
119
  terms.push_back(term);
120
  std::vector<Node> sterms;
121
  std::vector<std::vector<Node> > exps;
122
  getSubstitutedTerms(effort, terms, sterms, exps, useCache);
123
  Assert(sterms.size() == 1);
124
  Assert(exps.size() == 1);
125
  exp.insert(exp.end(), exps[0].begin(), exps[0].end());
126
  return sterms[0];
127
}
128
129
// do inferences
130
50333
void ExtTheory::getSubstitutedTerms(int effort,
131
                                    const std::vector<Node>& terms,
132
                                    std::vector<Node>& sterms,
133
                                    std::vector<std::vector<Node> >& exp,
134
                                    bool useCache)
135
{
136
50333
  if (useCache)
137
  {
138
    for (const Node& n : terms)
139
    {
140
      Assert(d_gst_cache[effort].find(n) != d_gst_cache[effort].end());
141
      sterms.push_back(d_gst_cache[effort][n].d_sterm);
142
      exp.push_back(std::vector<Node>());
143
      exp[0].insert(exp[0].end(),
144
                    d_gst_cache[effort][n].d_exp.begin(),
145
                    d_gst_cache[effort][n].d_exp.end());
146
    }
147
  }
148
  else
149
  {
150
100666
    Trace("extt-debug") << "getSubstitutedTerms for " << terms.size() << " / "
151
100666
                        << d_ext_func_terms.size() << " extended functions."
152
50333
                        << std::endl;
153
50333
    if (!terms.empty())
154
    {
155
      // all variables we need to find a substitution for
156
7974
      std::vector<Node> vars;
157
7974
      std::vector<Node> sub;
158
7974
      std::map<Node, std::vector<Node> > expc;
159
32096
      for (const Node& n : terms)
160
      {
161
        // do substitution, rewrite
162
28109
        std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
163
28109
        Assert(iti != d_extf_info.end());
164
74124
        for (const Node& v : iti->second.d_vars)
165
        {
166
46015
          if (std::find(vars.begin(), vars.end(), v) == vars.end())
167
          {
168
21163
            vars.push_back(v);
169
          }
170
        }
171
      }
172
3987
      bool useSubs = d_parent.getCurrentSubstitution(effort, vars, sub, expc);
173
      // get the current substitution for all variables
174
3987
      Assert(!useSubs || vars.size() == sub.size());
175
32096
      for (const Node& n : terms)
176
      {
177
56218
        Node ns = n;
178
56218
        std::vector<Node> expn;
179
28109
        if (useSubs)
180
        {
181
          // do substitution
182
14742
          ns = n.substitute(vars.begin(), vars.end(), sub.begin(), sub.end());
183
14742
          if (ns != n)
184
          {
185
            // build explanation: explanation vars = sub for each vars in FV(n)
186
4596
            std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
187
4596
            Assert(iti != d_extf_info.end());
188
13174
            for (const Node& v : iti->second.d_vars)
189
            {
190
8578
              std::map<Node, std::vector<Node> >::iterator itx = expc.find(v);
191
8578
              if (itx != expc.end())
192
              {
193
10660
                for (const Node& e : itx->second)
194
                {
195
5330
                  if (std::find(expn.begin(), expn.end(), e) == expn.end())
196
                  {
197
5330
                    expn.push_back(e);
198
                  }
199
                }
200
              }
201
            }
202
          }
203
29484
          Trace("extt-debug")
204
29484
              << "  have " << n << " == " << ns << ", exp size=" << expn.size()
205
14742
              << "." << std::endl;
206
        }
207
        // add to vector
208
28109
        sterms.push_back(ns);
209
28109
        exp.push_back(expn);
210
        // add to cache
211
28109
        if (d_cacheEnabled)
212
        {
213
          d_gst_cache[effort][n].d_sterm = ns;
214
          d_gst_cache[effort][n].d_exp.clear();
215
          d_gst_cache[effort][n].d_exp.insert(
216
              d_gst_cache[effort][n].d_exp.end(), expn.begin(), expn.end());
217
        }
218
      }
219
    }
220
  }
221
50333
}
222
223
50376
bool ExtTheory::doInferencesInternal(int effort,
224
                                     const std::vector<Node>& terms,
225
                                     std::vector<Node>& nred,
226
                                     bool batch,
227
                                     bool isRed)
228
{
229
50376
  if (batch)
230
  {
231
50376
    bool addedLemma = false;
232
50376
    if (isRed)
233
    {
234
168
      for (const Node& n : terms)
235
      {
236
250
        Node nr;
237
        // note: could do reduction with substitution here
238
125
        bool satDep = false;
239
125
        if (!d_parent.getReduction(effort, n, nr, satDep))
240
        {
241
          nred.push_back(n);
242
        }
243
        else
244
        {
245
125
          if (!nr.isNull() && n != nr)
246
          {
247
250
            Node lem = NodeManager::currentNM()->mkNode(kind::EQUAL, n, nr);
248
125
            if (sendLemma(lem, true))
249
            {
250
250
              Trace("extt-lemma")
251
125
                  << "ExtTheory : reduction lemma : " << lem << std::endl;
252
125
              addedLemma = true;
253
            }
254
          }
255
125
          markReduced(n, satDep);
256
        }
257
      }
258
    }
259
    else
260
    {
261
100666
      std::vector<Node> sterms;
262
100666
      std::vector<std::vector<Node> > exp;
263
50333
      getSubstitutedTerms(effort, terms, sterms, exp);
264
100666
      std::map<Node, unsigned> sterm_index;
265
50333
      NodeManager* nm = NodeManager::currentNM();
266
78442
      for (unsigned i = 0, size = terms.size(); i < size; i++)
267
      {
268
28109
        bool processed = false;
269
28109
        if (sterms[i] != terms[i])
270
        {
271
9192
          Node sr = Rewriter::rewrite(sterms[i]);
272
          // ask the theory if this term is reduced, e.g. is it constant or it
273
          // is a non-extf term.
274
4596
          if (d_parent.isExtfReduced(effort, sr, terms[i], exp[i]))
275
          {
276
3650
            processed = true;
277
3650
            markReduced(terms[i]);
278
            // We have exp[i] => terms[i] = sr, convert this to a clause.
279
            // This ensures the proof infrastructure can process this as a
280
            // normal theory lemma.
281
7300
            Node eq = terms[i].eqNode(sr);
282
7300
            Node lem = eq;
283
3650
            if (!exp[i].empty())
284
            {
285
7300
              std::vector<Node> eei;
286
7687
              for (const Node& e : exp[i])
287
              {
288
4037
                eei.push_back(e.negate());
289
              }
290
3650
              eei.push_back(eq);
291
3650
              lem = nm->mkNode(kind::OR, eei);
292
            }
293
294
7300
            Trace("extt-debug") << "ExtTheory::doInferences : infer : " << eq
295
3650
                                << " by " << exp[i] << std::endl;
296
3650
            Trace("extt-debug") << "...send lemma " << lem << std::endl;
297
3650
            if (sendLemma(lem))
298
            {
299
1956
              Trace("extt-lemma")
300
978
                  << "ExtTheory : substitution + rewrite lemma : " << lem
301
978
                  << std::endl;
302
978
              addedLemma = true;
303
            }
304
          }
305
          else
306
          {
307
            // check if we have already reduced this
308
946
            std::map<Node, unsigned>::iterator itsi = sterm_index.find(sr);
309
946
            if (itsi == sterm_index.end())
310
            {
311
932
              sterm_index[sr] = i;
312
            }
313
            else
314
            {
315
              // unsigned j = itsi->second;
316
              // note : can add (non-reducing) lemma :
317
              //   exp[j] ^ exp[i] => sterms[i] = sterms[j]
318
            }
319
320
946
            Trace("extt-nred") << "Non-reduced term : " << sr << std::endl;
321
          }
322
        }
323
        else
324
        {
325
23513
          Trace("extt-nred") << "Non-reduced term : " << sterms[i] << std::endl;
326
        }
327
28109
        if (!processed)
328
        {
329
24459
          nred.push_back(terms[i]);
330
        }
331
      }
332
    }
333
50376
    return addedLemma;
334
  }
335
  // non-batch
336
  std::vector<Node> nnred;
337
  if (terms.empty())
338
  {
339
    for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
340
         it != d_ext_func_terms.end();
341
         ++it)
342
    {
343
      if ((*it).second && !isContextIndependentInactive((*it).first))
344
      {
345
        std::vector<Node> nterms;
346
        nterms.push_back((*it).first);
347
        if (doInferencesInternal(effort, nterms, nnred, true, isRed))
348
        {
349
          return true;
350
        }
351
      }
352
    }
353
  }
354
  else
355
  {
356
    for (const Node& n : terms)
357
    {
358
      std::vector<Node> nterms;
359
      nterms.push_back(n);
360
      if (doInferencesInternal(effort, nterms, nnred, true, isRed))
361
      {
362
        return true;
363
      }
364
    }
365
  }
366
  return false;
367
}
368
369
3775
bool ExtTheory::sendLemma(Node lem, bool preprocess)
370
{
371
3775
  if (preprocess)
372
  {
373
125
    if (d_pp_lemmas.find(lem) == d_pp_lemmas.end())
374
    {
375
125
      d_pp_lemmas.insert(lem);
376
125
      d_out.lemma(lem);
377
125
      return true;
378
    }
379
  }
380
  else
381
  {
382
3650
    if (d_lemmas.find(lem) == d_lemmas.end())
383
    {
384
978
      d_lemmas.insert(lem);
385
978
      d_out.lemma(lem);
386
978
      return true;
387
    }
388
  }
389
2672
  return false;
390
}
391
392
bool ExtTheory::doInferences(int effort,
393
                             const std::vector<Node>& terms,
394
                             std::vector<Node>& nred,
395
                             bool batch)
396
{
397
  if (!terms.empty())
398
  {
399
    return doInferencesInternal(effort, terms, nred, batch, false);
400
  }
401
  return false;
402
}
403
404
50333
bool ExtTheory::doInferences(int effort, std::vector<Node>& nred, bool batch)
405
{
406
100666
  std::vector<Node> terms = getActive();
407
100666
  return doInferencesInternal(effort, terms, nred, batch, false);
408
}
409
410
43
bool ExtTheory::doReductions(int effort,
411
                             const std::vector<Node>& terms,
412
                             std::vector<Node>& nred,
413
                             bool batch)
414
{
415
43
  if (!terms.empty())
416
  {
417
43
    return doInferencesInternal(effort, terms, nred, batch, true);
418
  }
419
  return false;
420
}
421
422
bool ExtTheory::doReductions(int effort, std::vector<Node>& nred, bool batch)
423
{
424
  const std::vector<Node> terms = getActive();
425
  return doInferencesInternal(effort, terms, nred, batch, true);
426
}
427
428
// Register term.
429
653089
void ExtTheory::registerTerm(Node n)
430
{
431
653089
  if (d_extf_kind.find(n.getKind()) != d_extf_kind.end())
432
  {
433
87075
    if (d_ext_func_terms.find(n) == d_ext_func_terms.end())
434
    {
435
12406
      Trace("extt-debug") << "Found extended function : " << n << std::endl;
436
12406
      d_ext_func_terms[n] = true;
437
12406
      d_has_extf = n;
438
12406
      d_extf_info[n].d_vars = collectVars(n);
439
    }
440
  }
441
653089
}
442
443
// mark reduced
444
74669
void ExtTheory::markReduced(Node n, bool satDep)
445
{
446
74669
  Trace("extt-debug") << "Mark reduced " << n << std::endl;
447
74669
  registerTerm(n);
448
74669
  Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end());
449
74669
  d_ext_func_terms[n] = false;
450
74669
  if (!satDep)
451
  {
452
125
    d_ci_inactive.insert(n);
453
  }
454
455
  // update has_extf
456
74669
  if (d_has_extf.get() == n)
457
  {
458
102994
    for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
459
102994
         it != d_ext_func_terms.end();
460
         ++it)
461
    {
462
      // if not already reduced
463
98014
      if ((*it).second && !isContextIndependentInactive((*it).first))
464
      {
465
31573
        d_has_extf = (*it).first;
466
      }
467
    }
468
  }
469
74669
}
470
471
// mark congruent
472
void ExtTheory::markCongruent(Node a, Node b)
473
{
474
  Trace("extt-debug") << "Mark congruent : " << a << " " << b << std::endl;
475
  registerTerm(a);
476
  registerTerm(b);
477
  NodeBoolMap::const_iterator it = d_ext_func_terms.find(b);
478
  if (it != d_ext_func_terms.end())
479
  {
480
    if (d_ext_func_terms.find(a) != d_ext_func_terms.end())
481
    {
482
      d_ext_func_terms[a] = d_ext_func_terms[a] && (*it).second;
483
    }
484
    else
485
    {
486
      Assert(false);
487
    }
488
    d_ext_func_terms[b] = false;
489
  }
490
  else
491
  {
492
    Assert(false);
493
  }
494
}
495
496
596937
bool ExtTheory::isContextIndependentInactive(Node n) const
497
{
498
596937
  return d_ci_inactive.find(n) != d_ci_inactive.end();
499
}
500
501
2926
void ExtTheory::getTerms(std::vector<Node>& terms)
502
{
503
24781
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
504
24781
       it != d_ext_func_terms.end();
505
       ++it)
506
  {
507
21855
    terms.push_back((*it).first);
508
  }
509
2926
}
510
511
bool ExtTheory::hasActiveTerm() const { return !d_has_extf.get().isNull(); }
512
513
// is active
514
171067
bool ExtTheory::isActive(Node n) const
515
{
516
171067
  NodeBoolMap::const_iterator it = d_ext_func_terms.find(n);
517
171067
  if (it != d_ext_func_terms.end())
518
  {
519
171067
    return (*it).second && !isContextIndependentInactive(n);
520
  }
521
  return false;
522
}
523
524
// get active
525
101277
std::vector<Node> ExtTheory::getActive() const
526
{
527
101277
  std::vector<Node> active;
528
713278
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
529
713278
       it != d_ext_func_terms.end();
530
       ++it)
531
  {
532
    // if not already reduced
533
612001
    if ((*it).second && !isContextIndependentInactive((*it).first))
534
    {
535
392718
      active.push_back((*it).first);
536
    }
537
  }
538
101277
  return active;
539
}
540
541
7342
std::vector<Node> ExtTheory::getActive(Kind k) const
542
{
543
7342
  std::vector<Node> active;
544
22147
  for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
545
22147
       it != d_ext_func_terms.end();
546
       ++it)
547
  {
548
    // if not already reduced
549
33247
    if ((*it).first.getKind() == k && (*it).second
550
30461
        && !isContextIndependentInactive((*it).first))
551
    {
552
851
      active.push_back((*it).first);
553
    }
554
  }
555
7342
  return active;
556
}
557
558
25957
void ExtTheory::clearCache() { d_gst_cache.clear(); }
559
560
} /* CVC4::theory namespace */
561
26676
} /* CVC4 namespace */