GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/ho_elim.cpp Lines: 323 330 97.9 %
Date: 2021-11-07 Branches: 700 1518 46.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * The HoElim preprocessing pass.
14
 *
15
 * Eliminates higher-order constraints.
16
 */
17
18
#include "preprocessing/passes/ho_elim.h"
19
20
#include <sstream>
21
22
#include "expr/node_algorithm.h"
23
#include "expr/skolem_manager.h"
24
#include "options/quantifiers_options.h"
25
#include "preprocessing/assertion_pipeline.h"
26
#include "theory/rewriter.h"
27
#include "theory/uf/theory_uf_rewriter.h"
28
29
using namespace cvc5::kind;
30
31
namespace cvc5 {
32
namespace preprocessing {
33
namespace passes {
34
35
15340
HoElim::HoElim(PreprocessingPassContext* preprocContext)
36
15340
    : PreprocessingPass(preprocContext, "ho-elim"){};
37
38
5082
Node HoElim::eliminateLambdaComplete(Node n, std::map<Node, Node>& newLambda)
39
{
40
5082
  NodeManager* nm = NodeManager::currentNM();
41
5082
  SkolemManager* sm = nm->getSkolemManager();
42
5082
  std::unordered_map<Node, Node>::iterator it;
43
10164
  std::vector<Node> visit;
44
10164
  TNode cur;
45
5082
  visit.push_back(n);
46
111120
  do
47
  {
48
116202
    cur = visit.back();
49
116202
    visit.pop_back();
50
116202
    it = d_visited.find(cur);
51
52
116202
    if (it == d_visited.end())
53
    {
54
45485
      if (cur.getKind() == LAMBDA)
55
      {
56
180
        Trace("ho-elim-ll") << "Lambda lift: " << cur << std::endl;
57
        // must also get free variables in lambda
58
360
        std::vector<Node> lvars;
59
360
        std::vector<TypeNode> ftypes;
60
360
        std::unordered_set<Node> fvs;
61
180
        expr::getFreeVariables(cur, fvs);
62
360
        std::vector<Node> nvars;
63
360
        std::vector<Node> vars;
64
360
        Node sbd = cur[1];
65
180
        if (!fvs.empty())
66
        {
67
256
          Trace("ho-elim-ll")
68
128
              << "Has " << fvs.size() << " free variables" << std::endl;
69
357
          for (const Node& v : fvs)
70
          {
71
458
            TypeNode vt = v.getType();
72
229
            ftypes.push_back(vt);
73
458
            Node vs = nm->mkBoundVar(vt);
74
229
            vars.push_back(v);
75
229
            nvars.push_back(vs);
76
229
            lvars.push_back(vs);
77
          }
78
128
          sbd = sbd.substitute(
79
              vars.begin(), vars.end(), nvars.begin(), nvars.end());
80
        }
81
391
        for (const Node& bv : cur[0])
82
        {
83
422
          TypeNode bvt = bv.getType();
84
211
          ftypes.push_back(bvt);
85
211
          lvars.push_back(bv);
86
        }
87
360
        Node nlambda = cur;
88
180
        if (!fvs.empty())
89
        {
90
128
          nlambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, lvars), sbd);
91
256
          Trace("ho-elim-ll")
92
128
              << "...new lambda definition: " << nlambda << std::endl;
93
        }
94
360
        TypeNode rangeType = cur.getType().getRangeType();
95
360
        TypeNode nft = nm->mkFunctionType(ftypes, rangeType);
96
360
        Node nf = sm->mkDummySkolem("ll", nft);
97
360
        Trace("ho-elim-ll")
98
180
            << "...introduce: " << nf << " of type " << nft << std::endl;
99
180
        newLambda[nf] = nlambda;
100
180
        Assert(nf.getType() == nlambda.getType());
101
180
        if (!vars.empty())
102
        {
103
357
          for (const Node& v : vars)
104
          {
105
229
            nf = nm->mkNode(HO_APPLY, nf, v);
106
          }
107
128
          Trace("ho-elim-ll") << "...partial application: " << nf << std::endl;
108
        }
109
180
        d_visited[cur] = nf;
110
360
        Trace("ho-elim-ll") << "...return types : " << nf.getType() << " "
111
180
                            << cur.getType() << std::endl;
112
180
        Assert(nf.getType() == cur.getType());
113
      }
114
      else
115
      {
116
45305
        d_visited[cur] = Node::null();
117
45305
        visit.push_back(cur);
118
111120
        for (const Node& cn : cur)
119
        {
120
65815
          visit.push_back(cn);
121
        }
122
      }
123
    }
124
70717
    else if (it->second.isNull())
125
    {
126
90610
      Node ret = cur;
127
45305
      bool childChanged = false;
128
90610
      std::vector<Node> children;
129
45305
      if (cur.getMetaKind() == metakind::PARAMETERIZED)
130
      {
131
8376
        children.push_back(cur.getOperator());
132
      }
133
111120
      for (const Node& cn : cur)
134
      {
135
65815
        it = d_visited.find(cn);
136
65815
        Assert(it != d_visited.end());
137
65815
        Assert(!it->second.isNull());
138
65815
        childChanged = childChanged || cn != it->second;
139
65815
        children.push_back(it->second);
140
      }
141
45305
      if (childChanged)
142
      {
143
1489
        ret = nm->mkNode(cur.getKind(), children);
144
      }
145
45305
      d_visited[cur] = ret;
146
    }
147
116202
  } while (!visit.empty());
148
5082
  Assert(d_visited.find(n) != d_visited.end());
149
5082
  Assert(!d_visited.find(n)->second.isNull());
150
10164
  return d_visited[n];
151
}
152
153
4902
Node HoElim::eliminateHo(Node n)
154
{
155
4902
  Trace("ho-elim-assert") << "Ho-elim assertion: " << n << std::endl;
156
4902
  NodeManager* nm = NodeManager::currentNM();
157
4902
  SkolemManager* sm = nm->getSkolemManager();
158
4902
  std::unordered_map<Node, Node>::iterator it;
159
9804
  std::map<Node, Node> preReplace;
160
4902
  std::map<Node, Node>::iterator itr;
161
9804
  std::vector<TNode> visit;
162
9804
  TNode cur;
163
4902
  visit.push_back(n);
164
105795
  do
165
  {
166
110697
    cur = visit.back();
167
110697
    visit.pop_back();
168
110697
    it = d_visited.find(cur);
169
110697
    Trace("ho-elim-visit") << "Process: " << cur << std::endl;
170
171
110697
    if (it == d_visited.end())
172
    {
173
93526
      TypeNode tn = cur.getType();
174
      // lambdas are already eliminated by now
175
46763
      Assert(cur.getKind() != LAMBDA);
176
46763
      if (tn.isFunction())
177
      {
178
7961
        d_funTypes.insert(tn);
179
      }
180
46763
      if (cur.isVar())
181
      {
182
20214
        Node ret = cur;
183
10107
        if (options().quantifiers.hoElim)
184
        {
185
3541
          if (tn.isFunction())
186
          {
187
4178
            TypeNode ut = getUSort(tn);
188
2089
            if (cur.getKind() == BOUND_VARIABLE)
189
            {
190
1721
              ret = nm->mkBoundVar(ut);
191
            }
192
            else
193
            {
194
368
              ret = sm->mkDummySkolem("k", ut);
195
            }
196
            // must get the ho apply to ensure extensionality is applied
197
4178
            Node hoa = getHoApplyUf(tn);
198
2089
            Trace("ho-elim-visit") << "Hoa is " << hoa << std::endl;
199
          }
200
        }
201
10107
        d_visited[cur] = ret;
202
      }
203
      else
204
      {
205
36656
        d_visited[cur] = Node::null();
206
36656
        if (cur.getKind() == APPLY_UF && options().quantifiers.hoElim)
207
        {
208
1916
          Node op = cur.getOperator();
209
          // convert apply uf with variable arguments eagerly to ho apply
210
          // chains, so they are processed uniformly.
211
958
          visit.push_back(cur);
212
1916
          Node newCur = theory::uf::TheoryUfRewriter::getHoApplyForApplyUf(cur);
213
958
          preReplace[cur] = newCur;
214
958
          cur = newCur;
215
958
          d_visited[cur] = Node::null();
216
        }
217
36656
        visit.push_back(cur);
218
104837
        for (const Node& cn : cur)
219
        {
220
68181
          visit.push_back(cn);
221
        }
222
      }
223
    }
224
63934
    else if (it->second.isNull())
225
    {
226
75228
      Node ret = cur;
227
37614
      itr = preReplace.find(cur);
228
37614
      if (itr != preReplace.end())
229
      {
230
1916
        Trace("ho-elim-visit")
231
958
            << "return (pre-repl): " << d_visited[itr->second] << std::endl;
232
958
        d_visited[cur] = d_visited[itr->second];
233
      }
234
      else
235
      {
236
36656
        bool childChanged = false;
237
73312
        std::vector<Node> children;
238
73312
        std::vector<TypeNode> childrent;
239
36656
        bool typeChanged = false;
240
104837
        for (const Node& cn : ret)
241
        {
242
68181
          it = d_visited.find(cn);
243
68181
          Assert(it != d_visited.end());
244
68181
          Assert(!it->second.isNull());
245
68181
          childChanged = childChanged || cn != it->second;
246
68181
          children.push_back(it->second);
247
136362
          TypeNode ct = it->second.getType();
248
68181
          childrent.push_back(ct);
249
68181
          typeChanged = typeChanged || ct != cn.getType();
250
        }
251
36656
        if (ret.getMetaKind() == metakind::PARAMETERIZED)
252
        {
253
          // child of an argument changed type, must change type
254
14836
          Node op = ret.getOperator();
255
14836
          Node retOp = op;
256
14836
          Trace("ho-elim-visit")
257
7418
              << "Process op " << op << ", typeChanged = " << typeChanged
258
7418
              << std::endl;
259
7418
          if (typeChanged)
260
          {
261
            std::unordered_map<TNode, Node>::iterator ito =
262
                d_visited_op.find(op);
263
            if (ito == d_visited_op.end())
264
            {
265
              Assert(!childrent.empty());
266
              TypeNode newFType = nm->mkFunctionType(childrent, cur.getType());
267
              retOp = sm->mkDummySkolem("rf", newFType);
268
              d_visited_op[op] = retOp;
269
            }
270
            else
271
            {
272
              retOp = ito->second;
273
            }
274
          }
275
7418
          children.insert(children.begin(), retOp);
276
        }
277
        // process ho apply
278
36656
        if (ret.getKind() == HO_APPLY && options().quantifiers.hoElim)
279
        {
280
15638
          TypeNode tnr = ret.getType();
281
7819
          tnr = getUSort(tnr);
282
          Node hoa =
283
15638
              getHoApplyUf(children[0].getType(), children[1].getType(), tnr);
284
15638
          std::vector<Node> hchildren;
285
7819
          hchildren.push_back(hoa);
286
7819
          hchildren.push_back(children[0]);
287
7819
          hchildren.push_back(children[1]);
288
7819
          ret = nm->mkNode(APPLY_UF, hchildren);
289
        }
290
28837
        else if (childChanged)
291
        {
292
5253
          ret = nm->mkNode(ret.getKind(), children);
293
        }
294
36656
        Trace("ho-elim-visit") << "return (pre-repl): " << ret << std::endl;
295
36656
        d_visited[cur] = ret;
296
      }
297
    }
298
110697
  } while (!visit.empty());
299
4902
  Assert(d_visited.find(n) != d_visited.end());
300
4902
  Assert(!d_visited.find(n)->second.isNull());
301
4902
  Trace("ho-elim-assert") << "...got : " << d_visited[n] << std::endl;
302
9804
  return d_visited[n];
303
}
304
305
232
PreprocessingPassResult HoElim::applyInternal(
306
    AssertionPipeline* assertionsToPreprocess)
307
{
308
  // step [1]: apply lambda lifting to eliminate all lambdas
309
232
  NodeManager* nm = NodeManager::currentNM();
310
464
  std::vector<Node> axioms;
311
464
  std::map<Node, Node> newLambda;
312
5134
  for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
313
  {
314
9804
    Node prev = (*assertionsToPreprocess)[i];
315
9804
    Node res = eliminateLambdaComplete(prev, newLambda);
316
4902
    if (res != prev)
317
    {
318
316
      res = rewrite(res);
319
316
      Assert(!expr::hasFreeVar(res));
320
316
      assertionsToPreprocess->replace(i, res);
321
    }
322
  }
323
  // do lambda lifting on new lambda definitions
324
  // this will do fixed point to eliminate lambdas within lambda lifting axioms.
325
292
  while (!newLambda.empty())
326
  {
327
60
    std::map<Node, Node> lproc = newLambda;
328
30
    newLambda.clear();
329
210
    for (const std::pair<const Node, Node>& l : lproc)
330
    {
331
360
      Node lambda = l.second;
332
360
      std::vector<Node> vars;
333
360
      std::vector<Node> nvars;
334
620
      for (const Node& v : lambda[0])
335
      {
336
880
        Node bv = nm->mkBoundVar(v.getType());
337
440
        vars.push_back(v);
338
440
        nvars.push_back(bv);
339
      }
340
341
      Node bd = lambda[1].substitute(
342
360
          vars.begin(), vars.end(), nvars.begin(), nvars.end());
343
360
      Node bvl = nm->mkNode(BOUND_VAR_LIST, nvars);
344
345
180
      nvars.insert(nvars.begin(), l.first);
346
360
      Node curr = nm->mkNode(APPLY_UF, nvars);
347
348
360
      Node llfax = nm->mkNode(FORALL, bvl, curr.eqNode(bd));
349
360
      Trace("ho-elim-ax") << "Lambda lifting axiom (pre-elim) " << llfax
350
180
                          << " for " << lambda << std::endl;
351
180
      Assert(!expr::hasFreeVar(llfax));
352
360
      Node llfaxe = eliminateLambdaComplete(llfax, newLambda);
353
360
      Trace("ho-elim-ax") << "Lambda lifting axiom " << llfaxe << " for "
354
180
                          << lambda << std::endl;
355
180
      axioms.push_back(llfaxe);
356
    }
357
  }
358
359
232
  d_visited.clear();
360
  // add lambda lifting axioms as a conjunction to the first assertion
361
232
  if (!axioms.empty())
362
  {
363
52
    Node conj = nm->mkAnd(axioms);
364
26
    conj = rewrite(conj);
365
26
    Assert(!expr::hasFreeVar(conj));
366
26
    assertionsToPreprocess->conjoin(0, conj);
367
  }
368
232
  axioms.clear();
369
370
  // step [2]: eliminate all higher-order constraints
371
5134
  for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
372
  {
373
9804
    Node prev = (*assertionsToPreprocess)[i];
374
9804
    Node res = eliminateHo(prev);
375
4902
    if (res != prev)
376
    {
377
1434
      res = rewrite(res);
378
1434
      Assert(!expr::hasFreeVar(res));
379
1434
      assertionsToPreprocess->replace(i, res);
380
    }
381
  }
382
  // extensionality: process all function types
383
942
  for (const TypeNode& ftn : d_funTypes)
384
  {
385
710
    if (options().quantifiers.hoElim)
386
    {
387
536
      Node h = getHoApplyUf(ftn);
388
268
      Trace("ho-elim-ax") << "Make extensionality for " << h << std::endl;
389
536
      TypeNode ft = h.getType();
390
536
      TypeNode uf = getUSort(ft[0]);
391
536
      TypeNode ut = getUSort(ft[1]);
392
      // extensionality
393
536
      Node x = nm->mkBoundVar("x", uf);
394
536
      Node y = nm->mkBoundVar("y", uf);
395
536
      Node z = nm->mkBoundVar("z", ut);
396
      Node eq =
397
536
          nm->mkNode(APPLY_UF, h, x, z).eqNode(nm->mkNode(APPLY_UF, h, y, z));
398
536
      Node antec = nm->mkNode(FORALL, nm->mkNode(BOUND_VAR_LIST, z), eq);
399
536
      Node conc = x.eqNode(y);
400
      Node ax = nm->mkNode(FORALL,
401
536
                           nm->mkNode(BOUND_VAR_LIST, x, y),
402
1072
                           nm->mkNode(OR, antec.negate(), conc));
403
268
      axioms.push_back(ax);
404
268
      Trace("ho-elim-ax") << "...ext axiom : " << ax << std::endl;
405
      // Make the "store" axiom, which asserts for every function, there
406
      // exists another function that acts like the "store" operator for
407
      // arrays, e.g. it is the same function with one I/O pair updated.
408
      // Without this axiom, the translation is model unsound.
409
268
      if (options().quantifiers.hoElimStoreAx)
410
      {
411
536
        Node u = nm->mkBoundVar("u", uf);
412
536
        Node v = nm->mkBoundVar("v", uf);
413
536
        Node i = nm->mkBoundVar("i", ut);
414
536
        Node ii = nm->mkBoundVar("ii", ut);
415
536
        Node huii = nm->mkNode(APPLY_UF, h, u, ii);
416
536
        Node e = nm->mkBoundVar("e", huii.getType());
417
        Node store = nm->mkNode(
418
            FORALL,
419
536
            nm->mkNode(BOUND_VAR_LIST, u, e, i),
420
1072
            nm->mkNode(EXISTS,
421
536
                       nm->mkNode(BOUND_VAR_LIST, v),
422
1072
                       nm->mkNode(FORALL,
423
536
                                  nm->mkNode(BOUND_VAR_LIST, ii),
424
536
                                  nm->mkNode(APPLY_UF, h, v, ii)
425
1072
                                      .eqNode(nm->mkNode(
426
1608
                                          ITE, ii.eqNode(i), e, huii)))));
427
268
        axioms.push_back(store);
428
268
        Trace("ho-elim-ax") << "...store axiom : " << store << std::endl;
429
      }
430
    }
431
442
    else if (options().quantifiers.hoElimStoreAx)
432
    {
433
2
      Node u = nm->mkBoundVar("u", ftn);
434
2
      Node v = nm->mkBoundVar("v", ftn);
435
2
      std::vector<TypeNode> argTypes = ftn.getArgTypes();
436
2
      Node i = nm->mkBoundVar("i", argTypes[0]);
437
2
      Node ii = nm->mkBoundVar("ii", argTypes[0]);
438
2
      Node huii = nm->mkNode(HO_APPLY, u, ii);
439
2
      Node e = nm->mkBoundVar("e", huii.getType());
440
      Node store = nm->mkNode(
441
          FORALL,
442
2
          nm->mkNode(BOUND_VAR_LIST, u, e, i),
443
4
          nm->mkNode(
444
              EXISTS,
445
2
              nm->mkNode(BOUND_VAR_LIST, v),
446
4
              nm->mkNode(FORALL,
447
2
                         nm->mkNode(BOUND_VAR_LIST, ii),
448
2
                         nm->mkNode(HO_APPLY, v, ii)
449
6
                             .eqNode(nm->mkNode(ITE, ii.eqNode(i), e, huii)))));
450
1
      axioms.push_back(store);
451
2
      Trace("ho-elim-ax") << "...store (ho_apply) axiom : " << store
452
1
                          << std::endl;
453
    }
454
  }
455
  // add new axioms as a conjunction to the first assertion
456
232
  if (!axioms.empty())
457
  {
458
22
    Node conj = nm->mkAnd(axioms);
459
11
    conj = rewrite(conj);
460
11
    Assert(!expr::hasFreeVar(conj));
461
11
    assertionsToPreprocess->conjoin(0, conj);
462
  }
463
464
464
  return PreprocessingPassResult::NO_CONFLICT;
465
}
466
467
2357
Node HoElim::getHoApplyUf(TypeNode tn)
468
{
469
4714
  TypeNode tnu = getUSort(tn);
470
4714
  TypeNode rangeType = tn.getRangeType();
471
4714
  std::vector<TypeNode> argTypes = tn.getArgTypes();
472
4714
  TypeNode tna = getUSort(argTypes[0]);
473
474
4714
  TypeNode tr = rangeType;
475
2357
  if (argTypes.size() > 1)
476
  {
477
1410
    std::vector<TypeNode> remArgTypes;
478
705
    remArgTypes.insert(remArgTypes.end(), argTypes.begin() + 1, argTypes.end());
479
705
    tr = NodeManager::currentNM()->mkFunctionType(remArgTypes, tr);
480
  }
481
4714
  TypeNode tnr = getUSort(tr);
482
483
4714
  return getHoApplyUf(tnu, tna, tnr);
484
}
485
486
10176
Node HoElim::getHoApplyUf(TypeNode tnf, TypeNode tna, TypeNode tnr)
487
{
488
10176
  std::map<TypeNode, Node>::iterator it = d_hoApplyUf.find(tnf);
489
10176
  if (it == d_hoApplyUf.end())
490
  {
491
268
    NodeManager* nm = NodeManager::currentNM();
492
268
    SkolemManager* sm = nm->getSkolemManager();
493
494
536
    std::vector<TypeNode> hoTypeArgs;
495
268
    hoTypeArgs.push_back(tnf);
496
268
    hoTypeArgs.push_back(tna);
497
536
    TypeNode tnh = nm->mkFunctionType(hoTypeArgs, tnr);
498
536
    Node k = sm->mkDummySkolem("ho", tnh);
499
268
    d_hoApplyUf[tnf] = k;
500
268
    return k;
501
  }
502
9908
  return it->second;
503
}
504
505
17970
TypeNode HoElim::getUSort(TypeNode tn)
506
{
507
17970
  if (!tn.isFunction())
508
  {
509
7204
    return tn;
510
  }
511
10766
  std::map<TypeNode, TypeNode>::iterator it = d_ftypeMap.find(tn);
512
10766
  if (it == d_ftypeMap.end())
513
  {
514
    // flatten function arguments
515
890
    std::vector<TypeNode> argTypes = tn.getArgTypes();
516
890
    TypeNode rangeType = tn.getRangeType();
517
445
    bool typeChanged = false;
518
1435
    for (unsigned i = 0; i < argTypes.size(); i++)
519
    {
520
990
      if (argTypes[i].isFunction())
521
      {
522
278
        argTypes[i] = getUSort(argTypes[i]);
523
278
        typeChanged = true;
524
      }
525
    }
526
890
    TypeNode s;
527
445
    if (typeChanged)
528
    {
529
      TypeNode ntn =
530
354
          NodeManager::currentNM()->mkFunctionType(argTypes, rangeType);
531
177
      s = getUSort(ntn);
532
    }
533
    else
534
    {
535
536
      std::stringstream ss;
536
268
      ss << "u_" << tn;
537
268
      s = NodeManager::currentNM()->mkSort(ss.str());
538
    }
539
445
    d_ftypeMap[tn] = s;
540
445
    return s;
541
  }
542
10321
  return it->second;
543
}
544
545
}  // namespace passes
546
}  // namespace preprocessing
547
31137
}  // namespace cvc5