GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/fun_def_fmf.cpp Lines: 220 232 94.8 %
Date: 2021-05-22 Branches: 462 972 47.5 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Haniel Barbosa, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Function definition processor for finite model finding.
14
 */
15
16
#include "preprocessing/passes/fun_def_fmf.h"
17
18
#include <sstream>
19
20
#include "expr/skolem_manager.h"
21
#include "options/smt_options.h"
22
#include "preprocessing/assertion_pipeline.h"
23
#include "preprocessing/preprocessing_pass_context.h"
24
#include "theory/quantifiers/quantifiers_attributes.h"
25
#include "theory/quantifiers/term_database.h"
26
#include "theory/quantifiers/term_util.h"
27
#include "theory/rewriter.h"
28
29
using namespace std;
30
using namespace cvc5::kind;
31
using namespace cvc5::theory;
32
using namespace cvc5::theory::quantifiers;
33
34
namespace cvc5 {
35
namespace preprocessing {
36
namespace passes {
37
38
9459
FunDefFmf::FunDefFmf(PreprocessingPassContext* preprocContext)
39
    : PreprocessingPass(preprocContext, "fun-def-fmf"),
40
9459
      d_fmfRecFunctionsDefined(nullptr)
41
{
42
18918
  d_fmfRecFunctionsDefined =
43
18918
      new (true) NodeList(preprocContext->getUserContext());
44
9459
}
45
46
18918
FunDefFmf::~FunDefFmf() { d_fmfRecFunctionsDefined->deleteSelf(); }
47
48
117
PreprocessingPassResult FunDefFmf::applyInternal(
49
    AssertionPipeline* assertionsToPreprocess)
50
{
51
117
  Assert(d_fmfRecFunctionsDefined != nullptr);
52
  // reset
53
117
  d_sorts.clear();
54
117
  d_input_arg_inj.clear();
55
117
  d_funcs.clear();
56
  // must carry over current definitions (in case of incremental)
57
110
  for (context::CDList<Node>::const_iterator fit =
58
117
           d_fmfRecFunctionsDefined->begin();
59
227
       fit != d_fmfRecFunctionsDefined->end();
60
       ++fit)
61
  {
62
220
    Node f = (*fit);
63
110
    Assert(d_fmfRecFunctionsAbs.find(f) != d_fmfRecFunctionsAbs.end());
64
220
    TypeNode ft = d_fmfRecFunctionsAbs[f];
65
110
    d_sorts[f] = ft;
66
    std::map<Node, std::vector<Node>>::iterator fcit =
67
110
        d_fmfRecFunctionsConcrete.find(f);
68
110
    Assert(fcit != d_fmfRecFunctionsConcrete.end());
69
663
    for (const Node& fcc : fcit->second)
70
    {
71
553
      d_input_arg_inj[f].push_back(fcc);
72
    }
73
  }
74
117
  process(assertionsToPreprocess);
75
  // must store new definitions (in case of incremental)
76
201
  for (const Node& f : d_funcs)
77
  {
78
84
    d_fmfRecFunctionsAbs[f] = d_sorts[f];
79
84
    d_fmfRecFunctionsConcrete[f].clear();
80
288
    for (const Node& fcc : d_input_arg_inj[f])
81
    {
82
204
      d_fmfRecFunctionsConcrete[f].push_back(fcc);
83
    }
84
84
    d_fmfRecFunctionsDefined->push_back(f);
85
  }
86
117
  return PreprocessingPassResult::NO_CONFLICT;
87
}
88
89
117
void FunDefFmf::process(AssertionPipeline* assertionsToPreprocess)
90
{
91
117
  const std::vector<Node>& assertions = assertionsToPreprocess->ref();
92
234
  std::vector<int> fd_assertions;
93
234
  std::map<int, Node> subs_head;
94
  // first pass : find defined functions, transform quantifiers
95
117
  NodeManager* nm = NodeManager::currentNM();
96
117
  SkolemManager* sm = nm->getSkolemManager();
97
574
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
98
  {
99
914
    Node n = QuantAttributes::getFunDefHead(assertions[i]);
100
457
    if (!n.isNull())
101
    {
102
84
      Assert(n.getKind() == APPLY_UF);
103
168
      Node f = n.getOperator();
104
105
      // check if already defined, if so, throw error
106
84
      if (d_sorts.find(f) != d_sorts.end())
107
      {
108
        Unhandled() << "Cannot define function " << f << " more than once.";
109
      }
110
111
168
      Node bd = QuantAttributes::getFunDefBody(assertions[i]);
112
168
      Trace("fmf-fun-def-debug")
113
84
          << "Process function " << n << ", body = " << bd << std::endl;
114
84
      if (!bd.isNull())
115
      {
116
84
        d_funcs.push_back(f);
117
84
        bd = nm->mkNode(EQUAL, n, bd);
118
119
        // create a sort S that represents the inputs of the function
120
168
        std::stringstream ss;
121
84
        ss << "I_" << f;
122
168
        TypeNode iType = nm->mkSort(ss.str());
123
        AbsTypeFunDefAttribute atfda;
124
84
        iType.setAttribute(atfda, true);
125
84
        d_sorts[f] = iType;
126
127
        // create functions f1...fn mapping from this sort to concrete elements
128
84
        size_t nchildn = n.getNumChildren();
129
288
        for (size_t j = 0; j < nchildn; j++)
130
        {
131
408
          TypeNode typ = nm->mkFunctionType(iType, n[j].getType());
132
408
          std::stringstream ssf;
133
204
          ssf << f << "_arg_" << j;
134
612
          d_input_arg_inj[f].push_back(sm->mkDummySkolem(
135
408
              ssf.str(), typ, "op created during fun def fmf"));
136
        }
137
138
        // construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
139
168
        std::vector<Node> children;
140
168
        Node bv = nm->mkBoundVar("?i", iType);
141
168
        Node bvl = nm->mkNode(BOUND_VAR_LIST, bv);
142
168
        std::vector<Node> subs;
143
168
        std::vector<Node> vars;
144
288
        for (size_t j = 0; j < nchildn; j++)
145
        {
146
204
          vars.push_back(n[j]);
147
204
          subs.push_back(nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], bv));
148
        }
149
84
        bd = bd.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
150
84
        subs_head[i] =
151
168
            n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
152
153
168
        Trace("fmf-fun-def")
154
84
            << "FMF fun def: FUNCTION : rewrite " << assertions[i] << std::endl;
155
84
        Trace("fmf-fun-def") << "  to " << std::endl;
156
168
        Node new_q = nm->mkNode(FORALL, bvl, bd);
157
84
        new_q = Rewriter::rewrite(new_q);
158
84
        assertionsToPreprocess->replace(i, new_q);
159
84
        Trace("fmf-fun-def") << "  " << assertions[i] << std::endl;
160
84
        fd_assertions.push_back(i);
161
      }
162
      else
163
      {
164
        // can be, e.g. in corner cases forall x. f(x)=f(x), forall x.
165
        // f(x)=f(x)+1
166
      }
167
    }
168
  }
169
  // second pass : rewrite assertions
170
234
  std::map<int, std::map<Node, Node>> visited;
171
234
  std::map<int, std::map<Node, Node>> visited_cons;
172
574
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
173
  {
174
914
    bool is_fd = std::find(fd_assertions.begin(), fd_assertions.end(), i)
175
1371
                 != fd_assertions.end();
176
914
    std::vector<Node> constraints;
177
914
    Trace("fmf-fun-def-rewrite")
178
457
        << "Rewriting " << assertions[i]
179
457
        << ", is function definition = " << is_fd << std::endl;
180
457
    Node n = simplifyFormula(assertions[i],
181
                             true,
182
                             true,
183
                             constraints,
184
914
                             is_fd ? subs_head[i] : Node::null(),
185
                             is_fd,
186
                             visited,
187
1828
                             visited_cons);
188
457
    Assert(constraints.empty());
189
457
    if (n != assertions[i])
190
    {
191
118
      n = Rewriter::rewrite(n);
192
236
      Trace("fmf-fun-def-rewrite")
193
118
          << "FMF fun def : rewrite " << assertions[i] << std::endl;
194
118
      Trace("fmf-fun-def-rewrite") << "  to " << std::endl;
195
118
      Trace("fmf-fun-def-rewrite") << "  " << n << std::endl;
196
118
      assertionsToPreprocess->replace(i, n);
197
    }
198
  }
199
117
}
200
201
1326
Node FunDefFmf::simplifyFormula(
202
    Node n,
203
    bool pol,
204
    bool hasPol,
205
    std::vector<Node>& constraints,
206
    Node hd,
207
    bool is_fun_def,
208
    std::map<int, std::map<Node, Node>>& visited,
209
    std::map<int, std::map<Node, Node>>& visited_cons)
210
{
211
1326
  Assert(constraints.empty());
212
1326
  int index = (is_fun_def ? 1 : 0) + 2 * (hasPol ? (pol ? 1 : -1) : 0);
213
1326
  std::map<Node, Node>::iterator itv = visited[index].find(n);
214
1326
  if (itv != visited[index].end())
215
  {
216
    // constraints.insert( visited_cons[index]
217
138
    std::map<Node, Node>::iterator itvc = visited_cons[index].find(n);
218
138
    if (itvc != visited_cons[index].end())
219
    {
220
      constraints.push_back(itvc->second);
221
    }
222
138
    return itv->second;
223
  }
224
1188
  NodeManager* nm = NodeManager::currentNM();
225
2376
  Node ret;
226
2376
  Trace("fmf-fun-def-debug2") << "Simplify " << n << " " << pol << " " << hasPol
227
1188
                              << " " << is_fun_def << std::endl;
228
1188
  if (n.getKind() == FORALL)
229
  {
230
    Node c = simplifyFormula(
231
208
        n[1], pol, hasPol, constraints, hd, is_fun_def, visited, visited_cons);
232
    // append prenex to constraints
233
104
    for (unsigned i = 0; i < constraints.size(); i++)
234
    {
235
      constraints[i] = nm->mkNode(FORALL, n[0], constraints[i]);
236
      constraints[i] = Rewriter::rewrite(constraints[i]);
237
    }
238
104
    if (c != n[1])
239
    {
240
49
      ret = nm->mkNode(FORALL, n[0], c);
241
    }
242
    else
243
    {
244
55
      ret = n;
245
    }
246
  }
247
  else
248
  {
249
2168
    Node nn = n;
250
1084
    bool isBool = n.getType().isBoolean();
251
1084
    if (isBool && n.getKind() != APPLY_UF)
252
    {
253
1190
      std::vector<Node> children;
254
595
      bool childChanged = false;
255
      // are we at a branch position (not all children are necessarily
256
      // relevant)?
257
      bool branch_pos =
258
595
          (n.getKind() == ITE || n.getKind() == OR || n.getKind() == AND);
259
1190
      std::vector<Node> branch_constraints;
260
1443
      for (unsigned i = 0; i < n.getNumChildren(); i++)
261
      {
262
1696
        Node c = n[i];
263
        // do not process LHS of definition
264
848
        if (!is_fun_def || c != hd)
265
        {
266
          bool newHasPol;
267
          bool newPol;
268
765
          QuantPhaseReq::getPolarity(n, i, hasPol, pol, newHasPol, newPol);
269
          // get child constraints
270
1530
          std::vector<Node> cconstraints;
271
765
          c = simplifyFormula(n[i],
272
                              newPol,
273
                              newHasPol,
274
                              cconstraints,
275
                              hd,
276
                              false,
277
                              visited,
278
                              visited_cons);
279
765
          if (branch_pos)
280
          {
281
            // if at a branching position, the other constraints don't matter
282
            // if this is satisfied
283
300
            Node bcons = nm->mkAnd(cconstraints);
284
150
            branch_constraints.push_back(bcons);
285
300
            Trace("fmf-fun-def-debug2") << "Branching constraint at arg " << i
286
150
                                        << " is " << bcons << std::endl;
287
          }
288
765
          constraints.insert(
289
1530
              constraints.end(), cconstraints.begin(), cconstraints.end());
290
        }
291
848
        children.push_back(c);
292
848
        childChanged = c != n[i] || childChanged;
293
      }
294
595
      if (childChanged)
295
      {
296
25
        nn = nm->mkNode(n.getKind(), children);
297
      }
298
595
      if (branch_pos && !constraints.empty())
299
      {
300
        // if we are at a branching position in the formula, we can
301
        // minimize recursive constraints on recursively defined predicates if
302
        // we know one child forces the overall evaluation of this formula.
303
32
        Node branch_cond;
304
16
        if (n.getKind() == ITE)
305
        {
306
          // always care about constraints on the head of the ITE, but only
307
          // care about one of the children depending on how it evaluates
308
          branch_cond = nm->mkNode(
309
              AND,
310
              branch_constraints[0],
311
              nm->mkNode(
312
                  ITE, n[0], branch_constraints[1], branch_constraints[2]));
313
        }
314
        else
315
        {
316
          // in the default case, we care about all conditions
317
16
          branch_cond = nm->mkAnd(constraints);
318
67
          for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
319
          {
320
            // if this child holds with forcing polarity (true child of OR or
321
            // false child of AND), then we only care about its associated
322
            // recursive conditions
323
204
            branch_cond = nm->mkNode(ITE,
324
102
                                     (n.getKind() == OR ? n[i] : n[i].negate()),
325
51
                                     branch_constraints[i],
326
                                     branch_cond);
327
          }
328
        }
329
32
        Trace("fmf-fun-def-debug2")
330
16
            << "Made branching condition " << branch_cond << std::endl;
331
16
        constraints.clear();
332
16
        constraints.push_back(branch_cond);
333
      }
334
    }
335
    else
336
    {
337
      // simplify term
338
978
      std::map<Node, Node> visitedT;
339
489
      getConstraints(n, constraints, visitedT);
340
    }
341
1084
    if (!constraints.empty() && isBool && hasPol)
342
    {
343
      // conjoin with current
344
238
      Node cons = nm->mkAnd(constraints);
345
119
      if (pol)
346
      {
347
101
        ret = nm->mkNode(AND, nn, cons);
348
      }
349
      else
350
      {
351
18
        ret = nm->mkNode(OR, nn, cons.negate());
352
      }
353
238
      Trace("fmf-fun-def-debug2")
354
119
          << "Add constraint to obtain " << ret << std::endl;
355
119
      constraints.clear();
356
    }
357
    else
358
    {
359
965
      ret = nn;
360
    }
361
  }
362
1188
  if (!constraints.empty())
363
  {
364
264
    Node cons;
365
    // flatten to AND node for the purposes of caching
366
132
    if (constraints.size() > 1)
367
    {
368
      cons = nm->mkNode(AND, constraints);
369
      cons = Rewriter::rewrite(cons);
370
      constraints.clear();
371
      constraints.push_back(cons);
372
    }
373
    else
374
    {
375
132
      cons = constraints[0];
376
    }
377
132
    visited_cons[index][n] = cons;
378
132
    Assert(constraints.size() == 1 && constraints[0] == cons);
379
  }
380
1188
  visited[index][n] = ret;
381
1188
  return ret;
382
}
383
384
2396
void FunDefFmf::getConstraints(Node n,
385
                               std::vector<Node>& constraints,
386
                               std::map<Node, Node>& visited)
387
{
388
2396
  std::map<Node, Node>::iterator itv = visited.find(n);
389
2396
  if (itv != visited.end())
390
  {
391
    // already visited
392
447
    if (!itv->second.isNull())
393
    {
394
      // add the cached constraint if it does not already occur
395
660
      if (std::find(constraints.begin(), constraints.end(), itv->second)
396
660
          == constraints.end())
397
      {
398
220
        constraints.push_back(itv->second);
399
      }
400
    }
401
447
    return;
402
  }
403
1949
  visited[n] = Node::null();
404
3898
  std::vector<Node> currConstraints;
405
1949
  NodeManager* nm = NodeManager::currentNM();
406
1949
  if (n.getKind() == ITE)
407
  {
408
    // collect constraints for the condition
409
75
    getConstraints(n[0], currConstraints, visited);
410
    // collect constraints for each branch
411
150
    Node cs[2];
412
225
    for (unsigned i = 0; i < 2; i++)
413
    {
414
300
      std::vector<Node> ccons;
415
150
      getConstraints(n[i + 1], ccons, visited);
416
150
      cs[i] = nm->mkAnd(ccons);
417
    }
418
75
    if (!cs[0].isConst() || !cs[1].isConst())
419
    {
420
72
      Node itec = nm->mkNode(ITE, n[0], cs[0], cs[1]);
421
36
      currConstraints.push_back(itec);
422
72
      Trace("fmf-fun-def-debug")
423
36
          << "---> add constraint " << itec << " for " << n << std::endl;
424
    }
425
  }
426
  else
427
  {
428
1874
    if (n.getKind() == APPLY_UF)
429
    {
430
      // check if f is defined, if so, we must enforce domain constraints for
431
      // this f-application
432
550
      Node f = n.getOperator();
433
275
      std::map<Node, TypeNode>::iterator it = d_sorts.find(f);
434
275
      if (it != d_sorts.end())
435
      {
436
        // create existential
437
248
        Node z = nm->mkBoundVar("?z", it->second);
438
248
        Node bvl = nm->mkNode(BOUND_VAR_LIST, z);
439
248
        std::vector<Node> children;
440
292
        for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
441
        {
442
336
          Node uz = nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], z);
443
168
          children.push_back(uz.eqNode(n[j]));
444
        }
445
248
        Node bd = nm->mkAnd(children);
446
124
        bd = bd.negate();
447
248
        Node ex = nm->mkNode(FORALL, bvl, bd);
448
124
        ex = ex.negate();
449
124
        currConstraints.push_back(ex);
450
248
        Trace("fmf-fun-def-debug")
451
124
            << "---> add constraint " << ex << " for " << n << std::endl;
452
      }
453
    }
454
3336
    for (const Node& cn : n)
455
    {
456
1462
      getConstraints(cn, currConstraints, visited);
457
    }
458
  }
459
  // set the visited cache
460
1949
  if (!currConstraints.empty())
461
  {
462
440
    Node finalc = nm->mkAnd(currConstraints);
463
220
    visited[n] = finalc;
464
    // add to constraints
465
220
    getConstraints(n, constraints, visited);
466
  }
467
}
468
469
}  // namespace passes
470
}  // namespace preprocessing
471
28191
}  // namespace cvc5