GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/ho_elim.cpp Lines: 320 327 97.9 %
Date: 2021-03-22 Branches: 709 1540 46.0 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file ho_elim.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Andres Noetzli
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief The HoElim preprocessing pass
13
 **
14
 ** Eliminates higher-order constraints.
15
 **/
16
17
#include "preprocessing/passes/ho_elim.h"
18
19
#include "expr/node_algorithm.h"
20
#include "options/quantifiers_options.h"
21
#include "preprocessing/assertion_pipeline.h"
22
#include "theory/rewriter.h"
23
#include "theory/uf/theory_uf_rewriter.h"
24
25
using namespace CVC4::kind;
26
27
namespace CVC4 {
28
namespace preprocessing {
29
namespace passes {
30
31
8995
HoElim::HoElim(PreprocessingPassContext* preprocContext)
32
8995
    : PreprocessingPass(preprocContext, "ho-elim"){};
33
34
4727
Node HoElim::eliminateLambdaComplete(Node n, std::map<Node, Node>& newLambda)
35
{
36
4727
  NodeManager* nm = NodeManager::currentNM();
37
4727
  std::unordered_map<Node, Node, TNodeHashFunction>::iterator it;
38
9454
  std::vector<Node> visit;
39
9454
  TNode cur;
40
4727
  visit.push_back(n);
41
114179
  do
42
  {
43
118906
    cur = visit.back();
44
118906
    visit.pop_back();
45
118906
    it = d_visited.find(cur);
46
47
118906
    if (it == d_visited.end())
48
    {
49
46551
      if (cur.getKind() == LAMBDA)
50
      {
51
183
        Trace("ho-elim-ll") << "Lambda lift: " << cur << std::endl;
52
        // must also get free variables in lambda
53
366
        std::vector<Node> lvars;
54
366
        std::vector<TypeNode> ftypes;
55
366
        std::unordered_set<Node, NodeHashFunction> fvs;
56
183
        expr::getFreeVariables(cur, fvs);
57
366
        std::vector<Node> nvars;
58
366
        std::vector<Node> vars;
59
366
        Node sbd = cur[1];
60
183
        if (!fvs.empty())
61
        {
62
252
          Trace("ho-elim-ll")
63
126
              << "Has " << fvs.size() << " free variables" << std::endl;
64
352
          for (const Node& v : fvs)
65
          {
66
452
            TypeNode vt = v.getType();
67
226
            ftypes.push_back(vt);
68
452
            Node vs = nm->mkBoundVar(vt);
69
226
            vars.push_back(v);
70
226
            nvars.push_back(vs);
71
226
            lvars.push_back(vs);
72
          }
73
126
          sbd = sbd.substitute(
74
              vars.begin(), vars.end(), nvars.begin(), nvars.end());
75
        }
76
385
        for (const Node& bv : cur[0])
77
        {
78
404
          TypeNode bvt = bv.getType();
79
202
          ftypes.push_back(bvt);
80
202
          lvars.push_back(bv);
81
        }
82
366
        Node nlambda = cur;
83
183
        if (!fvs.empty())
84
        {
85
126
          nlambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, lvars), sbd);
86
252
          Trace("ho-elim-ll")
87
126
              << "...new lambda definition: " << nlambda << std::endl;
88
        }
89
366
        TypeNode rangeType = cur.getType().getRangeType();
90
366
        TypeNode nft = nm->mkFunctionType(ftypes, rangeType);
91
366
        Node nf = nm->mkSkolem("ll", nft);
92
366
        Trace("ho-elim-ll")
93
183
            << "...introduce: " << nf << " of type " << nft << std::endl;
94
183
        newLambda[nf] = nlambda;
95
183
        Assert(nf.getType() == nlambda.getType());
96
183
        if (!vars.empty())
97
        {
98
352
          for (const Node& v : vars)
99
          {
100
226
            nf = nm->mkNode(HO_APPLY, nf, v);
101
          }
102
126
          Trace("ho-elim-ll") << "...partial application: " << nf << std::endl;
103
        }
104
183
        d_visited[cur] = nf;
105
366
        Trace("ho-elim-ll") << "...return types : " << nf.getType() << " "
106
183
                            << cur.getType() << std::endl;
107
183
        Assert(nf.getType() == cur.getType());
108
      }
109
      else
110
      {
111
46368
        d_visited[cur] = Node::null();
112
46368
        visit.push_back(cur);
113
114179
        for (const Node& cn : cur)
114
        {
115
67811
          visit.push_back(cn);
116
        }
117
      }
118
    }
119
72355
    else if (it->second.isNull())
120
    {
121
92736
      Node ret = cur;
122
46368
      bool childChanged = false;
123
92736
      std::vector<Node> children;
124
46368
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
125
      {
126
8320
        children.push_back(cur.getOperator());
127
      }
128
114179
      for (const Node& cn : cur)
129
      {
130
67811
        it = d_visited.find(cn);
131
67811
        Assert(it != d_visited.end());
132
67811
        Assert(!it->second.isNull());
133
67811
        childChanged = childChanged || cn != it->second;
134
67811
        children.push_back(it->second);
135
      }
136
46368
      if (childChanged)
137
      {
138
1454
        ret = nm->mkNode(cur.getKind(), children);
139
      }
140
46368
      d_visited[cur] = ret;
141
    }
142
118906
  } while (!visit.empty());
143
4727
  Assert(d_visited.find(n) != d_visited.end());
144
4727
  Assert(!d_visited.find(n)->second.isNull());
145
9454
  return d_visited[n];
146
}
147
148
4544
Node HoElim::eliminateHo(Node n)
149
{
150
4544
  Trace("ho-elim-assert") << "Ho-elim assertion: " << n << std::endl;
151
4544
  NodeManager* nm = NodeManager::currentNM();
152
4544
  std::unordered_map<Node, Node, NodeHashFunction>::iterator it;
153
9088
  std::map<Node, Node> preReplace;
154
4544
  std::map<Node, Node>::iterator itr;
155
9088
  std::vector<TNode> visit;
156
9088
  TNode cur;
157
4544
  visit.push_back(n);
158
108862
  do
159
  {
160
113406
    cur = visit.back();
161
113406
    visit.pop_back();
162
113406
    it = d_visited.find(cur);
163
113406
    Trace("ho-elim-visit") << "Process: " << cur << std::endl;
164
165
113406
    if (it == d_visited.end())
166
    {
167
95730
      TypeNode tn = cur.getType();
168
      // lambdas are already eliminated by now
169
47865
      Assert(cur.getKind() != LAMBDA);
170
47865
      if (tn.isFunction())
171
      {
172
8042
        d_funTypes.insert(tn);
173
      }
174
47865
      if (cur.isVar())
175
      {
176
20502
        Node ret = cur;
177
38330
        if (options::hoElim())
178
        {
179
3724
          if (tn.isFunction())
180
          {
181
4186
            TypeNode ut = getUSort(tn);
182
2093
            if (cur.getKind() == BOUND_VARIABLE)
183
            {
184
1717
              ret = nm->mkBoundVar(ut);
185
            }
186
            else
187
            {
188
376
              ret = nm->mkSkolem("k", ut);
189
            }
190
            // must get the ho apply to ensure extensionality is applied
191
4186
            Node hoa = getHoApplyUf(tn);
192
2093
            Trace("ho-elim-visit") << "Hoa is " << hoa << std::endl;
193
          }
194
        }
195
10251
        d_visited[cur] = ret;
196
      }
197
      else
198
      {
199
37614
        d_visited[cur] = Node::null();
200
37614
        if (cur.getKind() == APPLY_UF && options::hoElim())
201
        {
202
2000
          Node op = cur.getOperator();
203
          // convert apply uf with variable arguments eagerly to ho apply
204
          // chains, so they are processed uniformly.
205
1000
          visit.push_back(cur);
206
2000
          Node newCur = theory::uf::TheoryUfRewriter::getHoApplyForApplyUf(cur);
207
1000
          preReplace[cur] = newCur;
208
1000
          cur = newCur;
209
1000
          d_visited[cur] = Node::null();
210
        }
211
37614
        visit.push_back(cur);
212
107862
        for (const Node& cn : cur)
213
        {
214
70248
          visit.push_back(cn);
215
        }
216
      }
217
    }
218
65541
    else if (it->second.isNull())
219
    {
220
77228
      Node ret = cur;
221
38614
      itr = preReplace.find(cur);
222
38614
      if (itr != preReplace.end())
223
      {
224
2000
        Trace("ho-elim-visit")
225
1000
            << "return (pre-repl): " << d_visited[itr->second] << std::endl;
226
1000
        d_visited[cur] = d_visited[itr->second];
227
      }
228
      else
229
      {
230
37614
        bool childChanged = false;
231
75228
        std::vector<Node> children;
232
75228
        std::vector<TypeNode> childrent;
233
37614
        bool typeChanged = false;
234
107862
        for (const Node& cn : ret)
235
        {
236
70248
          it = d_visited.find(cn);
237
70248
          Assert(it != d_visited.end());
238
70248
          Assert(!it->second.isNull());
239
70248
          childChanged = childChanged || cn != it->second;
240
70248
          children.push_back(it->second);
241
140496
          TypeNode ct = it->second.getType();
242
70248
          childrent.push_back(ct);
243
70248
          typeChanged = typeChanged || ct != cn.getType();
244
        }
245
37614
        if (ret.getMetaKind() == metakind::PARAMETERIZED)
246
        {
247
          // child of an argument changed type, must change type
248
14640
          Node op = ret.getOperator();
249
14640
          Node retOp = op;
250
14640
          Trace("ho-elim-visit")
251
7320
              << "Process op " << op << ", typeChanged = " << typeChanged
252
7320
              << std::endl;
253
7320
          if (typeChanged)
254
          {
255
            std::unordered_map<TNode, Node, TNodeHashFunction>::iterator ito =
256
                d_visited_op.find(op);
257
            if (ito == d_visited_op.end())
258
            {
259
              Assert(!childrent.empty());
260
              TypeNode newFType = nm->mkFunctionType(childrent, cur.getType());
261
              retOp = nm->mkSkolem("rf", newFType);
262
              d_visited_op[op] = retOp;
263
            }
264
            else
265
            {
266
              retOp = ito->second;
267
            }
268
          }
269
7320
          children.insert(children.begin(), retOp);
270
        }
271
        // process ho apply
272
37614
        if (ret.getKind() == HO_APPLY && options::hoElim())
273
        {
274
16104
          TypeNode tnr = ret.getType();
275
8052
          tnr = getUSort(tnr);
276
          Node hoa =
277
16104
              getHoApplyUf(children[0].getType(), children[1].getType(), tnr);
278
16104
          std::vector<Node> hchildren;
279
8052
          hchildren.push_back(hoa);
280
8052
          hchildren.push_back(children[0]);
281
8052
          hchildren.push_back(children[1]);
282
8052
          ret = nm->mkNode(APPLY_UF, hchildren);
283
        }
284
29562
        else if (childChanged)
285
        {
286
5429
          ret = nm->mkNode(ret.getKind(), children);
287
        }
288
37614
        Trace("ho-elim-visit") << "return (pre-repl): " << ret << std::endl;
289
37614
        d_visited[cur] = ret;
290
      }
291
    }
292
113406
  } while (!visit.empty());
293
4544
  Assert(d_visited.find(n) != d_visited.end());
294
4544
  Assert(!d_visited.find(n)->second.isNull());
295
4544
  Trace("ho-elim-assert") << "...got : " << d_visited[n] << std::endl;
296
9088
  return d_visited[n];
297
}
298
299
212
PreprocessingPassResult HoElim::applyInternal(
300
    AssertionPipeline* assertionsToPreprocess)
301
{
302
  // step [1]: apply lambda lifting to eliminate all lambdas
303
212
  NodeManager* nm = NodeManager::currentNM();
304
424
  std::vector<Node> axioms;
305
424
  std::map<Node, Node> newLambda;
306
4756
  for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
307
  {
308
9088
    Node prev = (*assertionsToPreprocess)[i];
309
9088
    Node res = eliminateLambdaComplete(prev, newLambda);
310
4544
    if (res != prev)
311
    {
312
307
      res = theory::Rewriter::rewrite(res);
313
307
      Assert(!expr::hasFreeVar(res));
314
307
      assertionsToPreprocess->replace(i, res);
315
    }
316
  }
317
  // do lambda lifting on new lambda definitions
318
  // this will do fixed point to eliminate lambdas within lambda lifting axioms.
319
274
  while (!newLambda.empty())
320
  {
321
62
    std::map<Node, Node> lproc = newLambda;
322
31
    newLambda.clear();
323
214
    for (const std::pair<const Node, Node>& l : lproc)
324
    {
325
366
      Node lambda = l.second;
326
366
      std::vector<Node> vars;
327
366
      std::vector<Node> nvars;
328
611
      for (const Node& v : lambda[0])
329
      {
330
856
        Node bv = nm->mkBoundVar(v.getType());
331
428
        vars.push_back(v);
332
428
        nvars.push_back(bv);
333
      }
334
335
      Node bd = lambda[1].substitute(
336
366
          vars.begin(), vars.end(), nvars.begin(), nvars.end());
337
366
      Node bvl = nm->mkNode(BOUND_VAR_LIST, nvars);
338
339
183
      nvars.insert(nvars.begin(), l.first);
340
366
      Node curr = nm->mkNode(APPLY_UF, nvars);
341
342
366
      Node llfax = nm->mkNode(FORALL, bvl, curr.eqNode(bd));
343
366
      Trace("ho-elim-ax") << "Lambda lifting axiom (pre-elim) " << llfax
344
183
                          << " for " << lambda << std::endl;
345
183
      Assert(!expr::hasFreeVar(llfax));
346
366
      Node llfaxe = eliminateLambdaComplete(llfax, newLambda);
347
366
      Trace("ho-elim-ax") << "Lambda lifting axiom " << llfaxe << " for "
348
183
                          << lambda << std::endl;
349
183
      axioms.push_back(llfaxe);
350
    }
351
  }
352
353
212
  d_visited.clear();
354
  // add lambda lifting axioms as a conjunction to the first assertion
355
212
  if (!axioms.empty())
356
  {
357
58
    Node conj = nm->mkAnd(axioms);
358
29
    conj = theory::Rewriter::rewrite(conj);
359
29
    Assert(!expr::hasFreeVar(conj));
360
29
    assertionsToPreprocess->conjoin(0, conj);
361
  }
362
212
  axioms.clear();
363
364
  // step [2]: eliminate all higher-order constraints
365
4756
  for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
366
  {
367
9088
    Node prev = (*assertionsToPreprocess)[i];
368
9088
    Node res = eliminateHo(prev);
369
4544
    if (res != prev)
370
    {
371
1472
      res = theory::Rewriter::rewrite(res);
372
1472
      Assert(!expr::hasFreeVar(res));
373
1472
      assertionsToPreprocess->replace(i, res);
374
    }
375
  }
376
  // extensionality: process all function types
377
914
  for (const TypeNode& ftn : d_funTypes)
378
  {
379
702
    if (options::hoElim())
380
    {
381
542
      Node h = getHoApplyUf(ftn);
382
271
      Trace("ho-elim-ax") << "Make extensionality for " << h << std::endl;
383
542
      TypeNode ft = h.getType();
384
542
      TypeNode uf = getUSort(ft[0]);
385
542
      TypeNode ut = getUSort(ft[1]);
386
      // extensionality
387
542
      Node x = nm->mkBoundVar("x", uf);
388
542
      Node y = nm->mkBoundVar("y", uf);
389
542
      Node z = nm->mkBoundVar("z", ut);
390
      Node eq =
391
542
          nm->mkNode(APPLY_UF, h, x, z).eqNode(nm->mkNode(APPLY_UF, h, y, z));
392
542
      Node antec = nm->mkNode(FORALL, nm->mkNode(BOUND_VAR_LIST, z), eq);
393
542
      Node conc = x.eqNode(y);
394
      Node ax = nm->mkNode(FORALL,
395
542
                           nm->mkNode(BOUND_VAR_LIST, x, y),
396
1084
                           nm->mkNode(OR, antec.negate(), conc));
397
271
      axioms.push_back(ax);
398
271
      Trace("ho-elim-ax") << "...ext axiom : " << ax << std::endl;
399
      // Make the "store" axiom, which asserts for every function, there
400
      // exists another function that acts like the "store" operator for
401
      // arrays, e.g. it is the same function with one I/O pair updated.
402
      // Without this axiom, the translation is model unsound.
403
1406
      if (options::hoElimStoreAx())
404
      {
405
542
        Node u = nm->mkBoundVar("u", uf);
406
542
        Node v = nm->mkBoundVar("v", uf);
407
542
        Node i = nm->mkBoundVar("i", ut);
408
542
        Node ii = nm->mkBoundVar("ii", ut);
409
542
        Node huii = nm->mkNode(APPLY_UF, h, u, ii);
410
542
        Node e = nm->mkBoundVar("e", huii.getType());
411
        Node store = nm->mkNode(
412
            FORALL,
413
542
            nm->mkNode(BOUND_VAR_LIST, u, e, i),
414
1084
            nm->mkNode(EXISTS,
415
542
                       nm->mkNode(BOUND_VAR_LIST, v),
416
1084
                       nm->mkNode(FORALL,
417
542
                                  nm->mkNode(BOUND_VAR_LIST, ii),
418
542
                                  nm->mkNode(APPLY_UF, h, v, ii)
419
1084
                                      .eqNode(nm->mkNode(
420
1626
                                          ITE, ii.eqNode(i), e, huii)))));
421
271
        axioms.push_back(store);
422
271
        Trace("ho-elim-ax") << "...store axiom : " << store << std::endl;
423
      }
424
    }
425
431
    else if (options::hoElimStoreAx())
426
    {
427
2
      Node u = nm->mkBoundVar("u", ftn);
428
2
      Node v = nm->mkBoundVar("v", ftn);
429
2
      std::vector<TypeNode> argTypes = ftn.getArgTypes();
430
2
      Node i = nm->mkBoundVar("i", argTypes[0]);
431
2
      Node ii = nm->mkBoundVar("ii", argTypes[0]);
432
2
      Node huii = nm->mkNode(HO_APPLY, u, ii);
433
2
      Node e = nm->mkBoundVar("e", huii.getType());
434
      Node store = nm->mkNode(
435
          FORALL,
436
2
          nm->mkNode(BOUND_VAR_LIST, u, e, i),
437
4
          nm->mkNode(
438
              EXISTS,
439
2
              nm->mkNode(BOUND_VAR_LIST, v),
440
4
              nm->mkNode(FORALL,
441
2
                         nm->mkNode(BOUND_VAR_LIST, ii),
442
2
                         nm->mkNode(HO_APPLY, v, ii)
443
6
                             .eqNode(nm->mkNode(ITE, ii.eqNode(i), e, huii)))));
444
1
      axioms.push_back(store);
445
2
      Trace("ho-elim-ax") << "...store (ho_apply) axiom : " << store
446
1
                          << std::endl;
447
    }
448
  }
449
  // add new axioms as a conjunction to the first assertion
450
212
  if (!axioms.empty())
451
  {
452
26
    Node conj = nm->mkAnd(axioms);
453
13
    conj = theory::Rewriter::rewrite(conj);
454
13
    Assert(!expr::hasFreeVar(conj));
455
13
    assertionsToPreprocess->conjoin(0, conj);
456
  }
457
458
424
  return PreprocessingPassResult::NO_CONFLICT;
459
}
460
461
2364
Node HoElim::getHoApplyUf(TypeNode tn)
462
{
463
4728
  TypeNode tnu = getUSort(tn);
464
4728
  TypeNode rangeType = tn.getRangeType();
465
4728
  std::vector<TypeNode> argTypes = tn.getArgTypes();
466
4728
  TypeNode tna = getUSort(argTypes[0]);
467
468
4728
  TypeNode tr = rangeType;
469
2364
  if (argTypes.size() > 1)
470
  {
471
1428
    std::vector<TypeNode> remArgTypes;
472
714
    remArgTypes.insert(remArgTypes.end(), argTypes.begin() + 1, argTypes.end());
473
714
    tr = NodeManager::currentNM()->mkFunctionType(remArgTypes, tr);
474
  }
475
4728
  TypeNode tnr = getUSort(tr);
476
477
4728
  return getHoApplyUf(tnu, tna, tnr);
478
}
479
480
10416
Node HoElim::getHoApplyUf(TypeNode tnf, TypeNode tna, TypeNode tnr)
481
{
482
10416
  std::map<TypeNode, Node>::iterator it = d_hoApplyUf.find(tnf);
483
10416
  if (it == d_hoApplyUf.end())
484
  {
485
271
    NodeManager* nm = NodeManager::currentNM();
486
487
542
    std::vector<TypeNode> hoTypeArgs;
488
271
    hoTypeArgs.push_back(tnf);
489
271
    hoTypeArgs.push_back(tna);
490
542
    TypeNode tnh = nm->mkFunctionType(hoTypeArgs, tnr);
491
542
    Node k = NodeManager::currentNM()->mkSkolem("ho", tnh);
492
271
    d_hoApplyUf[tnf] = k;
493
271
    return k;
494
  }
495
10145
  return it->second;
496
}
497
498
18220
TypeNode HoElim::getUSort(TypeNode tn)
499
{
500
18220
  if (!tn.isFunction())
501
  {
502
7335
    return tn;
503
  }
504
10885
  std::map<TypeNode, TypeNode>::iterator it = d_ftypeMap.find(tn);
505
10885
  if (it == d_ftypeMap.end())
506
  {
507
    // flatten function arguments
508
886
    std::vector<TypeNode> argTypes = tn.getArgTypes();
509
886
    TypeNode rangeType = tn.getRangeType();
510
443
    bool typeChanged = false;
511
1425
    for (unsigned i = 0; i < argTypes.size(); i++)
512
    {
513
982
      if (argTypes[i].isFunction())
514
      {
515
269
        argTypes[i] = getUSort(argTypes[i]);
516
269
        typeChanged = true;
517
      }
518
    }
519
886
    TypeNode s;
520
443
    if (typeChanged)
521
    {
522
      TypeNode ntn =
523
344
          NodeManager::currentNM()->mkFunctionType(argTypes, rangeType);
524
172
      s = getUSort(ntn);
525
    }
526
    else
527
    {
528
542
      std::stringstream ss;
529
271
      ss << "u_" << tn;
530
271
      s = NodeManager::currentNM()->mkSort(ss.str());
531
    }
532
443
    d_ftypeMap[tn] = s;
533
443
    return s;
534
  }
535
10442
  return it->second;
536
}
537
538
}  // namespace passes
539
}  // namespace preprocessing
540
55890
}  // namespace CVC4