GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/ho_elim.cpp Lines: 323 330 97.9 %
Date: 2021-08-17 Branches: 700 1518 46.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * The HoElim preprocessing pass.
14
 *
15
 * Eliminates higher-order constraints.
16
 */
17
18
#include "preprocessing/passes/ho_elim.h"
19
20
#include <sstream>
21
22
#include "expr/node_algorithm.h"
23
#include "expr/skolem_manager.h"
24
#include "options/quantifiers_options.h"
25
#include "preprocessing/assertion_pipeline.h"
26
#include "theory/rewriter.h"
27
#include "theory/uf/theory_uf_rewriter.h"
28
29
using namespace cvc5::kind;
30
31
namespace cvc5 {
32
namespace preprocessing {
33
namespace passes {
34
35
9850
HoElim::HoElim(PreprocessingPassContext* preprocContext)
36
9850
    : PreprocessingPass(preprocContext, "ho-elim"){};
37
38
4542
Node HoElim::eliminateLambdaComplete(Node n, std::map<Node, Node>& newLambda)
39
{
40
4542
  NodeManager* nm = NodeManager::currentNM();
41
4542
  SkolemManager* sm = nm->getSkolemManager();
42
4542
  std::unordered_map<Node, Node>::iterator it;
43
9084
  std::vector<Node> visit;
44
9084
  TNode cur;
45
4542
  visit.push_back(n);
46
109295
  do
47
  {
48
113837
    cur = visit.back();
49
113837
    visit.pop_back();
50
113837
    it = d_visited.find(cur);
51
52
113837
    if (it == d_visited.end())
53
    {
54
44641
      if (cur.getKind() == LAMBDA)
55
      {
56
164
        Trace("ho-elim-ll") << "Lambda lift: " << cur << std::endl;
57
        // must also get free variables in lambda
58
328
        std::vector<Node> lvars;
59
328
        std::vector<TypeNode> ftypes;
60
328
        std::unordered_set<Node> fvs;
61
164
        expr::getFreeVariables(cur, fvs);
62
328
        std::vector<Node> nvars;
63
328
        std::vector<Node> vars;
64
328
        Node sbd = cur[1];
65
164
        if (!fvs.empty())
66
        {
67
252
          Trace("ho-elim-ll")
68
126
              << "Has " << fvs.size() << " free variables" << std::endl;
69
351
          for (const Node& v : fvs)
70
          {
71
450
            TypeNode vt = v.getType();
72
225
            ftypes.push_back(vt);
73
450
            Node vs = nm->mkBoundVar(vt);
74
225
            vars.push_back(v);
75
225
            nvars.push_back(vs);
76
225
            lvars.push_back(vs);
77
          }
78
126
          sbd = sbd.substitute(
79
              vars.begin(), vars.end(), nvars.begin(), nvars.end());
80
        }
81
359
        for (const Node& bv : cur[0])
82
        {
83
390
          TypeNode bvt = bv.getType();
84
195
          ftypes.push_back(bvt);
85
195
          lvars.push_back(bv);
86
        }
87
328
        Node nlambda = cur;
88
164
        if (!fvs.empty())
89
        {
90
126
          nlambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, lvars), sbd);
91
252
          Trace("ho-elim-ll")
92
126
              << "...new lambda definition: " << nlambda << std::endl;
93
        }
94
328
        TypeNode rangeType = cur.getType().getRangeType();
95
328
        TypeNode nft = nm->mkFunctionType(ftypes, rangeType);
96
328
        Node nf = sm->mkDummySkolem("ll", nft);
97
328
        Trace("ho-elim-ll")
98
164
            << "...introduce: " << nf << " of type " << nft << std::endl;
99
164
        newLambda[nf] = nlambda;
100
164
        Assert(nf.getType() == nlambda.getType());
101
164
        if (!vars.empty())
102
        {
103
351
          for (const Node& v : vars)
104
          {
105
225
            nf = nm->mkNode(HO_APPLY, nf, v);
106
          }
107
126
          Trace("ho-elim-ll") << "...partial application: " << nf << std::endl;
108
        }
109
164
        d_visited[cur] = nf;
110
328
        Trace("ho-elim-ll") << "...return types : " << nf.getType() << " "
111
164
                            << cur.getType() << std::endl;
112
164
        Assert(nf.getType() == cur.getType());
113
      }
114
      else
115
      {
116
44477
        d_visited[cur] = Node::null();
117
44477
        visit.push_back(cur);
118
109295
        for (const Node& cn : cur)
119
        {
120
64818
          visit.push_back(cn);
121
        }
122
      }
123
    }
124
69196
    else if (it->second.isNull())
125
    {
126
88954
      Node ret = cur;
127
44477
      bool childChanged = false;
128
88954
      std::vector<Node> children;
129
44477
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
130
      {
131
8226
        children.push_back(cur.getOperator());
132
      }
133
109295
      for (const Node& cn : cur)
134
      {
135
64818
        it = d_visited.find(cn);
136
64818
        Assert(it != d_visited.end());
137
64818
        Assert(!it->second.isNull());
138
64818
        childChanged = childChanged || cn != it->second;
139
64818
        children.push_back(it->second);
140
      }
141
44477
      if (childChanged)
142
      {
143
1451
        ret = nm->mkNode(cur.getKind(), children);
144
      }
145
44477
      d_visited[cur] = ret;
146
    }
147
113837
  } while (!visit.empty());
148
4542
  Assert(d_visited.find(n) != d_visited.end());
149
4542
  Assert(!d_visited.find(n)->second.isNull());
150
9084
  return d_visited[n];
151
}
152
153
4378
Node HoElim::eliminateHo(Node n)
154
{
155
4378
  Trace("ho-elim-assert") << "Ho-elim assertion: " << n << std::endl;
156
4378
  NodeManager* nm = NodeManager::currentNM();
157
4378
  SkolemManager* sm = nm->getSkolemManager();
158
4378
  std::unordered_map<Node, Node>::iterator it;
159
8756
  std::map<Node, Node> preReplace;
160
4378
  std::map<Node, Node>::iterator itr;
161
8756
  std::vector<TNode> visit;
162
8756
  TNode cur;
163
4378
  visit.push_back(n);
164
104034
  do
165
  {
166
108412
    cur = visit.back();
167
108412
    visit.pop_back();
168
108412
    it = d_visited.find(cur);
169
108412
    Trace("ho-elim-visit") << "Process: " << cur << std::endl;
170
171
108412
    if (it == d_visited.end())
172
    {
173
91814
      TypeNode tn = cur.getType();
174
      // lambdas are already eliminated by now
175
45907
      Assert(cur.getKind() != LAMBDA);
176
45907
      if (tn.isFunction())
177
      {
178
7883
        d_funTypes.insert(tn);
179
      }
180
45907
      if (cur.isVar())
181
      {
182
19956
        Node ret = cur;
183
9978
        if (options::hoElim())
184
        {
185
3526
          if (tn.isFunction())
186
          {
187
4156
            TypeNode ut = getUSort(tn);
188
2078
            if (cur.getKind() == BOUND_VARIABLE)
189
            {
190
1713
              ret = nm->mkBoundVar(ut);
191
            }
192
            else
193
            {
194
365
              ret = sm->mkDummySkolem("k", ut);
195
            }
196
            // must get the ho apply to ensure extensionality is applied
197
4156
            Node hoa = getHoApplyUf(tn);
198
2078
            Trace("ho-elim-visit") << "Hoa is " << hoa << std::endl;
199
          }
200
        }
201
9978
        d_visited[cur] = ret;
202
      }
203
      else
204
      {
205
35929
        d_visited[cur] = Node::null();
206
35929
        if (cur.getKind() == APPLY_UF && options::hoElim())
207
        {
208
1912
          Node op = cur.getOperator();
209
          // convert apply uf with variable arguments eagerly to ho apply
210
          // chains, so they are processed uniformly.
211
956
          visit.push_back(cur);
212
1912
          Node newCur = theory::uf::TheoryUfRewriter::getHoApplyForApplyUf(cur);
213
956
          preReplace[cur] = newCur;
214
956
          cur = newCur;
215
956
          d_visited[cur] = Node::null();
216
        }
217
35929
        visit.push_back(cur);
218
103078
        for (const Node& cn : cur)
219
        {
220
67149
          visit.push_back(cn);
221
        }
222
      }
223
    }
224
62505
    else if (it->second.isNull())
225
    {
226
73770
      Node ret = cur;
227
36885
      itr = preReplace.find(cur);
228
36885
      if (itr != preReplace.end())
229
      {
230
1912
        Trace("ho-elim-visit")
231
956
            << "return (pre-repl): " << d_visited[itr->second] << std::endl;
232
956
        d_visited[cur] = d_visited[itr->second];
233
      }
234
      else
235
      {
236
35929
        bool childChanged = false;
237
71858
        std::vector<Node> children;
238
71858
        std::vector<TypeNode> childrent;
239
35929
        bool typeChanged = false;
240
103078
        for (const Node& cn : ret)
241
        {
242
67149
          it = d_visited.find(cn);
243
67149
          Assert(it != d_visited.end());
244
67149
          Assert(!it->second.isNull());
245
67149
          childChanged = childChanged || cn != it->second;
246
67149
          children.push_back(it->second);
247
134298
          TypeNode ct = it->second.getType();
248
67149
          childrent.push_back(ct);
249
67149
          typeChanged = typeChanged || ct != cn.getType();
250
        }
251
35929
        if (ret.getMetaKind() == metakind::PARAMETERIZED)
252
        {
253
          // child of an argument changed type, must change type
254
14540
          Node op = ret.getOperator();
255
14540
          Node retOp = op;
256
14540
          Trace("ho-elim-visit")
257
7270
              << "Process op " << op << ", typeChanged = " << typeChanged
258
7270
              << std::endl;
259
7270
          if (typeChanged)
260
          {
261
            std::unordered_map<TNode, Node>::iterator ito =
262
                d_visited_op.find(op);
263
            if (ito == d_visited_op.end())
264
            {
265
              Assert(!childrent.empty());
266
              TypeNode newFType = nm->mkFunctionType(childrent, cur.getType());
267
              retOp = sm->mkDummySkolem("rf", newFType);
268
              d_visited_op[op] = retOp;
269
            }
270
            else
271
            {
272
              retOp = ito->second;
273
            }
274
          }
275
7270
          children.insert(children.begin(), retOp);
276
        }
277
        // process ho apply
278
35929
        if (ret.getKind() == HO_APPLY && options::hoElim())
279
        {
280
15592
          TypeNode tnr = ret.getType();
281
7796
          tnr = getUSort(tnr);
282
          Node hoa =
283
15592
              getHoApplyUf(children[0].getType(), children[1].getType(), tnr);
284
15592
          std::vector<Node> hchildren;
285
7796
          hchildren.push_back(hoa);
286
7796
          hchildren.push_back(children[0]);
287
7796
          hchildren.push_back(children[1]);
288
7796
          ret = nm->mkNode(APPLY_UF, hchildren);
289
        }
290
28133
        else if (childChanged)
291
        {
292
5229
          ret = nm->mkNode(ret.getKind(), children);
293
        }
294
35929
        Trace("ho-elim-visit") << "return (pre-repl): " << ret << std::endl;
295
35929
        d_visited[cur] = ret;
296
      }
297
    }
298
108412
  } while (!visit.empty());
299
4378
  Assert(d_visited.find(n) != d_visited.end());
300
4378
  Assert(!d_visited.find(n)->second.isNull());
301
4378
  Trace("ho-elim-assert") << "...got : " << d_visited[n] << std::endl;
302
8756
  return d_visited[n];
303
}
304
305
183
PreprocessingPassResult HoElim::applyInternal(
306
    AssertionPipeline* assertionsToPreprocess)
307
{
308
  // step [1]: apply lambda lifting to eliminate all lambdas
309
183
  NodeManager* nm = NodeManager::currentNM();
310
366
  std::vector<Node> axioms;
311
366
  std::map<Node, Node> newLambda;
312
4561
  for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
313
  {
314
8756
    Node prev = (*assertionsToPreprocess)[i];
315
8756
    Node res = eliminateLambdaComplete(prev, newLambda);
316
4378
    if (res != prev)
317
    {
318
300
      res = theory::Rewriter::rewrite(res);
319
300
      Assert(!expr::hasFreeVar(res));
320
300
      assertionsToPreprocess->replace(i, res);
321
    }
322
  }
323
  // do lambda lifting on new lambda definitions
324
  // this will do fixed point to eliminate lambdas within lambda lifting axioms.
325
225
  while (!newLambda.empty())
326
  {
327
42
    std::map<Node, Node> lproc = newLambda;
328
21
    newLambda.clear();
329
185
    for (const std::pair<const Node, Node>& l : lproc)
330
    {
331
328
      Node lambda = l.second;
332
328
      std::vector<Node> vars;
333
328
      std::vector<Node> nvars;
334
584
      for (const Node& v : lambda[0])
335
      {
336
840
        Node bv = nm->mkBoundVar(v.getType());
337
420
        vars.push_back(v);
338
420
        nvars.push_back(bv);
339
      }
340
341
      Node bd = lambda[1].substitute(
342
328
          vars.begin(), vars.end(), nvars.begin(), nvars.end());
343
328
      Node bvl = nm->mkNode(BOUND_VAR_LIST, nvars);
344
345
164
      nvars.insert(nvars.begin(), l.first);
346
328
      Node curr = nm->mkNode(APPLY_UF, nvars);
347
348
328
      Node llfax = nm->mkNode(FORALL, bvl, curr.eqNode(bd));
349
328
      Trace("ho-elim-ax") << "Lambda lifting axiom (pre-elim) " << llfax
350
164
                          << " for " << lambda << std::endl;
351
164
      Assert(!expr::hasFreeVar(llfax));
352
328
      Node llfaxe = eliminateLambdaComplete(llfax, newLambda);
353
328
      Trace("ho-elim-ax") << "Lambda lifting axiom " << llfaxe << " for "
354
164
                          << lambda << std::endl;
355
164
      axioms.push_back(llfaxe);
356
    }
357
  }
358
359
183
  d_visited.clear();
360
  // add lambda lifting axioms as a conjunction to the first assertion
361
183
  if (!axioms.empty())
362
  {
363
34
    Node conj = nm->mkAnd(axioms);
364
17
    conj = theory::Rewriter::rewrite(conj);
365
17
    Assert(!expr::hasFreeVar(conj));
366
17
    assertionsToPreprocess->conjoin(0, conj);
367
  }
368
183
  axioms.clear();
369
370
  // step [2]: eliminate all higher-order constraints
371
4561
  for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
372
  {
373
8756
    Node prev = (*assertionsToPreprocess)[i];
374
8756
    Node res = eliminateHo(prev);
375
4378
    if (res != prev)
376
    {
377
1432
      res = theory::Rewriter::rewrite(res);
378
1432
      Assert(!expr::hasFreeVar(res));
379
1432
      assertionsToPreprocess->replace(i, res);
380
    }
381
  }
382
  // extensionality: process all function types
383
856
  for (const TypeNode& ftn : d_funTypes)
384
  {
385
673
    if (options::hoElim())
386
    {
387
530
      Node h = getHoApplyUf(ftn);
388
265
      Trace("ho-elim-ax") << "Make extensionality for " << h << std::endl;
389
530
      TypeNode ft = h.getType();
390
530
      TypeNode uf = getUSort(ft[0]);
391
530
      TypeNode ut = getUSort(ft[1]);
392
      // extensionality
393
530
      Node x = nm->mkBoundVar("x", uf);
394
530
      Node y = nm->mkBoundVar("y", uf);
395
530
      Node z = nm->mkBoundVar("z", ut);
396
      Node eq =
397
530
          nm->mkNode(APPLY_UF, h, x, z).eqNode(nm->mkNode(APPLY_UF, h, y, z));
398
530
      Node antec = nm->mkNode(FORALL, nm->mkNode(BOUND_VAR_LIST, z), eq);
399
530
      Node conc = x.eqNode(y);
400
      Node ax = nm->mkNode(FORALL,
401
530
                           nm->mkNode(BOUND_VAR_LIST, x, y),
402
1060
                           nm->mkNode(OR, antec.negate(), conc));
403
265
      axioms.push_back(ax);
404
265
      Trace("ho-elim-ax") << "...ext axiom : " << ax << std::endl;
405
      // Make the "store" axiom, which asserts for every function, there
406
      // exists another function that acts like the "store" operator for
407
      // arrays, e.g. it is the same function with one I/O pair updated.
408
      // Without this axiom, the translation is model unsound.
409
265
      if (options::hoElimStoreAx())
410
      {
411
530
        Node u = nm->mkBoundVar("u", uf);
412
530
        Node v = nm->mkBoundVar("v", uf);
413
530
        Node i = nm->mkBoundVar("i", ut);
414
530
        Node ii = nm->mkBoundVar("ii", ut);
415
530
        Node huii = nm->mkNode(APPLY_UF, h, u, ii);
416
530
        Node e = nm->mkBoundVar("e", huii.getType());
417
        Node store = nm->mkNode(
418
            FORALL,
419
530
            nm->mkNode(BOUND_VAR_LIST, u, e, i),
420
1060
            nm->mkNode(EXISTS,
421
530
                       nm->mkNode(BOUND_VAR_LIST, v),
422
1060
                       nm->mkNode(FORALL,
423
530
                                  nm->mkNode(BOUND_VAR_LIST, ii),
424
530
                                  nm->mkNode(APPLY_UF, h, v, ii)
425
1060
                                      .eqNode(nm->mkNode(
426
1590
                                          ITE, ii.eqNode(i), e, huii)))));
427
265
        axioms.push_back(store);
428
265
        Trace("ho-elim-ax") << "...store axiom : " << store << std::endl;
429
      }
430
    }
431
408
    else if (options::hoElimStoreAx())
432
    {
433
2
      Node u = nm->mkBoundVar("u", ftn);
434
2
      Node v = nm->mkBoundVar("v", ftn);
435
2
      std::vector<TypeNode> argTypes = ftn.getArgTypes();
436
2
      Node i = nm->mkBoundVar("i", argTypes[0]);
437
2
      Node ii = nm->mkBoundVar("ii", argTypes[0]);
438
2
      Node huii = nm->mkNode(HO_APPLY, u, ii);
439
2
      Node e = nm->mkBoundVar("e", huii.getType());
440
      Node store = nm->mkNode(
441
          FORALL,
442
2
          nm->mkNode(BOUND_VAR_LIST, u, e, i),
443
4
          nm->mkNode(
444
              EXISTS,
445
2
              nm->mkNode(BOUND_VAR_LIST, v),
446
4
              nm->mkNode(FORALL,
447
2
                         nm->mkNode(BOUND_VAR_LIST, ii),
448
2
                         nm->mkNode(HO_APPLY, v, ii)
449
6
                             .eqNode(nm->mkNode(ITE, ii.eqNode(i), e, huii)))));
450
1
      axioms.push_back(store);
451
2
      Trace("ho-elim-ax") << "...store (ho_apply) axiom : " << store
452
1
                          << std::endl;
453
    }
454
  }
455
  // add new axioms as a conjunction to the first assertion
456
183
  if (!axioms.empty())
457
  {
458
20
    Node conj = nm->mkAnd(axioms);
459
10
    conj = theory::Rewriter::rewrite(conj);
460
10
    Assert(!expr::hasFreeVar(conj));
461
10
    assertionsToPreprocess->conjoin(0, conj);
462
  }
463
464
366
  return PreprocessingPassResult::NO_CONFLICT;
465
}
466
467
2343
Node HoElim::getHoApplyUf(TypeNode tn)
468
{
469
4686
  TypeNode tnu = getUSort(tn);
470
4686
  TypeNode rangeType = tn.getRangeType();
471
4686
  std::vector<TypeNode> argTypes = tn.getArgTypes();
472
4686
  TypeNode tna = getUSort(argTypes[0]);
473
474
4686
  TypeNode tr = rangeType;
475
2343
  if (argTypes.size() > 1)
476
  {
477
1400
    std::vector<TypeNode> remArgTypes;
478
700
    remArgTypes.insert(remArgTypes.end(), argTypes.begin() + 1, argTypes.end());
479
700
    tr = NodeManager::currentNM()->mkFunctionType(remArgTypes, tr);
480
  }
481
4686
  TypeNode tnr = getUSort(tr);
482
483
4686
  return getHoApplyUf(tnu, tna, tnr);
484
}
485
486
10139
Node HoElim::getHoApplyUf(TypeNode tnf, TypeNode tna, TypeNode tnr)
487
{
488
10139
  std::map<TypeNode, Node>::iterator it = d_hoApplyUf.find(tnf);
489
10139
  if (it == d_hoApplyUf.end())
490
  {
491
265
    NodeManager* nm = NodeManager::currentNM();
492
265
    SkolemManager* sm = nm->getSkolemManager();
493
494
530
    std::vector<TypeNode> hoTypeArgs;
495
265
    hoTypeArgs.push_back(tnf);
496
265
    hoTypeArgs.push_back(tna);
497
530
    TypeNode tnh = nm->mkFunctionType(hoTypeArgs, tnr);
498
530
    Node k = sm->mkDummySkolem("ho", tnh);
499
265
    d_hoApplyUf[tnf] = k;
500
265
    return k;
501
  }
502
9874
  return it->second;
503
}
504
505
17883
TypeNode HoElim::getUSort(TypeNode tn)
506
{
507
17883
  if (!tn.isFunction())
508
  {
509
7170
    return tn;
510
  }
511
10713
  std::map<TypeNode, TypeNode>::iterator it = d_ftypeMap.find(tn);
512
10713
  if (it == d_ftypeMap.end())
513
  {
514
    // flatten function arguments
515
880
    std::vector<TypeNode> argTypes = tn.getArgTypes();
516
880
    TypeNode rangeType = tn.getRangeType();
517
440
    bool typeChanged = false;
518
1419
    for (unsigned i = 0; i < argTypes.size(); i++)
519
    {
520
979
      if (argTypes[i].isFunction())
521
      {
522
275
        argTypes[i] = getUSort(argTypes[i]);
523
275
        typeChanged = true;
524
      }
525
    }
526
880
    TypeNode s;
527
440
    if (typeChanged)
528
    {
529
      TypeNode ntn =
530
350
          NodeManager::currentNM()->mkFunctionType(argTypes, rangeType);
531
175
      s = getUSort(ntn);
532
    }
533
    else
534
    {
535
530
      std::stringstream ss;
536
265
      ss << "u_" << tn;
537
265
      s = NodeManager::currentNM()->mkSort(ss.str());
538
    }
539
440
    d_ftypeMap[tn] = s;
540
440
    return s;
541
  }
542
10273
  return it->second;
543
}
544
545
}  // namespace passes
546
}  // namespace preprocessing
547
29337
}  // namespace cvc5