GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/ematching/ho_trigger.cpp Lines: 230 257 89.5 %
Date: 2021-11-07 Branches: 464 1128 41.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mathias Preiner, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of higher-order trigger class.
14
 */
15
16
#include "theory/quantifiers/ematching/ho_trigger.h"
17
18
#include <stack>
19
20
#include "theory/quantifiers/ho_term_database.h"
21
#include "theory/quantifiers/instantiate.h"
22
#include "theory/quantifiers/quantifiers_inference_manager.h"
23
#include "theory/quantifiers/quantifiers_registry.h"
24
#include "theory/quantifiers/quantifiers_state.h"
25
#include "theory/quantifiers/term_registry.h"
26
#include "theory/quantifiers/term_util.h"
27
#include "theory/uf/theory_uf_rewriter.h"
28
#include "util/hash.h"
29
30
using namespace cvc5::kind;
31
32
namespace cvc5 {
33
namespace theory {
34
namespace quantifiers {
35
namespace inst {
36
37
36
HigherOrderTrigger::HigherOrderTrigger(
38
    Env& env,
39
    QuantifiersState& qs,
40
    QuantifiersInferenceManager& qim,
41
    QuantifiersRegistry& qr,
42
    TermRegistry& tr,
43
    Node q,
44
    std::vector<Node>& nodes,
45
36
    std::map<Node, std::vector<Node> >& ho_apps)
46
36
    : Trigger(env, qs, qim, qr, tr, q, nodes), d_ho_var_apps(ho_apps)
47
{
48
36
  NodeManager* nm = NodeManager::currentNM();
49
  // process the higher-order variable applications
50
84
  for (std::pair<const Node, std::vector<Node> >& as : d_ho_var_apps)
51
  {
52
96
    Node n = as.first;
53
48
    d_ho_var_list.push_back(n);
54
96
    TypeNode tn = n.getType();
55
48
    Assert(tn.isFunction());
56
48
    if (Trace.isOn("ho-quant-trigger"))
57
    {
58
      Trace("ho-quant-trigger") << "  have " << as.second.size();
59
      Trace("ho-quant-trigger") << " patterns with variable operator " << n
60
                                << ":" << std::endl;
61
      for (unsigned j = 0; j < as.second.size(); j++)
62
      {
63
        Trace("ho-quant-trigger") << "  " << as.second[j] << std::endl;
64
      }
65
    }
66
48
    if (d_ho_var_types.find(tn) == d_ho_var_types.end())
67
    {
68
96
      Trace("ho-quant-trigger") << "  type " << tn
69
48
                                << " needs higher-order matching." << std::endl;
70
48
      d_ho_var_types.insert(tn);
71
    }
72
    // make the bound variable lists
73
48
    d_ho_var_bvl[n] = nm->getBoundVarListForFunctionType(tn);
74
120
    for (const Node& nc : d_ho_var_bvl[n])
75
    {
76
72
      d_ho_var_bvs[n].push_back(nc);
77
    }
78
  }
79
36
}
80
81
72
HigherOrderTrigger::~HigherOrderTrigger() {}
82
void HigherOrderTrigger::collectHoVarApplyTerms(
83
    Node q, Node& n, std::map<Node, std::vector<Node> >& apps)
84
{
85
  std::vector<Node> ns;
86
  ns.push_back(n);
87
  collectHoVarApplyTerms(q, ns, apps);
88
  Assert(ns.size() == 1);
89
  n = ns[0];
90
}
91
92
17703
void HigherOrderTrigger::collectHoVarApplyTerms(
93
    Node q, std::vector<Node>& ns, std::map<Node, std::vector<Node> >& apps)
94
{
95
35406
  std::unordered_map<TNode, Node> visited;
96
17703
  std::unordered_map<TNode, Node>::iterator it;
97
  // whether the visited node is a child of a HO_APPLY chain
98
35406
  std::unordered_map<TNode, bool> withinApply;
99
35406
  std::vector<TNode> visit;
100
35406
  TNode cur;
101
37057
  for (unsigned i = 0, size = ns.size(); i < size; i++)
102
  {
103
19354
    visit.push_back(ns[i]);
104
19354
    withinApply[ns[i]] = false;
105
207856
    do
106
    {
107
227210
      cur = visit.back();
108
227210
      visit.pop_back();
109
110
227210
      it = visited.find(cur);
111
227210
      if (it == visited.end())
112
      {
113
        // do not look in nested quantifiers
114
111609
        if (cur.getKind() == FORALL)
115
        {
116
          visited[cur] = cur;
117
        }
118
        else
119
        {
120
111609
          bool curWithinApply = withinApply[cur];
121
111609
          visited[cur] = Node::null();
122
111609
          visit.push_back(cur);
123
207856
          for (unsigned j = 0, sizec = cur.getNumChildren(); j < sizec; j++)
124
          {
125
96247
            withinApply[cur[j]] = curWithinApply && j == 0;
126
96247
            visit.push_back(cur[j]);
127
          }
128
        }
129
      }
130
115601
      else if (it->second.isNull())
131
      {
132
        // carry the conversion
133
223218
        Node ret = cur;
134
111609
        bool childChanged = false;
135
223218
        std::vector<Node> children;
136
111609
        if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
137
        {
138
44259
          children.push_back(cur.getOperator());
139
        }
140
207856
        for (const Node& nc : cur)
141
        {
142
96247
          it = visited.find(nc);
143
96247
          Assert(it != visited.end());
144
96247
          Assert(!it->second.isNull());
145
96247
          childChanged = childChanged || nc != it->second;
146
96247
          children.push_back(it->second);
147
        }
148
111609
        if (childChanged)
149
        {
150
40
          ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
151
        }
152
        // now, convert and store the application
153
111609
        if (!withinApply[cur])
154
        {
155
223218
          TNode op;
156
111609
          if (ret.getKind() == kind::APPLY_UF)
157
          {
158
            // could be a fully applied function variable
159
38205
            op = ret.getOperator();
160
          }
161
73404
          else if (ret.getKind() == kind::HO_APPLY)
162
          {
163
111
            op = ret;
164
401
            while (op.getKind() == kind::HO_APPLY)
165
            {
166
145
              op = op[0];
167
            }
168
          }
169
111609
          if (!op.isNull())
170
          {
171
38316
            if (op.getKind() == kind::INST_CONSTANT)
172
            {
173
48
              Assert(TermUtil::getInstConstAttr(ret) == q);
174
96
              Trace("ho-quant-trigger-debug")
175
48
                  << "Ho variable apply term : " << ret << " with head " << op
176
48
                  << std::endl;
177
48
              if (ret.getKind() == kind::APPLY_UF)
178
              {
179
80
                Node prev = ret;
180
                // for consistency, convert to HO_APPLY if fully applied
181
40
                ret = uf::TheoryUfRewriter::getHoApplyForApplyUf(ret);
182
              }
183
48
              apps[op].push_back(ret);
184
            }
185
          }
186
        }
187
111609
        visited[cur] = ret;
188
      }
189
227210
    } while (!visit.empty());
190
191
    // store the conversion
192
19354
    Assert(visited.find(ns[i]) != visited.end());
193
19354
    ns[i] = visited[ns[i]];
194
  }
195
17703
}
196
197
108
uint64_t HigherOrderTrigger::addInstantiations()
198
{
199
  // call the base class implementation
200
108
  uint64_t addedFoLemmas = Trigger::addInstantiations();
201
  // also adds predicate lemms to force app completion
202
108
  uint64_t addedHoLemmas = addHoTypeMatchPredicateLemmas();
203
108
  return addedHoLemmas + addedFoLemmas;
204
}
205
206
192
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m, InferenceId id)
207
{
208
192
  if (options().quantifiers.hoMatching)
209
  {
210
    // get substitution corresponding to m
211
384
    std::vector<TNode> vars;
212
384
    std::vector<TNode> subs;
213
684
    for (unsigned i = 0, size = d_quant[0].getNumChildren(); i < size; i++)
214
    {
215
492
      subs.push_back(m[i]);
216
492
      vars.push_back(d_qreg.getInstantiationConstant(d_quant, i));
217
    }
218
219
192
    Trace("ho-unif-debug") << "Run higher-order unification..." << std::endl;
220
221
    // get the substituted form of all variable-operator ho application terms
222
384
    std::map<TNode, std::vector<Node> > ho_var_apps_subs;
223
432
    for (std::pair<const Node, std::vector<Node> >& ha : d_ho_var_apps)
224
    {
225
480
      TNode var = ha.first;
226
480
      for (unsigned j = 0, size = ha.second.size(); j < size; j++)
227
      {
228
480
        TNode app = ha.second[j];
229
        Node sapp =
230
480
            app.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
231
240
        ho_var_apps_subs[var].push_back(sapp);
232
480
        Trace("ho-unif-debug") << "  app[" << var << "] : " << app << " -> "
233
240
                               << sapp << std::endl;
234
      }
235
    }
236
237
    // compute argument vectors for each variable
238
192
    d_lchildren.clear();
239
192
    d_arg_to_arg_rep.clear();
240
192
    d_arg_vector.clear();
241
192
    EntailmentCheck* echeck = d_treg.getEntailmentCheck();
242
432
    for (std::pair<const TNode, std::vector<Node> >& ha : ho_var_apps_subs)
243
    {
244
480
      TNode var = ha.first;
245
240
      unsigned vnum = var.getAttribute(InstVarNumAttribute());
246
480
      TNode value = m[vnum];
247
240
      Trace("ho-unif-debug") << "  val[" << var << "] = " << value << std::endl;
248
249
480
      Trace("ho-unif-debug2") << "initialize lambda information..."
250
240
                              << std::endl;
251
      // initialize the lambda children
252
240
      d_lchildren[vnum].push_back(value);
253
      std::map<TNode, std::vector<Node> >::iterator ithb =
254
240
          d_ho_var_bvs.find(var);
255
240
      Assert(ithb != d_ho_var_bvs.end());
256
480
      d_lchildren[vnum].insert(
257
480
          d_lchildren[vnum].end(), ithb->second.begin(), ithb->second.end());
258
259
240
      Trace("ho-unif-debug2") << "compute fixed arguments..." << std::endl;
260
      // compute for each argument if it is only applied to a fixed value modulo
261
      // equality
262
480
      std::map<unsigned, Node> fixed_vals;
263
480
      for (unsigned i = 0; i < ha.second.size(); i++)
264
      {
265
480
        std::vector<TNode> args;
266
        // must substitute the operator we matched with the original
267
        // higher-order variable (var) that matched it. This ensures that the
268
        // argument vector (args) below is of the proper length. This handles,
269
        // for example, matches like:
270
        //   (@ x y) with (@ (@ k1 k2) k3)
271
        // where k3 but not k2 should be an argument of the match.
272
480
        Node hmatch = ha.second[i];
273
240
        Trace("ho-unif-debug2") << "Match is " << hmatch << std::endl;
274
240
        hmatch = hmatch.substitute(value, var);
275
240
        Trace("ho-unif-debug2") << "Pre-subs match is " << hmatch << std::endl;
276
480
        Node f = uf::TheoryUfRewriter::decomposeHoApply(hmatch, args);
277
        // Assert( f==value );
278
616
        for (unsigned k = 0, size = args.size(); k < size; k++)
279
        {
280
          // must now subsitute back, to handle cases like
281
          // (@ x y) matching (@ t (@ t s))
282
          // where the above substitution would produce (@ x (@ x s)),
283
          // but the argument should be (@ t s).
284
376
          args[k] = args[k].substitute(var, value);
285
752
          Node val = args[k];
286
376
          std::map<unsigned, Node>::iterator itf = fixed_vals.find(k);
287
376
          if (itf == fixed_vals.end())
288
          {
289
376
            fixed_vals[k] = val;
290
          }
291
          else if (!itf->second.isNull())
292
          {
293
            if (!d_qstate.areEqual(itf->second, args[k]))
294
            {
295
              if (!echeck->isEntailed(itf->second.eqNode(args[k]), true))
296
              {
297
                fixed_vals[k] = Node::null();
298
              }
299
            }
300
          }
301
        }
302
      }
303
240
      if (Trace.isOn("ho-unif-debug"))
304
      {
305
        for (std::map<unsigned, Node>::iterator itf = fixed_vals.begin();
306
             itf != fixed_vals.end();
307
             ++itf)
308
        {
309
          Trace("ho-unif-debug") << "  arg[" << var << "][" << itf->first
310
                                 << "] : " << itf->second << std::endl;
311
        }
312
      }
313
314
      // now construct argument vectors
315
240
      Trace("ho-unif-debug2") << "compute argument vectors..." << std::endl;
316
480
      std::map<Node, unsigned> arg_to_rep;
317
640
      for (unsigned index = 0, size = ithb->second.size(); index < size;
318
           index++)
319
      {
320
800
        Node bv_at_index = ithb->second[index];
321
400
        std::map<unsigned, Node>::iterator itf = fixed_vals.find(index);
322
400
        Trace("ho-unif-debug") << "  * arg[" << var << "][" << index << "]";
323
400
        if (itf != fixed_vals.end())
324
        {
325
376
          if (!itf->second.isNull())
326
          {
327
752
            Node r = d_qstate.getRepresentative(itf->second);
328
376
            std::map<Node, unsigned>::iterator itfr = arg_to_rep.find(r);
329
376
            if (itfr != arg_to_rep.end())
330
            {
331
52
              d_arg_to_arg_rep[vnum][index] = itfr->second;
332
              // function applied to equivalent values at multiple arguments,
333
              // can permute variables
334
52
              d_arg_vector[vnum][itfr->second].push_back(bv_at_index);
335
104
              Trace("ho-unif-debug") << " = { self } ++ arg[" << var << "]["
336
52
                                     << itfr->second << "]" << std::endl;
337
            }
338
            else
339
            {
340
324
              arg_to_rep[r] = index;
341
              // function applied to single value, can either use variable or
342
              // value at this argument position
343
324
              d_arg_vector[vnum][index].push_back(bv_at_index);
344
324
              d_arg_vector[vnum][index].push_back(itf->second);
345
324
              if (!options().quantifiers.hoMatchingVarArgPriority)
346
              {
347
                std::reverse(d_arg_vector[vnum][index].begin(),
348
                             d_arg_vector[vnum][index].end());
349
              }
350
648
              Trace("ho-unif-debug") << " = { self, " << itf->second << " } "
351
324
                                     << std::endl;
352
            }
353
          }
354
          else
355
          {
356
            // function is applied to disequal values, can only use variable at
357
            // this argument position
358
            d_arg_vector[vnum][index].push_back(bv_at_index);
359
            Trace("ho-unif-debug") << " = { self } (disequal)" << std::endl;
360
          }
361
        }
362
        else
363
        {
364
          // argument is irrelevant to matching, assume identity variable
365
24
          d_arg_vector[vnum][index].push_back(bv_at_index);
366
24
          Trace("ho-unif-debug") << " = { self } (irrelevant)" << std::endl;
367
        }
368
      }
369
240
      Trace("ho-unif-debug2") << "finished." << std::endl;
370
    }
371
372
192
    bool ret = sendInstantiation(m, 0);
373
192
    Trace("ho-unif-debug") << "Finished, success = " << ret << std::endl;
374
192
    return ret;
375
  }
376
  else
377
  {
378
    // do not run higher-order matching
379
    return d_qim.getInstantiate()->addInstantiation(d_quant, m, id);
380
  }
381
}
382
383
// recursion depth limited by number of arguments of higher order variables
384
// occurring as pattern operators (very small)
385
906
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m,
386
                                           size_t var_index)
387
{
388
1812
  Trace("ho-unif-debug2") << "send inst " << var_index << " / "
389
906
                          << d_ho_var_list.size() << std::endl;
390
906
  if (var_index == d_ho_var_list.size())
391
  {
392
    // we now have an instantiation to try
393
1204
    return d_qim.getInstantiate()->addInstantiation(
394
602
        d_quant, m, InferenceId::QUANTIFIERS_INST_E_MATCHING_HO);
395
  }
396
  else
397
  {
398
608
    Node var = d_ho_var_list[var_index];
399
304
    unsigned vnum = var.getAttribute(InstVarNumAttribute());
400
304
    Assert(vnum < m.size());
401
608
    Node value = m[vnum];
402
304
    Assert(d_lchildren[vnum][0] == value);
403
304
    Assert(d_ho_var_bvl.find(var) != d_ho_var_bvl.end());
404
405
    // now, recurse on arguments to enumerate equivalent matching lambda
406
    // expressions
407
    bool ret =
408
304
        sendInstantiationArg(m, var_index, vnum, 0, d_ho_var_bvl[var], false);
409
410
    // reset the value
411
304
    m[vnum] = value;
412
413
304
    return ret;
414
  }
415
}
416
417
1272
bool HigherOrderTrigger::sendInstantiationArg(std::vector<Node>& m,
418
                                              unsigned var_index,
419
                                              unsigned vnum,
420
                                              unsigned arg_index,
421
                                              Node lbvl,
422
                                              bool arg_changed)
423
{
424
2544
  Trace("ho-unif-debug2") << "send inst arg " << arg_index << " / "
425
1272
                          << lbvl.getNumChildren() << std::endl;
426
1272
  if (arg_index == lbvl.getNumChildren())
427
  {
428
    // construct the lambda
429
714
    if (arg_changed)
430
    {
431
876
      Trace("ho-unif-debug2")
432
438
          << "  make lambda from children: " << d_lchildren[vnum] << std::endl;
433
      Node body =
434
876
          NodeManager::currentNM()->mkNode(kind::APPLY_UF, d_lchildren[vnum]);
435
438
      Trace("ho-unif-debug2") << "  got " << body << std::endl;
436
876
      Node lam = NodeManager::currentNM()->mkNode(kind::LAMBDA, lbvl, body);
437
438
      m[vnum] = lam;
438
438
      Trace("ho-unif-debug2") << "  try " << vnum << " -> " << lam << std::endl;
439
    }
440
714
    return sendInstantiation(m, var_index + 1);
441
  }
442
  else
443
  {
444
    std::map<unsigned, unsigned>::iterator itr =
445
558
        d_arg_to_arg_rep[vnum].find(arg_index);
446
    unsigned rindex =
447
558
        itr != d_arg_to_arg_rep[vnum].end() ? itr->second : arg_index;
448
    std::map<unsigned, std::vector<Node> >::iterator itv =
449
558
        d_arg_vector[vnum].find(rindex);
450
558
    Assert(itv != d_arg_vector[vnum].end());
451
1116
    Node prev = lbvl[arg_index];
452
558
    bool ret = false;
453
    // try each argument in the vector
454
1196
    for (unsigned i = 0, size = itv->second.size(); i < size; i++)
455
    {
456
968
      bool new_arg_changed = arg_changed || prev != itv->second[i];
457
968
      d_lchildren[vnum][arg_index + 1] = itv->second[i];
458
968
      if (sendInstantiationArg(
459
              m, var_index, vnum, arg_index + 1, lbvl, new_arg_changed))
460
      {
461
330
        ret = true;
462
330
        break;
463
      }
464
    }
465
    // clean up
466
558
    d_lchildren[vnum][arg_index + 1] = prev;
467
558
    return ret;
468
  }
469
}
470
471
108
uint64_t HigherOrderTrigger::addHoTypeMatchPredicateLemmas()
472
{
473
108
  if (d_ho_var_types.empty())
474
  {
475
    return 0;
476
  }
477
108
  Trace("ho-quant-trigger") << "addHoTypeMatchPredicateLemmas..." << std::endl;
478
108
  uint64_t numLemmas = 0;
479
  // this forces expansion of APPLY_UF terms to curried HO_APPLY chains
480
108
  TermDb* tdb = d_treg.getTermDatabase();
481
108
  unsigned size = tdb->getNumOperators();
482
108
  NodeManager* nm = NodeManager::currentNM();
483
768
  for (unsigned j = 0; j < size; j++)
484
  {
485
1320
    Node f = tdb->getOperator(j);
486
660
    if (f.isVar())
487
    {
488
1024
      TypeNode tn = f.getType();
489
512
      if (tn.isFunction())
490
      {
491
1024
        std::vector<TypeNode> argTypes = tn.getArgTypes();
492
512
        Assert(argTypes.size() > 0);
493
1024
        TypeNode range = tn.getRangeType();
494
        // for each function type suffix of the type of f, for example if
495
        // f : (Int -> (Int -> Int))
496
        // we iterate with stn = (Int -> (Int -> Int)) and (Int -> Int)
497
1112
        for (unsigned a = 0, arg_size = argTypes.size(); a < arg_size; a++)
498
        {
499
1200
          std::vector<TypeNode> sargts;
500
600
          sargts.insert(sargts.begin(), argTypes.begin() + a, argTypes.end());
501
600
          Assert(sargts.size() > 0);
502
1200
          TypeNode stn = nm->mkFunctionType(sargts, range);
503
1200
          Trace("ho-quant-trigger-debug")
504
600
              << "For " << f << ", check " << stn << "..." << std::endl;
505
          // if a variable of this type occurs in this trigger
506
600
          if (d_ho_var_types.find(stn) != d_ho_var_types.end())
507
          {
508
536
            Node u = HoTermDb::getHoTypeMatchPredicate(tn);
509
536
            Node au = nm->mkNode(kind::APPLY_UF, u, f);
510
268
            if (d_qim.addPendingLemma(au,
511
                                      InferenceId::QUANTIFIERS_HO_MATCH_PRED))
512
            {
513
              // this forces f to be a first-class member of the quantifier-free
514
              // equality engine, which in turn forces the quantifier-free
515
              // theory solver to expand it to an HO_APPLY chain.
516
116
              Trace("ho-quant")
517
58
                  << "Added ho match predicate lemma : " << au << std::endl;
518
58
              numLemmas++;
519
            }
520
          }
521
        }
522
      }
523
    }
524
  }
525
526
108
  return numLemmas;
527
}
528
529
}  // namespace inst
530
}  // namespace quantifiers
531
}  // namespace theory
532
31137
}  // namespace cvc5