GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/fun_def_fmf.cpp Lines: 219 231 94.8 %
Date: 2021-09-10 Branches: 462 970 47.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Haniel Barbosa, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Function definition processor for finite model finding.
14
 */
15
16
#include "preprocessing/passes/fun_def_fmf.h"
17
18
#include <sstream>
19
20
#include "expr/skolem_manager.h"
21
#include "options/smt_options.h"
22
#include "preprocessing/assertion_pipeline.h"
23
#include "preprocessing/preprocessing_pass_context.h"
24
#include "theory/quantifiers/quantifiers_attributes.h"
25
#include "theory/quantifiers/term_database.h"
26
#include "theory/quantifiers/term_util.h"
27
#include "theory/rewriter.h"
28
29
using namespace std;
30
using namespace cvc5::kind;
31
using namespace cvc5::theory;
32
using namespace cvc5::theory::quantifiers;
33
34
namespace cvc5 {
35
namespace preprocessing {
36
namespace passes {
37
38
9913
FunDefFmf::FunDefFmf(PreprocessingPassContext* preprocContext)
39
    : PreprocessingPass(preprocContext, "fun-def-fmf"),
40
9913
      d_fmfRecFunctionsDefined(nullptr)
41
{
42
9913
  d_fmfRecFunctionsDefined = new (true) NodeList(userContext());
43
9913
}
44
45
19820
FunDefFmf::~FunDefFmf() { d_fmfRecFunctionsDefined->deleteSelf(); }
46
47
121
PreprocessingPassResult FunDefFmf::applyInternal(
48
    AssertionPipeline* assertionsToPreprocess)
49
{
50
121
  Assert(d_fmfRecFunctionsDefined != nullptr);
51
  // reset
52
121
  d_sorts.clear();
53
121
  d_input_arg_inj.clear();
54
121
  d_funcs.clear();
55
  // must carry over current definitions (in case of incremental)
56
110
  for (context::CDList<Node>::const_iterator fit =
57
121
           d_fmfRecFunctionsDefined->begin();
58
231
       fit != d_fmfRecFunctionsDefined->end();
59
       ++fit)
60
  {
61
220
    Node f = (*fit);
62
110
    Assert(d_fmfRecFunctionsAbs.find(f) != d_fmfRecFunctionsAbs.end());
63
220
    TypeNode ft = d_fmfRecFunctionsAbs[f];
64
110
    d_sorts[f] = ft;
65
    std::map<Node, std::vector<Node>>::iterator fcit =
66
110
        d_fmfRecFunctionsConcrete.find(f);
67
110
    Assert(fcit != d_fmfRecFunctionsConcrete.end());
68
663
    for (const Node& fcc : fcit->second)
69
    {
70
553
      d_input_arg_inj[f].push_back(fcc);
71
    }
72
  }
73
121
  process(assertionsToPreprocess);
74
  // must store new definitions (in case of incremental)
75
205
  for (const Node& f : d_funcs)
76
  {
77
84
    d_fmfRecFunctionsAbs[f] = d_sorts[f];
78
84
    d_fmfRecFunctionsConcrete[f].clear();
79
288
    for (const Node& fcc : d_input_arg_inj[f])
80
    {
81
204
      d_fmfRecFunctionsConcrete[f].push_back(fcc);
82
    }
83
84
    d_fmfRecFunctionsDefined->push_back(f);
84
  }
85
121
  return PreprocessingPassResult::NO_CONFLICT;
86
}
87
88
121
void FunDefFmf::process(AssertionPipeline* assertionsToPreprocess)
89
{
90
121
  const std::vector<Node>& assertions = assertionsToPreprocess->ref();
91
242
  std::vector<int> fd_assertions;
92
242
  std::map<int, Node> subs_head;
93
  // first pass : find defined functions, transform quantifiers
94
121
  NodeManager* nm = NodeManager::currentNM();
95
121
  SkolemManager* sm = nm->getSkolemManager();
96
586
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
97
  {
98
930
    Node n = QuantAttributes::getFunDefHead(assertions[i]);
99
465
    if (!n.isNull())
100
    {
101
84
      Assert(n.getKind() == APPLY_UF);
102
168
      Node f = n.getOperator();
103
104
      // check if already defined, if so, throw error
105
84
      if (d_sorts.find(f) != d_sorts.end())
106
      {
107
        Unhandled() << "Cannot define function " << f << " more than once.";
108
      }
109
110
168
      Node bd = QuantAttributes::getFunDefBody(assertions[i]);
111
168
      Trace("fmf-fun-def-debug")
112
84
          << "Process function " << n << ", body = " << bd << std::endl;
113
84
      if (!bd.isNull())
114
      {
115
84
        d_funcs.push_back(f);
116
84
        bd = nm->mkNode(EQUAL, n, bd);
117
118
        // create a sort S that represents the inputs of the function
119
168
        std::stringstream ss;
120
84
        ss << "I_" << f;
121
168
        TypeNode iType = nm->mkSort(ss.str());
122
        AbsTypeFunDefAttribute atfda;
123
84
        iType.setAttribute(atfda, true);
124
84
        d_sorts[f] = iType;
125
126
        // create functions f1...fn mapping from this sort to concrete elements
127
84
        size_t nchildn = n.getNumChildren();
128
288
        for (size_t j = 0; j < nchildn; j++)
129
        {
130
408
          TypeNode typ = nm->mkFunctionType(iType, n[j].getType());
131
408
          std::stringstream ssf;
132
204
          ssf << f << "_arg_" << j;
133
612
          d_input_arg_inj[f].push_back(sm->mkDummySkolem(
134
408
              ssf.str(), typ, "op created during fun def fmf"));
135
        }
136
137
        // construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
138
168
        std::vector<Node> children;
139
168
        Node bv = nm->mkBoundVar("?i", iType);
140
168
        Node bvl = nm->mkNode(BOUND_VAR_LIST, bv);
141
168
        std::vector<Node> subs;
142
168
        std::vector<Node> vars;
143
288
        for (size_t j = 0; j < nchildn; j++)
144
        {
145
204
          vars.push_back(n[j]);
146
204
          subs.push_back(nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], bv));
147
        }
148
84
        bd = bd.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
149
84
        subs_head[i] =
150
168
            n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
151
152
168
        Trace("fmf-fun-def")
153
84
            << "FMF fun def: FUNCTION : rewrite " << assertions[i] << std::endl;
154
84
        Trace("fmf-fun-def") << "  to " << std::endl;
155
168
        Node new_q = nm->mkNode(FORALL, bvl, bd);
156
84
        new_q = rewrite(new_q);
157
84
        assertionsToPreprocess->replace(i, new_q);
158
84
        Trace("fmf-fun-def") << "  " << assertions[i] << std::endl;
159
84
        fd_assertions.push_back(i);
160
      }
161
      else
162
      {
163
        // can be, e.g. in corner cases forall x. f(x)=f(x), forall x.
164
        // f(x)=f(x)+1
165
      }
166
    }
167
  }
168
  // second pass : rewrite assertions
169
242
  std::map<int, std::map<Node, Node>> visited;
170
242
  std::map<int, std::map<Node, Node>> visited_cons;
171
586
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
172
  {
173
930
    bool is_fd = std::find(fd_assertions.begin(), fd_assertions.end(), i)
174
1395
                 != fd_assertions.end();
175
930
    std::vector<Node> constraints;
176
930
    Trace("fmf-fun-def-rewrite")
177
465
        << "Rewriting " << assertions[i]
178
465
        << ", is function definition = " << is_fd << std::endl;
179
465
    Node n = simplifyFormula(assertions[i],
180
                             true,
181
                             true,
182
                             constraints,
183
930
                             is_fd ? subs_head[i] : Node::null(),
184
                             is_fd,
185
                             visited,
186
1860
                             visited_cons);
187
465
    Assert(constraints.empty());
188
465
    if (n != assertions[i])
189
    {
190
118
      n = rewrite(n);
191
236
      Trace("fmf-fun-def-rewrite")
192
118
          << "FMF fun def : rewrite " << assertions[i] << std::endl;
193
118
      Trace("fmf-fun-def-rewrite") << "  to " << std::endl;
194
118
      Trace("fmf-fun-def-rewrite") << "  " << n << std::endl;
195
118
      assertionsToPreprocess->replace(i, n);
196
    }
197
  }
198
121
}
199
200
1342
Node FunDefFmf::simplifyFormula(
201
    Node n,
202
    bool pol,
203
    bool hasPol,
204
    std::vector<Node>& constraints,
205
    Node hd,
206
    bool is_fun_def,
207
    std::map<int, std::map<Node, Node>>& visited,
208
    std::map<int, std::map<Node, Node>>& visited_cons)
209
{
210
1342
  Assert(constraints.empty());
211
1342
  int index = (is_fun_def ? 1 : 0) + 2 * (hasPol ? (pol ? 1 : -1) : 0);
212
1342
  std::map<Node, Node>::iterator itv = visited[index].find(n);
213
1342
  if (itv != visited[index].end())
214
  {
215
    // constraints.insert( visited_cons[index]
216
138
    std::map<Node, Node>::iterator itvc = visited_cons[index].find(n);
217
138
    if (itvc != visited_cons[index].end())
218
    {
219
      constraints.push_back(itvc->second);
220
    }
221
138
    return itv->second;
222
  }
223
1204
  NodeManager* nm = NodeManager::currentNM();
224
2408
  Node ret;
225
2408
  Trace("fmf-fun-def-debug2") << "Simplify " << n << " " << pol << " " << hasPol
226
1204
                              << " " << is_fun_def << std::endl;
227
1204
  if (n.getKind() == FORALL)
228
  {
229
    Node c = simplifyFormula(
230
208
        n[1], pol, hasPol, constraints, hd, is_fun_def, visited, visited_cons);
231
    // append prenex to constraints
232
104
    for (unsigned i = 0; i < constraints.size(); i++)
233
    {
234
      constraints[i] = nm->mkNode(FORALL, n[0], constraints[i]);
235
      constraints[i] = rewrite(constraints[i]);
236
    }
237
104
    if (c != n[1])
238
    {
239
49
      ret = nm->mkNode(FORALL, n[0], c);
240
    }
241
    else
242
    {
243
55
      ret = n;
244
    }
245
  }
246
  else
247
  {
248
2200
    Node nn = n;
249
1100
    bool isBool = n.getType().isBoolean();
250
1100
    if (isBool && n.getKind() != APPLY_UF)
251
    {
252
1206
      std::vector<Node> children;
253
603
      bool childChanged = false;
254
      // are we at a branch position (not all children are necessarily
255
      // relevant)?
256
      bool branch_pos =
257
603
          (n.getKind() == ITE || n.getKind() == OR || n.getKind() == AND);
258
1206
      std::vector<Node> branch_constraints;
259
1459
      for (unsigned i = 0; i < n.getNumChildren(); i++)
260
      {
261
1712
        Node c = n[i];
262
        // do not process LHS of definition
263
856
        if (!is_fun_def || c != hd)
264
        {
265
          bool newHasPol;
266
          bool newPol;
267
773
          QuantPhaseReq::getPolarity(n, i, hasPol, pol, newHasPol, newPol);
268
          // get child constraints
269
1546
          std::vector<Node> cconstraints;
270
773
          c = simplifyFormula(n[i],
271
                              newPol,
272
                              newHasPol,
273
                              cconstraints,
274
                              hd,
275
                              false,
276
                              visited,
277
                              visited_cons);
278
773
          if (branch_pos)
279
          {
280
            // if at a branching position, the other constraints don't matter
281
            // if this is satisfied
282
300
            Node bcons = nm->mkAnd(cconstraints);
283
150
            branch_constraints.push_back(bcons);
284
300
            Trace("fmf-fun-def-debug2") << "Branching constraint at arg " << i
285
150
                                        << " is " << bcons << std::endl;
286
          }
287
773
          constraints.insert(
288
1546
              constraints.end(), cconstraints.begin(), cconstraints.end());
289
        }
290
856
        children.push_back(c);
291
856
        childChanged = c != n[i] || childChanged;
292
      }
293
603
      if (childChanged)
294
      {
295
25
        nn = nm->mkNode(n.getKind(), children);
296
      }
297
603
      if (branch_pos && !constraints.empty())
298
      {
299
        // if we are at a branching position in the formula, we can
300
        // minimize recursive constraints on recursively defined predicates if
301
        // we know one child forces the overall evaluation of this formula.
302
32
        Node branch_cond;
303
16
        if (n.getKind() == ITE)
304
        {
305
          // always care about constraints on the head of the ITE, but only
306
          // care about one of the children depending on how it evaluates
307
          branch_cond = nm->mkNode(
308
              AND,
309
              branch_constraints[0],
310
              nm->mkNode(
311
                  ITE, n[0], branch_constraints[1], branch_constraints[2]));
312
        }
313
        else
314
        {
315
          // in the default case, we care about all conditions
316
16
          branch_cond = nm->mkAnd(constraints);
317
67
          for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
318
          {
319
            // if this child holds with forcing polarity (true child of OR or
320
            // false child of AND), then we only care about its associated
321
            // recursive conditions
322
204
            branch_cond = nm->mkNode(ITE,
323
102
                                     (n.getKind() == OR ? n[i] : n[i].negate()),
324
51
                                     branch_constraints[i],
325
                                     branch_cond);
326
          }
327
        }
328
32
        Trace("fmf-fun-def-debug2")
329
16
            << "Made branching condition " << branch_cond << std::endl;
330
16
        constraints.clear();
331
16
        constraints.push_back(branch_cond);
332
      }
333
    }
334
    else
335
    {
336
      // simplify term
337
994
      std::map<Node, Node> visitedT;
338
497
      getConstraints(n, constraints, visitedT);
339
    }
340
1100
    if (!constraints.empty() && isBool && hasPol)
341
    {
342
      // conjoin with current
343
238
      Node cons = nm->mkAnd(constraints);
344
119
      if (pol)
345
      {
346
101
        ret = nm->mkNode(AND, nn, cons);
347
      }
348
      else
349
      {
350
18
        ret = nm->mkNode(OR, nn, cons.negate());
351
      }
352
238
      Trace("fmf-fun-def-debug2")
353
119
          << "Add constraint to obtain " << ret << std::endl;
354
119
      constraints.clear();
355
    }
356
    else
357
    {
358
981
      ret = nn;
359
    }
360
  }
361
1204
  if (!constraints.empty())
362
  {
363
264
    Node cons;
364
    // flatten to AND node for the purposes of caching
365
132
    if (constraints.size() > 1)
366
    {
367
      cons = nm->mkNode(AND, constraints);
368
      cons = rewrite(cons);
369
      constraints.clear();
370
      constraints.push_back(cons);
371
    }
372
    else
373
    {
374
132
      cons = constraints[0];
375
    }
376
132
    visited_cons[index][n] = cons;
377
132
    Assert(constraints.size() == 1 && constraints[0] == cons);
378
  }
379
1204
  visited[index][n] = ret;
380
1204
  return ret;
381
}
382
383
2422
void FunDefFmf::getConstraints(Node n,
384
                               std::vector<Node>& constraints,
385
                               std::map<Node, Node>& visited)
386
{
387
2422
  std::map<Node, Node>::iterator itv = visited.find(n);
388
2422
  if (itv != visited.end())
389
  {
390
    // already visited
391
451
    if (!itv->second.isNull())
392
    {
393
      // add the cached constraint if it does not already occur
394
660
      if (std::find(constraints.begin(), constraints.end(), itv->second)
395
660
          == constraints.end())
396
      {
397
220
        constraints.push_back(itv->second);
398
      }
399
    }
400
451
    return;
401
  }
402
1971
  visited[n] = Node::null();
403
3942
  std::vector<Node> currConstraints;
404
1971
  NodeManager* nm = NodeManager::currentNM();
405
1971
  if (n.getKind() == ITE)
406
  {
407
    // collect constraints for the condition
408
75
    getConstraints(n[0], currConstraints, visited);
409
    // collect constraints for each branch
410
150
    Node cs[2];
411
225
    for (unsigned i = 0; i < 2; i++)
412
    {
413
300
      std::vector<Node> ccons;
414
150
      getConstraints(n[i + 1], ccons, visited);
415
150
      cs[i] = nm->mkAnd(ccons);
416
    }
417
75
    if (!cs[0].isConst() || !cs[1].isConst())
418
    {
419
72
      Node itec = nm->mkNode(ITE, n[0], cs[0], cs[1]);
420
36
      currConstraints.push_back(itec);
421
72
      Trace("fmf-fun-def-debug")
422
36
          << "---> add constraint " << itec << " for " << n << std::endl;
423
    }
424
  }
425
  else
426
  {
427
1896
    if (n.getKind() == APPLY_UF)
428
    {
429
      // check if f is defined, if so, we must enforce domain constraints for
430
      // this f-application
431
550
      Node f = n.getOperator();
432
275
      std::map<Node, TypeNode>::iterator it = d_sorts.find(f);
433
275
      if (it != d_sorts.end())
434
      {
435
        // create existential
436
248
        Node z = nm->mkBoundVar("?z", it->second);
437
248
        Node bvl = nm->mkNode(BOUND_VAR_LIST, z);
438
248
        std::vector<Node> children;
439
292
        for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
440
        {
441
336
          Node uz = nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], z);
442
168
          children.push_back(uz.eqNode(n[j]));
443
        }
444
248
        Node bd = nm->mkAnd(children);
445
124
        bd = bd.negate();
446
248
        Node ex = nm->mkNode(FORALL, bvl, bd);
447
124
        ex = ex.negate();
448
124
        currConstraints.push_back(ex);
449
248
        Trace("fmf-fun-def-debug")
450
124
            << "---> add constraint " << ex << " for " << n << std::endl;
451
      }
452
    }
453
3376
    for (const Node& cn : n)
454
    {
455
1480
      getConstraints(cn, currConstraints, visited);
456
    }
457
  }
458
  // set the visited cache
459
1971
  if (!currConstraints.empty())
460
  {
461
440
    Node finalc = nm->mkAnd(currConstraints);
462
220
    visited[n] = finalc;
463
    // add to constraints
464
220
    getConstraints(n, constraints, visited);
465
  }
466
}
467
468
}  // namespace passes
469
}  // namespace preprocessing
470
29502
}  // namespace cvc5