GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/fun_def_fmf.cpp Lines: 219 231 94.8 %
Date: 2021-03-23 Branches: 462 972 47.5 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file fun_def_fmf.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Haniel Barbosa, Mathias Preiner
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Function definition processor for finite model finding
13
 **/
14
15
#include "preprocessing/passes/fun_def_fmf.h"
16
17
#include <sstream>
18
19
#include "options/smt_options.h"
20
#include "preprocessing/assertion_pipeline.h"
21
#include "preprocessing/preprocessing_pass_context.h"
22
#include "proof/proof_manager.h"
23
#include "theory/quantifiers/quantifiers_attributes.h"
24
#include "theory/quantifiers/term_database.h"
25
#include "theory/quantifiers/term_util.h"
26
#include "theory/rewriter.h"
27
28
using namespace std;
29
using namespace CVC4::kind;
30
using namespace CVC4::theory;
31
using namespace CVC4::theory::quantifiers;
32
33
namespace CVC4 {
34
namespace preprocessing {
35
namespace passes {
36
37
8997
FunDefFmf::FunDefFmf(PreprocessingPassContext* preprocContext)
38
    : PreprocessingPass(preprocContext, "fun-def-fmf"),
39
8997
      d_fmfRecFunctionsDefined(nullptr)
40
{
41
17994
  d_fmfRecFunctionsDefined =
42
17994
      new (true) NodeList(preprocContext->getUserContext());
43
8997
}
44
45
17988
FunDefFmf::~FunDefFmf() { d_fmfRecFunctionsDefined->deleteSelf(); }
46
47
106
PreprocessingPassResult FunDefFmf::applyInternal(
48
    AssertionPipeline* assertionsToPreprocess)
49
{
50
106
  Assert(d_fmfRecFunctionsDefined != nullptr);
51
  // reset
52
106
  d_sorts.clear();
53
106
  d_input_arg_inj.clear();
54
106
  d_funcs.clear();
55
  // must carry over current definitions (in case of incremental)
56
185
  for (context::CDList<Node>::const_iterator fit =
57
106
           d_fmfRecFunctionsDefined->begin();
58
291
       fit != d_fmfRecFunctionsDefined->end();
59
       ++fit)
60
  {
61
370
    Node f = (*fit);
62
185
    Assert(d_fmfRecFunctionsAbs.find(f) != d_fmfRecFunctionsAbs.end());
63
370
    TypeNode ft = d_fmfRecFunctionsAbs[f];
64
185
    d_sorts[f] = ft;
65
    std::map<Node, std::vector<Node>>::iterator fcit =
66
185
        d_fmfRecFunctionsConcrete.find(f);
67
185
    Assert(fcit != d_fmfRecFunctionsConcrete.end());
68
1254
    for (const Node& fcc : fcit->second)
69
    {
70
1069
      d_input_arg_inj[f].push_back(fcc);
71
    }
72
  }
73
106
  process(assertionsToPreprocess);
74
  // must store new definitions (in case of incremental)
75
187
  for (const Node& f : d_funcs)
76
  {
77
81
    d_fmfRecFunctionsAbs[f] = d_sorts[f];
78
81
    d_fmfRecFunctionsConcrete[f].clear();
79
328
    for (const Node& fcc : d_input_arg_inj[f])
80
    {
81
247
      d_fmfRecFunctionsConcrete[f].push_back(fcc);
82
    }
83
81
    d_fmfRecFunctionsDefined->push_back(f);
84
  }
85
106
  return PreprocessingPassResult::NO_CONFLICT;
86
}
87
88
106
void FunDefFmf::process(AssertionPipeline* assertionsToPreprocess)
89
{
90
106
  const std::vector<Node>& assertions = assertionsToPreprocess->ref();
91
212
  std::vector<int> fd_assertions;
92
212
  std::map<int, Node> subs_head;
93
  // first pass : find defined functions, transform quantifiers
94
106
  NodeManager* nm = NodeManager::currentNM();
95
590
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
96
  {
97
968
    Node n = QuantAttributes::getFunDefHead(assertions[i]);
98
484
    if (!n.isNull())
99
    {
100
81
      Assert(n.getKind() == APPLY_UF);
101
162
      Node f = n.getOperator();
102
103
      // check if already defined, if so, throw error
104
81
      if (d_sorts.find(f) != d_sorts.end())
105
      {
106
        Unhandled() << "Cannot define function " << f << " more than once.";
107
      }
108
109
162
      Node bd = QuantAttributes::getFunDefBody(assertions[i]);
110
162
      Trace("fmf-fun-def-debug")
111
81
          << "Process function " << n << ", body = " << bd << std::endl;
112
81
      if (!bd.isNull())
113
      {
114
81
        d_funcs.push_back(f);
115
81
        bd = nm->mkNode(EQUAL, n, bd);
116
117
        // create a sort S that represents the inputs of the function
118
162
        std::stringstream ss;
119
81
        ss << "I_" << f;
120
162
        TypeNode iType = nm->mkSort(ss.str());
121
        AbsTypeFunDefAttribute atfda;
122
81
        iType.setAttribute(atfda, true);
123
81
        d_sorts[f] = iType;
124
125
        // create functions f1...fn mapping from this sort to concrete elements
126
81
        size_t nchildn = n.getNumChildren();
127
328
        for (size_t j = 0; j < nchildn; j++)
128
        {
129
494
          TypeNode typ = nm->mkFunctionType(iType, n[j].getType());
130
494
          std::stringstream ssf;
131
247
          ssf << f << "_arg_" << j;
132
494
          d_input_arg_inj[f].push_back(
133
494
              nm->mkSkolem(ssf.str(), typ, "op created during fun def fmf"));
134
        }
135
136
        // construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
137
162
        std::vector<Node> children;
138
162
        Node bv = nm->mkBoundVar("?i", iType);
139
162
        Node bvl = nm->mkNode(BOUND_VAR_LIST, bv);
140
162
        std::vector<Node> subs;
141
162
        std::vector<Node> vars;
142
328
        for (size_t j = 0; j < nchildn; j++)
143
        {
144
247
          vars.push_back(n[j]);
145
247
          subs.push_back(nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], bv));
146
        }
147
81
        bd = bd.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
148
81
        subs_head[i] =
149
162
            n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
150
151
162
        Trace("fmf-fun-def")
152
81
            << "FMF fun def: FUNCTION : rewrite " << assertions[i] << std::endl;
153
81
        Trace("fmf-fun-def") << "  to " << std::endl;
154
162
        Node new_q = nm->mkNode(FORALL, bvl, bd);
155
81
        new_q = Rewriter::rewrite(new_q);
156
81
        assertionsToPreprocess->replace(i, new_q);
157
81
        Trace("fmf-fun-def") << "  " << assertions[i] << std::endl;
158
81
        fd_assertions.push_back(i);
159
      }
160
      else
161
      {
162
        // can be, e.g. in corner cases forall x. f(x)=f(x), forall x.
163
        // f(x)=f(x)+1
164
      }
165
    }
166
  }
167
  // second pass : rewrite assertions
168
212
  std::map<int, std::map<Node, Node>> visited;
169
212
  std::map<int, std::map<Node, Node>> visited_cons;
170
590
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
171
  {
172
968
    bool is_fd = std::find(fd_assertions.begin(), fd_assertions.end(), i)
173
1452
                 != fd_assertions.end();
174
968
    std::vector<Node> constraints;
175
968
    Trace("fmf-fun-def-rewrite")
176
484
        << "Rewriting " << assertions[i]
177
484
        << ", is function definition = " << is_fd << std::endl;
178
484
    Node n = simplifyFormula(assertions[i],
179
                             true,
180
                             true,
181
                             constraints,
182
968
                             is_fd ? subs_head[i] : Node::null(),
183
                             is_fd,
184
                             visited,
185
1936
                             visited_cons);
186
484
    Assert(constraints.empty());
187
484
    if (n != assertions[i])
188
    {
189
103
      n = Rewriter::rewrite(n);
190
206
      Trace("fmf-fun-def-rewrite")
191
103
          << "FMF fun def : rewrite " << assertions[i] << std::endl;
192
103
      Trace("fmf-fun-def-rewrite") << "  to " << std::endl;
193
103
      Trace("fmf-fun-def-rewrite") << "  " << n << std::endl;
194
103
      assertionsToPreprocess->replace(i, n);
195
    }
196
  }
197
106
}
198
199
1366
Node FunDefFmf::simplifyFormula(
200
    Node n,
201
    bool pol,
202
    bool hasPol,
203
    std::vector<Node>& constraints,
204
    Node hd,
205
    bool is_fun_def,
206
    std::map<int, std::map<Node, Node>>& visited,
207
    std::map<int, std::map<Node, Node>>& visited_cons)
208
{
209
1366
  Assert(constraints.empty());
210
1366
  int index = (is_fun_def ? 1 : 0) + 2 * (hasPol ? (pol ? 1 : -1) : 0);
211
1366
  std::map<Node, Node>::iterator itv = visited[index].find(n);
212
1366
  if (itv != visited[index].end())
213
  {
214
    // constraints.insert( visited_cons[index]
215
155
    std::map<Node, Node>::iterator itvc = visited_cons[index].find(n);
216
155
    if (itvc != visited_cons[index].end())
217
    {
218
      constraints.push_back(itvc->second);
219
    }
220
155
    return itv->second;
221
  }
222
1211
  NodeManager* nm = NodeManager::currentNM();
223
2422
  Node ret;
224
2422
  Trace("fmf-fun-def-debug2") << "Simplify " << n << " " << pol << " " << hasPol
225
1211
                              << " " << is_fun_def << std::endl;
226
1211
  if (n.getKind() == FORALL)
227
  {
228
    Node c = simplifyFormula(
229
186
        n[1], pol, hasPol, constraints, hd, is_fun_def, visited, visited_cons);
230
    // append prenex to constraints
231
93
    for (unsigned i = 0; i < constraints.size(); i++)
232
    {
233
      constraints[i] = nm->mkNode(FORALL, n[0], constraints[i]);
234
      constraints[i] = Rewriter::rewrite(constraints[i]);
235
    }
236
93
    if (c != n[1])
237
    {
238
42
      ret = nm->mkNode(FORALL, n[0], c);
239
    }
240
    else
241
    {
242
51
      ret = n;
243
    }
244
  }
245
  else
246
  {
247
2236
    Node nn = n;
248
1118
    bool isBool = n.getType().isBoolean();
249
1118
    if (isBool && n.getKind() != APPLY_UF)
250
    {
251
1204
      std::vector<Node> children;
252
602
      bool childChanged = false;
253
      // are we at a branch position (not all children are necessarily
254
      // relevant)?
255
      bool branch_pos =
256
602
          (n.getKind() == ITE || n.getKind() == OR || n.getKind() == AND);
257
1204
      std::vector<Node> branch_constraints;
258
1471
      for (unsigned i = 0; i < n.getNumChildren(); i++)
259
      {
260
1738
        Node c = n[i];
261
        // do not process LHS of definition
262
869
        if (!is_fun_def || c != hd)
263
        {
264
          bool newHasPol;
265
          bool newPol;
266
789
          QuantPhaseReq::getPolarity(n, i, hasPol, pol, newHasPol, newPol);
267
          // get child constraints
268
1578
          std::vector<Node> cconstraints;
269
789
          c = simplifyFormula(n[i],
270
                              newPol,
271
                              newHasPol,
272
                              cconstraints,
273
                              hd,
274
                              false,
275
                              visited,
276
                              visited_cons);
277
789
          if (branch_pos)
278
          {
279
            // if at a branching position, the other constraints don't matter
280
            // if this is satisfied
281
268
            Node bcons = nm->mkAnd(cconstraints);
282
134
            branch_constraints.push_back(bcons);
283
268
            Trace("fmf-fun-def-debug2") << "Branching constraint at arg " << i
284
134
                                        << " is " << bcons << std::endl;
285
          }
286
789
          constraints.insert(
287
1578
              constraints.end(), cconstraints.begin(), cconstraints.end());
288
        }
289
869
        children.push_back(c);
290
869
        childChanged = c != n[i] || childChanged;
291
      }
292
602
      if (childChanged)
293
      {
294
18
        nn = nm->mkNode(n.getKind(), children);
295
      }
296
602
      if (branch_pos && !constraints.empty())
297
      {
298
        // if we are at a branching position in the formula, we can
299
        // minimize recursive constraints on recursively defined predicates if
300
        // we know one child forces the overall evaluation of this formula.
301
28
        Node branch_cond;
302
14
        if (n.getKind() == ITE)
303
        {
304
          // always care about constraints on the head of the ITE, but only
305
          // care about one of the children depending on how it evaluates
306
          branch_cond = nm->mkNode(
307
              AND,
308
              branch_constraints[0],
309
              nm->mkNode(
310
                  ITE, n[0], branch_constraints[1], branch_constraints[2]));
311
        }
312
        else
313
        {
314
          // in the default case, we care about all conditions
315
14
          branch_cond = nm->mkAnd(constraints);
316
59
          for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
317
          {
318
            // if this child holds with forcing polarity (true child of OR or
319
            // false child of AND), then we only care about its associated
320
            // recursive conditions
321
180
            branch_cond = nm->mkNode(ITE,
322
90
                                     (n.getKind() == OR ? n[i] : n[i].negate()),
323
45
                                     branch_constraints[i],
324
                                     branch_cond);
325
          }
326
        }
327
28
        Trace("fmf-fun-def-debug2")
328
14
            << "Made branching condition " << branch_cond << std::endl;
329
14
        constraints.clear();
330
14
        constraints.push_back(branch_cond);
331
      }
332
    }
333
    else
334
    {
335
      // simplify term
336
1032
      std::map<Node, Node> visitedT;
337
516
      getConstraints(n, constraints, visitedT);
338
    }
339
1118
    if (!constraints.empty() && isBool && hasPol)
340
    {
341
      // conjoin with current
342
208
      Node cons = nm->mkAnd(constraints);
343
104
      if (pol)
344
      {
345
90
        ret = nm->mkNode(AND, nn, cons);
346
      }
347
      else
348
      {
349
14
        ret = nm->mkNode(OR, nn, cons.negate());
350
      }
351
208
      Trace("fmf-fun-def-debug2")
352
104
          << "Add constraint to obtain " << ret << std::endl;
353
104
      constraints.clear();
354
    }
355
    else
356
    {
357
1014
      ret = nn;
358
    }
359
  }
360
1211
  if (!constraints.empty())
361
  {
362
226
    Node cons;
363
    // flatten to AND node for the purposes of caching
364
113
    if (constraints.size() > 1)
365
    {
366
      cons = nm->mkNode(AND, constraints);
367
      cons = Rewriter::rewrite(cons);
368
      constraints.clear();
369
      constraints.push_back(cons);
370
    }
371
    else
372
    {
373
113
      cons = constraints[0];
374
    }
375
113
    visited_cons[index][n] = cons;
376
113
    Assert(constraints.size() == 1 && constraints[0] == cons);
377
  }
378
1211
  visited[index][n] = ret;
379
1211
  return ret;
380
}
381
382
2289
void FunDefFmf::getConstraints(Node n,
383
                               std::vector<Node>& constraints,
384
                               std::map<Node, Node>& visited)
385
{
386
2289
  std::map<Node, Node>::iterator itv = visited.find(n);
387
2289
  if (itv != visited.end())
388
  {
389
    // already visited
390
393
    if (!itv->second.isNull())
391
    {
392
      // add the cached constraint if it does not already occur
393
579
      if (std::find(constraints.begin(), constraints.end(), itv->second)
394
579
          == constraints.end())
395
      {
396
193
        constraints.push_back(itv->second);
397
      }
398
    }
399
393
    return;
400
  }
401
1896
  visited[n] = Node::null();
402
3792
  std::vector<Node> currConstraints;
403
1896
  NodeManager* nm = NodeManager::currentNM();
404
1896
  if (n.getKind() == ITE)
405
  {
406
    // collect constraints for the condition
407
63
    getConstraints(n[0], currConstraints, visited);
408
    // collect constraints for each branch
409
126
    Node cs[2];
410
189
    for (unsigned i = 0; i < 2; i++)
411
    {
412
252
      std::vector<Node> ccons;
413
126
      getConstraints(n[i + 1], ccons, visited);
414
126
      cs[i] = nm->mkAnd(ccons);
415
    }
416
63
    if (!cs[0].isConst() || !cs[1].isConst())
417
    {
418
60
      Node itec = nm->mkNode(ITE, n[0], cs[0], cs[1]);
419
30
      currConstraints.push_back(itec);
420
60
      Trace("fmf-fun-def-debug")
421
30
          << "---> add constraint " << itec << " for " << n << std::endl;
422
    }
423
  }
424
  else
425
  {
426
1833
    if (n.getKind() == APPLY_UF)
427
    {
428
      // check if f is defined, if so, we must enforce domain constraints for
429
      // this f-application
430
470
      Node f = n.getOperator();
431
235
      std::map<Node, TypeNode>::iterator it = d_sorts.find(f);
432
235
      if (it != d_sorts.end())
433
      {
434
        // create existential
435
218
        Node z = nm->mkBoundVar("?z", it->second);
436
218
        Node bvl = nm->mkNode(BOUND_VAR_LIST, z);
437
218
        std::vector<Node> children;
438
271
        for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
439
        {
440
324
          Node uz = nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], z);
441
162
          children.push_back(uz.eqNode(n[j]));
442
        }
443
218
        Node bd = nm->mkAnd(children);
444
109
        bd = bd.negate();
445
218
        Node ex = nm->mkNode(FORALL, bvl, bd);
446
109
        ex = ex.negate();
447
109
        currConstraints.push_back(ex);
448
218
        Trace("fmf-fun-def-debug")
449
109
            << "---> add constraint " << ex << " for " << n << std::endl;
450
      }
451
    }
452
3224
    for (const Node& cn : n)
453
    {
454
1391
      getConstraints(cn, currConstraints, visited);
455
    }
456
  }
457
  // set the visited cache
458
1896
  if (!currConstraints.empty())
459
  {
460
386
    Node finalc = nm->mkAnd(currConstraints);
461
193
    visited[n] = finalc;
462
    // add to constraints
463
193
    getConstraints(n, constraints, visited);
464
  }
465
}
466
467
}  // namespace passes
468
}  // namespace preprocessing
469
26685
}  // namespace CVC4