GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/smt/expand_definitions.cpp Lines: 140 157 89.2 %
Date: 2021-03-23 Branches: 285 672 42.4 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file expand_definitions.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Morgan Deters, Andres Noetzli
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of expand definitions for an SMT engine.
13
 **/
14
15
#include "smt/expand_definitions.h"
16
17
#include <stack>
18
#include <utility>
19
20
#include "expr/node_manager_attributes.h"
21
#include "preprocessing/assertion_pipeline.h"
22
#include "smt/defined_function.h"
23
#include "smt/smt_engine.h"
24
#include "smt/smt_engine_stats.h"
25
#include "theory/theory_engine.h"
26
27
using namespace CVC4::preprocessing;
28
using namespace CVC4::theory;
29
using namespace CVC4::kind;
30
31
namespace CVC4 {
32
namespace smt {
33
34
9621
ExpandDefs::ExpandDefs(SmtEngine& smt,
35
                       ResourceManager& rm,
36
9621
                       SmtEngineStatistics& stats)
37
9621
    : d_smt(smt), d_resourceManager(rm), d_smtStats(stats), d_tpg(nullptr)
38
{
39
9621
}
40
41
9600
ExpandDefs::~ExpandDefs() {}
42
43
30469
Node ExpandDefs::expandDefinitions(
44
    TNode n,
45
    std::unordered_map<Node, Node, NodeHashFunction>& cache,
46
    bool expandOnly)
47
{
48
60938
  TrustNode trn = expandDefinitions(n, cache, expandOnly, nullptr);
49
60938
  return trn.isNull() ? Node(n) : trn.getNode();
50
}
51
52
131944
TrustNode ExpandDefs::expandDefinitions(
53
    TNode n,
54
    std::unordered_map<Node, Node, NodeHashFunction>& cache,
55
    bool expandOnly,
56
    TConvProofGenerator* tpg)
57
{
58
263888
  const TNode orig = n;
59
131944
  NodeManager* nm = d_smt.getNodeManager();
60
263888
  std::stack<std::tuple<Node, Node, bool>> worklist;
61
263888
  std::stack<Node> result;
62
131944
  worklist.push(std::make_tuple(Node(n), Node(n), false));
63
  // The worklist is made of triples, each is input / original node then the
64
  // output / rewritten node and finally a flag tracking whether the children
65
  // have been explored (i.e. if this is a downward or upward pass).
66
67
8305072
  do
68
  {
69
8437016
    d_resourceManager.spendResource(ResourceManager::Resource::PreprocessStep);
70
71
    // n is the input / original
72
    // node is the output / result
73
13361970
    Node node;
74
    bool childrenPushed;
75
8437016
    std::tie(n, node, childrenPushed) = worklist.top();
76
8437016
    worklist.pop();
77
78
    // Working downwards
79
8437016
    if (!childrenPushed)
80
    {
81
5974539
      Kind k = n.getKind();
82
83
      // we can short circuit (variable) leaves
84
5974539
      if (n.isVar())
85
      {
86
1865852
        SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap();
87
1865852
        SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(n);
88
1865852
        if (i != dfuns->end())
89
        {
90
4964
          Node f = (*i).second.getFormula();
91
2482
          const std::vector<Node>& formals = (*i).second.getFormals();
92
          // replacement must be closed
93
2482
          if (!formals.empty())
94
          {
95
8
            f = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, formals), f);
96
          }
97
          // are we are producing proofs for this call?
98
2482
          if (tpg != nullptr)
99
          {
100
            // if this is a variable, we can simply assume it
101
            // ------- ASSUME
102
            // n = f
103
498
            Node conc = n.eqNode(f);
104
249
            tpg->addRewriteStep(n, f, PfRule::ASSUME, {}, {conc}, true);
105
          }
106
          // must recursively expand its definition
107
4964
          TrustNode tfe = expandDefinitions(f, cache, expandOnly, tpg);
108
4964
          Node fe = tfe.isNull() ? f : tfe.getNode();
109
          // don't bother putting in the cache
110
2482
          result.push(fe);
111
2482
          continue;
112
        }
113
        // don't bother putting in the cache
114
1863370
        result.push(n);
115
1863370
        continue;
116
      }
117
118
      // maybe it's in the cache
119
      std::unordered_map<Node, Node, NodeHashFunction>::iterator cacheHit =
120
4108687
          cache.find(n);
121
4108687
      if (cacheHit != cache.end())
122
      {
123
3284530
        TNode ret = (*cacheHit).second;
124
1642265
        result.push(ret.isNull() ? n : ret);
125
1642265
        continue;
126
      }
127
128
      // otherwise expand it
129
2466422
      bool doExpand = false;
130
2466422
      if (k == APPLY_UF)
131
      {
132
        // Always do beta-reduction here. The reason is that there may be
133
        // operators such as INTS_MODULUS in the body of the lambda that would
134
        // otherwise be introduced by beta-reduction via the rewriter, but are
135
        // not expanded here since the traversal in this function does not
136
        // traverse the operators of nodes. Hence, we beta-reduce here to
137
        // ensure terms in the body of the lambda are expanded during this
138
        // call.
139
139361
        if (n.getOperator().getKind() == LAMBDA)
140
        {
141
627
          doExpand = true;
142
        }
143
        else
144
        {
145
          // We always check if this operator corresponds to a defined function.
146
138734
          doExpand = d_smt.isDefinedFunction(n.getOperator());
147
        }
148
      }
149
      // the premise of the proof of expansion (if we are expanding a definition
150
      // and proofs are enabled)
151
4928899
      std::vector<Node> pfExpChildren;
152
2466422
      if (doExpand)
153
      {
154
7890
        std::vector<Node> formals;
155
7890
        TNode fm;
156
3945
        if (n.getOperator().getKind() == LAMBDA)
157
        {
158
1254
          TNode op = n.getOperator();
159
          // lambda
160
1959
          for (unsigned i = 0; i < op[0].getNumChildren(); i++)
161
          {
162
1332
            formals.push_back(op[0][i]);
163
          }
164
627
          fm = op[1];
165
        }
166
        else
167
        {
168
          // application of a user-defined symbol
169
6636
          TNode func = n.getOperator();
170
3318
          SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap();
171
3318
          SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(func);
172
3318
          if (i == dfuns->end())
173
          {
174
            throw TypeCheckingExceptionPrivate(
175
                n,
176
                std::string("Undefined function: `") + func.toString() + "'");
177
          }
178
6636
          DefinedFunction def = (*i).second;
179
3318
          formals = def.getFormals();
180
181
3318
          if (Debug.isOn("expand"))
182
          {
183
            Debug("expand") << "found: " << n << std::endl;
184
            Debug("expand") << " func: " << func << std::endl;
185
            std::string name = func.getAttribute(expr::VarNameAttr());
186
            Debug("expand") << "     : \"" << name << "\"" << std::endl;
187
          }
188
3318
          if (Debug.isOn("expand"))
189
          {
190
            Debug("expand") << " defn: " << def.getFunction() << std::endl
191
                            << "       [";
192
            if (formals.size() > 0)
193
            {
194
              copy(formals.begin(),
195
                   formals.end() - 1,
196
                   std::ostream_iterator<Node>(Debug("expand"), ", "));
197
              Debug("expand") << formals.back();
198
            }
199
            Debug("expand")
200
                << "]" << std::endl
201
                << "       " << def.getFunction().getType() << std::endl
202
                << "       " << def.getFormula() << std::endl;
203
          }
204
205
3318
          fm = def.getFormula();
206
          // are we producing proofs for this call?
207
3318
          if (tpg != nullptr)
208
          {
209
768
            Node pfRhs = fm;
210
384
            if (!formals.empty())
211
            {
212
768
              Node bvl = nm->mkNode(BOUND_VAR_LIST, formals);
213
384
              pfRhs = nm->mkNode(LAMBDA, bvl, pfRhs);
214
            }
215
384
            Assert(func.getType().isComparableTo(pfRhs.getType()));
216
384
            pfExpChildren.push_back(func.eqNode(pfRhs));
217
          }
218
        }
219
220
        Node instance = fm.substitute(formals.begin(),
221
                                      formals.end(),
222
                                      n.begin(),
223
7890
                                      n.begin() + formals.size());
224
3945
        Debug("expand") << "made : " << instance << std::endl;
225
        // are we producing proofs for this call?
226
3945
        if (tpg != nullptr)
227
        {
228
384
          if (n != instance)
229
          {
230
            // This code is run both when we are doing expand definitions and
231
            // simple beta reduction.
232
            //
233
            // f = (lambda ((x T)) t)  [if we are expanding a definition]
234
            // ---------------------- MACRO_SR_PRED_INTRO
235
            // n = instance
236
768
            Node conc = n.eqNode(instance);
237
768
            tpg->addRewriteStep(n,
238
                                instance,
239
                                PfRule::MACRO_SR_PRED_INTRO,
240
                                pfExpChildren,
241
                                {conc},
242
384
                                true);
243
          }
244
        }
245
        // now, call expand definitions again on the result
246
7890
        TrustNode texp = expandDefinitions(instance, cache, expandOnly, tpg);
247
7890
        Node expanded = texp.isNull() ? instance : texp.getNode();
248
3945
        cache[n] = n == expanded ? Node::null() : expanded;
249
3945
        result.push(expanded);
250
3945
        continue;
251
      }
252
2462477
      else if (!expandOnly)
253
      {
254
        // do not do any theory stuff if expandOnly is true
255
256
9773
        theory::Theory* t = d_smt.getTheoryEngine()->theoryOf(node);
257
258
9773
        Assert(t != NULL);
259
19546
        TrustNode trn = t->expandDefinition(n);
260
9773
        if (!trn.isNull())
261
        {
262
254
          node = trn.getNode();
263
254
          if (tpg != nullptr)
264
          {
265
            tpg->addRewriteStep(
266
                n, node, trn.getGenerator(), true, PfRule::THEORY_EXPAND_DEF);
267
          }
268
        }
269
        else
270
        {
271
9519
          node = n;
272
        }
273
      }
274
275
      // the partial functions can fall through, in which case we still
276
      // consider their children
277
7387431
      worklist.push(std::make_tuple(
278
4924954
          Node(n), node, true));  // Original and rewritten result
279
280
8305072
      for (size_t i = 0; i < node.getNumChildren(); ++i)
281
      {
282
5842595
        worklist.push(
283
11685190
            std::make_tuple(node[i],
284
                            node[i],
285
                            false));  // Rewrite the children of the result only
286
      }
287
    }
288
    else
289
    {
290
      // Working upwards
291
      // Reconstruct the node from it's (now rewritten) children on the stack
292
293
2462477
      Debug("expand") << "cons : " << node << std::endl;
294
2462477
      if (node.getNumChildren() > 0)
295
      {
296
        // cout << "cons : " << node << std::endl;
297
4782208
        NodeBuilder<> nb(node.getKind());
298
2391104
        if (node.getMetaKind() == metakind::PARAMETERIZED)
299
        {
300
182850
          Debug("expand") << "op   : " << node.getOperator() << std::endl;
301
          // cout << "op   : " << node.getOperator() << std::endl;
302
182850
          nb << node.getOperator();
303
        }
304
8233699
        for (size_t i = 0, nchild = node.getNumChildren(); i < nchild; ++i)
305
        {
306
5842595
          Assert(!result.empty());
307
11685190
          Node expanded = result.top();
308
5842595
          result.pop();
309
          // cout << "exchld : " << expanded << std::endl;
310
5842595
          Debug("expand") << "exchld : " << expanded << std::endl;
311
5842595
          nb << expanded;
312
        }
313
2391104
        node = nb;
314
      }
315
      // Only cache once all subterms are expanded
316
2462477
      cache[n] = n == node ? Node::null() : node;
317
2462477
      result.push(node);
318
    }
319
8437016
  } while (!worklist.empty());
320
321
131944
  AlwaysAssert(result.size() == 1);
322
323
263888
  Node res = result.top();
324
325
131944
  if (res == orig)
326
  {
327
127699
    return TrustNode::null();
328
  }
329
4245
  return TrustNode::mkTrustRewrite(orig, res, tpg);
330
}
331
332
11325
void ExpandDefs::expandAssertions(AssertionPipeline& assertions,
333
                                  bool expandOnly)
334
{
335
11325
  Chat() << "expanding definitions in assertions..." << std::endl;
336
22650
  Trace("exp-defs") << "ExpandDefs::simplify(): expanding definitions"
337
11325
                    << std::endl;
338
22650
  TimerStat::CodeTimer codeTimer(d_smtStats.d_definitionExpansionTime);
339
22650
  std::unordered_map<Node, Node, NodeHashFunction> cache;
340
106373
  for (size_t i = 0, nasserts = assertions.size(); i < nasserts; ++i)
341
  {
342
190096
    Node assert = assertions[i];
343
    // notice we call this method with only one value of expandOnly currently,
344
    // hence we maintain only a single set of proof steps in d_tpg.
345
190096
    TrustNode expd = expandDefinitions(assert, cache, expandOnly, d_tpg.get());
346
95048
    if (!expd.isNull())
347
    {
348
2434
      Trace("exp-defs") << "ExpandDefs::expandAssertions: " << assert << " -> "
349
1217
                        << expd.getNode() << std::endl;
350
1217
      assertions.replaceTrusted(i, expd);
351
    }
352
  }
353
11325
}
354
355
962
void ExpandDefs::setProofNodeManager(ProofNodeManager* pnm)
356
{
357
962
  if (d_tpg == nullptr)
358
  {
359
2886
    d_tpg.reset(new TConvProofGenerator(pnm,
360
962
                                        d_smt.getUserContext(),
361
                                        TConvPolicy::FIXPOINT,
362
                                        TConvCachePolicy::NEVER,
363
                                        "ExpandDefs::TConvProofGenerator",
364
                                        nullptr,
365
962
                                        true));
366
  }
367
962
}
368
369
}  // namespace smt
370
26685
}  // namespace CVC4