GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/fun_def_fmf.cpp Lines: 219 231 94.8 %
Date: 2021-09-29 Branches: 462 970 47.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Haniel Barbosa, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Function definition processor for finite model finding.
14
 */
15
16
#include "preprocessing/passes/fun_def_fmf.h"
17
18
#include <sstream>
19
20
#include "expr/skolem_manager.h"
21
#include "options/smt_options.h"
22
#include "preprocessing/assertion_pipeline.h"
23
#include "preprocessing/preprocessing_pass_context.h"
24
#include "theory/quantifiers/quantifiers_attributes.h"
25
#include "theory/quantifiers/term_database.h"
26
#include "theory/quantifiers/term_util.h"
27
#include "theory/rewriter.h"
28
29
using namespace std;
30
using namespace cvc5::kind;
31
using namespace cvc5::theory;
32
using namespace cvc5::theory::quantifiers;
33
34
namespace cvc5 {
35
namespace preprocessing {
36
namespace passes {
37
38
6271
FunDefFmf::FunDefFmf(PreprocessingPassContext* preprocContext)
39
    : PreprocessingPass(preprocContext, "fun-def-fmf"),
40
6271
      d_fmfRecFunctionsDefined(nullptr)
41
{
42
6271
  d_fmfRecFunctionsDefined = new (true) NodeList(userContext());
43
6271
}
44
45
12536
FunDefFmf::~FunDefFmf() { d_fmfRecFunctionsDefined->deleteSelf(); }
46
47
73
PreprocessingPassResult FunDefFmf::applyInternal(
48
    AssertionPipeline* assertionsToPreprocess)
49
{
50
73
  Assert(d_fmfRecFunctionsDefined != nullptr);
51
  // reset
52
73
  d_sorts.clear();
53
73
  d_input_arg_inj.clear();
54
73
  d_funcs.clear();
55
  // must carry over current definitions (in case of incremental)
56
95
  for (context::CDList<Node>::const_iterator fit =
57
73
           d_fmfRecFunctionsDefined->begin();
58
168
       fit != d_fmfRecFunctionsDefined->end();
59
       ++fit)
60
  {
61
190
    Node f = (*fit);
62
95
    Assert(d_fmfRecFunctionsAbs.find(f) != d_fmfRecFunctionsAbs.end());
63
190
    TypeNode ft = d_fmfRecFunctionsAbs[f];
64
95
    d_sorts[f] = ft;
65
    std::map<Node, std::vector<Node>>::iterator fcit =
66
95
        d_fmfRecFunctionsConcrete.find(f);
67
95
    Assert(fcit != d_fmfRecFunctionsConcrete.end());
68
633
    for (const Node& fcc : fcit->second)
69
    {
70
538
      d_input_arg_inj[f].push_back(fcc);
71
    }
72
  }
73
73
  process(assertionsToPreprocess);
74
  // must store new definitions (in case of incremental)
75
131
  for (const Node& f : d_funcs)
76
  {
77
58
    d_fmfRecFunctionsAbs[f] = d_sorts[f];
78
58
    d_fmfRecFunctionsConcrete[f].clear();
79
233
    for (const Node& fcc : d_input_arg_inj[f])
80
    {
81
175
      d_fmfRecFunctionsConcrete[f].push_back(fcc);
82
    }
83
58
    d_fmfRecFunctionsDefined->push_back(f);
84
  }
85
73
  return PreprocessingPassResult::NO_CONFLICT;
86
}
87
88
73
void FunDefFmf::process(AssertionPipeline* assertionsToPreprocess)
89
{
90
73
  const std::vector<Node>& assertions = assertionsToPreprocess->ref();
91
146
  std::vector<int> fd_assertions;
92
146
  std::map<int, Node> subs_head;
93
  // first pass : find defined functions, transform quantifiers
94
73
  NodeManager* nm = NodeManager::currentNM();
95
73
  SkolemManager* sm = nm->getSkolemManager();
96
371
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
97
  {
98
596
    Node n = QuantAttributes::getFunDefHead(assertions[i]);
99
298
    if (!n.isNull())
100
    {
101
58
      Assert(n.getKind() == APPLY_UF);
102
116
      Node f = n.getOperator();
103
104
      // check if already defined, if so, throw error
105
58
      if (d_sorts.find(f) != d_sorts.end())
106
      {
107
        Unhandled() << "Cannot define function " << f << " more than once.";
108
      }
109
110
116
      Node bd = QuantAttributes::getFunDefBody(assertions[i]);
111
116
      Trace("fmf-fun-def-debug")
112
58
          << "Process function " << n << ", body = " << bd << std::endl;
113
58
      if (!bd.isNull())
114
      {
115
58
        d_funcs.push_back(f);
116
58
        bd = nm->mkNode(EQUAL, n, bd);
117
118
        // create a sort S that represents the inputs of the function
119
116
        std::stringstream ss;
120
58
        ss << "I_" << f;
121
116
        TypeNode iType = nm->mkSort(ss.str());
122
        AbsTypeFunDefAttribute atfda;
123
58
        iType.setAttribute(atfda, true);
124
58
        d_sorts[f] = iType;
125
126
        // create functions f1...fn mapping from this sort to concrete elements
127
58
        size_t nchildn = n.getNumChildren();
128
233
        for (size_t j = 0; j < nchildn; j++)
129
        {
130
350
          TypeNode typ = nm->mkFunctionType(iType, n[j].getType());
131
350
          std::stringstream ssf;
132
175
          ssf << f << "_arg_" << j;
133
525
          d_input_arg_inj[f].push_back(sm->mkDummySkolem(
134
350
              ssf.str(), typ, "op created during fun def fmf"));
135
        }
136
137
        // construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
138
116
        std::vector<Node> children;
139
116
        Node bv = nm->mkBoundVar("?i", iType);
140
116
        Node bvl = nm->mkNode(BOUND_VAR_LIST, bv);
141
116
        std::vector<Node> subs;
142
116
        std::vector<Node> vars;
143
233
        for (size_t j = 0; j < nchildn; j++)
144
        {
145
175
          vars.push_back(n[j]);
146
175
          subs.push_back(nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], bv));
147
        }
148
58
        bd = bd.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
149
58
        subs_head[i] =
150
116
            n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
151
152
116
        Trace("fmf-fun-def")
153
58
            << "FMF fun def: FUNCTION : rewrite " << assertions[i] << std::endl;
154
58
        Trace("fmf-fun-def") << "  to " << std::endl;
155
116
        Node new_q = nm->mkNode(FORALL, bvl, bd);
156
58
        new_q = rewrite(new_q);
157
58
        assertionsToPreprocess->replace(i, new_q);
158
58
        Trace("fmf-fun-def") << "  " << assertions[i] << std::endl;
159
58
        fd_assertions.push_back(i);
160
      }
161
      else
162
      {
163
        // can be, e.g. in corner cases forall x. f(x)=f(x), forall x.
164
        // f(x)=f(x)+1
165
      }
166
    }
167
  }
168
  // second pass : rewrite assertions
169
146
  std::map<int, std::map<Node, Node>> visited;
170
146
  std::map<int, std::map<Node, Node>> visited_cons;
171
371
  for (size_t i = 0, asize = assertions.size(); i < asize; i++)
172
  {
173
596
    bool is_fd = std::find(fd_assertions.begin(), fd_assertions.end(), i)
174
894
                 != fd_assertions.end();
175
596
    std::vector<Node> constraints;
176
596
    Trace("fmf-fun-def-rewrite")
177
298
        << "Rewriting " << assertions[i]
178
298
        << ", is function definition = " << is_fd << std::endl;
179
298
    Node n = simplifyFormula(assertions[i],
180
                             true,
181
                             true,
182
                             constraints,
183
596
                             is_fd ? subs_head[i] : Node::null(),
184
                             is_fd,
185
                             visited,
186
1192
                             visited_cons);
187
298
    Assert(constraints.empty());
188
298
    if (n != assertions[i])
189
    {
190
58
      n = rewrite(n);
191
116
      Trace("fmf-fun-def-rewrite")
192
58
          << "FMF fun def : rewrite " << assertions[i] << std::endl;
193
58
      Trace("fmf-fun-def-rewrite") << "  to " << std::endl;
194
58
      Trace("fmf-fun-def-rewrite") << "  " << n << std::endl;
195
58
      assertionsToPreprocess->replace(i, n);
196
    }
197
  }
198
73
}
199
200
969
Node FunDefFmf::simplifyFormula(
201
    Node n,
202
    bool pol,
203
    bool hasPol,
204
    std::vector<Node>& constraints,
205
    Node hd,
206
    bool is_fun_def,
207
    std::map<int, std::map<Node, Node>>& visited,
208
    std::map<int, std::map<Node, Node>>& visited_cons)
209
{
210
969
  Assert(constraints.empty());
211
969
  int index = (is_fun_def ? 1 : 0) + 2 * (hasPol ? (pol ? 1 : -1) : 0);
212
969
  std::map<Node, Node>::iterator itv = visited[index].find(n);
213
969
  if (itv != visited[index].end())
214
  {
215
    // constraints.insert( visited_cons[index]
216
101
    std::map<Node, Node>::iterator itvc = visited_cons[index].find(n);
217
101
    if (itvc != visited_cons[index].end())
218
    {
219
      constraints.push_back(itvc->second);
220
    }
221
101
    return itv->second;
222
  }
223
868
  NodeManager* nm = NodeManager::currentNM();
224
1736
  Node ret;
225
1736
  Trace("fmf-fun-def-debug2") << "Simplify " << n << " " << pol << " " << hasPol
226
868
                              << " " << is_fun_def << std::endl;
227
868
  if (n.getKind() == FORALL)
228
  {
229
    Node c = simplifyFormula(
230
144
        n[1], pol, hasPol, constraints, hd, is_fun_def, visited, visited_cons);
231
    // append prenex to constraints
232
72
    for (unsigned i = 0; i < constraints.size(); i++)
233
    {
234
      constraints[i] = nm->mkNode(FORALL, n[0], constraints[i]);
235
      constraints[i] = rewrite(constraints[i]);
236
    }
237
72
    if (c != n[1])
238
    {
239
27
      ret = nm->mkNode(FORALL, n[0], c);
240
    }
241
    else
242
    {
243
45
      ret = n;
244
    }
245
  }
246
  else
247
  {
248
1592
    Node nn = n;
249
796
    bool isBool = n.getType().isBoolean();
250
796
    if (isBool && n.getKind() != APPLY_UF)
251
    {
252
860
      std::vector<Node> children;
253
430
      bool childChanged = false;
254
      // are we at a branch position (not all children are necessarily
255
      // relevant)?
256
      bool branch_pos =
257
430
          (n.getKind() == ITE || n.getKind() == OR || n.getKind() == AND);
258
860
      std::vector<Node> branch_constraints;
259
1086
      for (unsigned i = 0; i < n.getNumChildren(); i++)
260
      {
261
1312
        Node c = n[i];
262
        // do not process LHS of definition
263
656
        if (!is_fun_def || c != hd)
264
        {
265
          bool newHasPol;
266
          bool newPol;
267
599
          QuantPhaseReq::getPolarity(n, i, hasPol, pol, newHasPol, newPol);
268
          // get child constraints
269
1198
          std::vector<Node> cconstraints;
270
599
          c = simplifyFormula(n[i],
271
                              newPol,
272
                              newHasPol,
273
                              cconstraints,
274
                              hd,
275
                              false,
276
                              visited,
277
                              visited_cons);
278
599
          if (branch_pos)
279
          {
280
            // if at a branching position, the other constraints don't matter
281
            // if this is satisfied
282
284
            Node bcons = nm->mkAnd(cconstraints);
283
142
            branch_constraints.push_back(bcons);
284
284
            Trace("fmf-fun-def-debug2") << "Branching constraint at arg " << i
285
142
                                        << " is " << bcons << std::endl;
286
          }
287
599
          constraints.insert(
288
1198
              constraints.end(), cconstraints.begin(), cconstraints.end());
289
        }
290
656
        children.push_back(c);
291
656
        childChanged = c != n[i] || childChanged;
292
      }
293
430
      if (childChanged)
294
      {
295
11
        nn = nm->mkNode(n.getKind(), children);
296
      }
297
430
      if (branch_pos && !constraints.empty())
298
      {
299
        // if we are at a branching position in the formula, we can
300
        // minimize recursive constraints on recursively defined predicates if
301
        // we know one child forces the overall evaluation of this formula.
302
32
        Node branch_cond;
303
16
        if (n.getKind() == ITE)
304
        {
305
          // always care about constraints on the head of the ITE, but only
306
          // care about one of the children depending on how it evaluates
307
          branch_cond = nm->mkNode(
308
              AND,
309
              branch_constraints[0],
310
              nm->mkNode(
311
                  ITE, n[0], branch_constraints[1], branch_constraints[2]));
312
        }
313
        else
314
        {
315
          // in the default case, we care about all conditions
316
16
          branch_cond = nm->mkAnd(constraints);
317
67
          for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
318
          {
319
            // if this child holds with forcing polarity (true child of OR or
320
            // false child of AND), then we only care about its associated
321
            // recursive conditions
322
204
            branch_cond = nm->mkNode(ITE,
323
102
                                     (n.getKind() == OR ? n[i] : n[i].negate()),
324
51
                                     branch_constraints[i],
325
                                     branch_cond);
326
          }
327
        }
328
32
        Trace("fmf-fun-def-debug2")
329
16
            << "Made branching condition " << branch_cond << std::endl;
330
16
        constraints.clear();
331
16
        constraints.push_back(branch_cond);
332
      }
333
    }
334
    else
335
    {
336
      // simplify term
337
732
      std::map<Node, Node> visitedT;
338
366
      getConstraints(n, constraints, visitedT);
339
    }
340
796
    if (!constraints.empty() && isBool && hasPol)
341
    {
342
      // conjoin with current
343
118
      Node cons = nm->mkAnd(constraints);
344
59
      if (pol)
345
      {
346
51
        ret = nm->mkNode(AND, nn, cons);
347
      }
348
      else
349
      {
350
8
        ret = nm->mkNode(OR, nn, cons.negate());
351
      }
352
118
      Trace("fmf-fun-def-debug2")
353
59
          << "Add constraint to obtain " << ret << std::endl;
354
59
      constraints.clear();
355
    }
356
    else
357
    {
358
737
      ret = nn;
359
    }
360
  }
361
868
  if (!constraints.empty())
362
  {
363
136
    Node cons;
364
    // flatten to AND node for the purposes of caching
365
68
    if (constraints.size() > 1)
366
    {
367
      cons = nm->mkNode(AND, constraints);
368
      cons = rewrite(cons);
369
      constraints.clear();
370
      constraints.push_back(cons);
371
    }
372
    else
373
    {
374
68
      cons = constraints[0];
375
    }
376
68
    visited_cons[index][n] = cons;
377
68
    Assert(constraints.size() == 1 && constraints[0] == cons);
378
  }
379
868
  visited[index][n] = ret;
380
868
  return ret;
381
}
382
383
1779
void FunDefFmf::getConstraints(Node n,
384
                               std::vector<Node>& constraints,
385
                               std::map<Node, Node>& visited)
386
{
387
1779
  std::map<Node, Node>::iterator itv = visited.find(n);
388
1779
  if (itv != visited.end())
389
  {
390
    // already visited
391
242
    if (!itv->second.isNull())
392
    {
393
      // add the cached constraint if it does not already occur
394
291
      if (std::find(constraints.begin(), constraints.end(), itv->second)
395
291
          == constraints.end())
396
      {
397
97
        constraints.push_back(itv->second);
398
      }
399
    }
400
242
    return;
401
  }
402
1537
  visited[n] = Node::null();
403
3074
  std::vector<Node> currConstraints;
404
1537
  NodeManager* nm = NodeManager::currentNM();
405
1537
  if (n.getKind() == ITE)
406
  {
407
    // collect constraints for the condition
408
41
    getConstraints(n[0], currConstraints, visited);
409
    // collect constraints for each branch
410
82
    Node cs[2];
411
123
    for (unsigned i = 0; i < 2; i++)
412
    {
413
164
      std::vector<Node> ccons;
414
82
      getConstraints(n[i + 1], ccons, visited);
415
82
      cs[i] = nm->mkAnd(ccons);
416
    }
417
41
    if (!cs[0].isConst() || !cs[1].isConst())
418
    {
419
28
      Node itec = nm->mkNode(ITE, n[0], cs[0], cs[1]);
420
14
      currConstraints.push_back(itec);
421
28
      Trace("fmf-fun-def-debug")
422
14
          << "---> add constraint " << itec << " for " << n << std::endl;
423
    }
424
  }
425
  else
426
  {
427
1496
    if (n.getKind() == APPLY_UF)
428
    {
429
      // check if f is defined, if so, we must enforce domain constraints for
430
      // this f-application
431
344
      Node f = n.getOperator();
432
172
      std::map<Node, TypeNode>::iterator it = d_sorts.find(f);
433
172
      if (it != d_sorts.end())
434
      {
435
        // create existential
436
122
        Node z = nm->mkBoundVar("?z", it->second);
437
122
        Node bvl = nm->mkNode(BOUND_VAR_LIST, z);
438
122
        std::vector<Node> children;
439
154
        for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
440
        {
441
186
          Node uz = nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], z);
442
93
          children.push_back(uz.eqNode(n[j]));
443
        }
444
122
        Node bd = nm->mkAnd(children);
445
61
        bd = bd.negate();
446
122
        Node ex = nm->mkNode(FORALL, bvl, bd);
447
61
        ex = ex.negate();
448
61
        currConstraints.push_back(ex);
449
122
        Trace("fmf-fun-def-debug")
450
61
            << "---> add constraint " << ex << " for " << n << std::endl;
451
      }
452
    }
453
2689
    for (const Node& cn : n)
454
    {
455
1193
      getConstraints(cn, currConstraints, visited);
456
    }
457
  }
458
  // set the visited cache
459
1537
  if (!currConstraints.empty())
460
  {
461
194
    Node finalc = nm->mkAnd(currConstraints);
462
97
    visited[n] = finalc;
463
    // add to constraints
464
97
    getConstraints(n, constraints, visited);
465
  }
466
}
467
468
}  // namespace passes
469
}  // namespace preprocessing
470
22746
}  // namespace cvc5