GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/ematching/ho_trigger.cpp Lines: 229 257 89.1 %
Date: 2021-09-16 Branches: 463 1128 41.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mathias Preiner, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of higher-order trigger class.
14
 */
15
16
#include "theory/quantifiers/ematching/ho_trigger.h"
17
18
#include <stack>
19
20
#include "theory/quantifiers/ho_term_database.h"
21
#include "theory/quantifiers/instantiate.h"
22
#include "theory/quantifiers/quantifiers_inference_manager.h"
23
#include "theory/quantifiers/quantifiers_registry.h"
24
#include "theory/quantifiers/quantifiers_state.h"
25
#include "theory/quantifiers/term_registry.h"
26
#include "theory/quantifiers/term_util.h"
27
#include "theory/uf/theory_uf_rewriter.h"
28
#include "util/hash.h"
29
30
using namespace cvc5::kind;
31
32
namespace cvc5 {
33
namespace theory {
34
namespace quantifiers {
35
namespace inst {
36
37
36
HigherOrderTrigger::HigherOrderTrigger(
38
    QuantifiersState& qs,
39
    QuantifiersInferenceManager& qim,
40
    QuantifiersRegistry& qr,
41
    TermRegistry& tr,
42
    Node q,
43
    std::vector<Node>& nodes,
44
36
    std::map<Node, std::vector<Node> >& ho_apps)
45
36
    : Trigger(qs, qim, qr, tr, q, nodes), d_ho_var_apps(ho_apps)
46
{
47
36
  NodeManager* nm = NodeManager::currentNM();
48
  // process the higher-order variable applications
49
84
  for (std::pair<const Node, std::vector<Node> >& as : d_ho_var_apps)
50
  {
51
96
    Node n = as.first;
52
48
    d_ho_var_list.push_back(n);
53
96
    TypeNode tn = n.getType();
54
48
    Assert(tn.isFunction());
55
48
    if (Trace.isOn("ho-quant-trigger"))
56
    {
57
      Trace("ho-quant-trigger") << "  have " << as.second.size();
58
      Trace("ho-quant-trigger") << " patterns with variable operator " << n
59
                                << ":" << std::endl;
60
      for (unsigned j = 0; j < as.second.size(); j++)
61
      {
62
        Trace("ho-quant-trigger") << "  " << as.second[j] << std::endl;
63
      }
64
    }
65
48
    if (d_ho_var_types.find(tn) == d_ho_var_types.end())
66
    {
67
96
      Trace("ho-quant-trigger") << "  type " << tn
68
48
                                << " needs higher-order matching." << std::endl;
69
48
      d_ho_var_types.insert(tn);
70
    }
71
    // make the bound variable lists
72
48
    d_ho_var_bvl[n] = nm->getBoundVarListForFunctionType(tn);
73
120
    for (const Node& nc : d_ho_var_bvl[n])
74
    {
75
72
      d_ho_var_bvs[n].push_back(nc);
76
    }
77
  }
78
36
}
79
80
72
HigherOrderTrigger::~HigherOrderTrigger() {}
81
void HigherOrderTrigger::collectHoVarApplyTerms(
82
    Node q, Node& n, std::map<Node, std::vector<Node> >& apps)
83
{
84
  std::vector<Node> ns;
85
  ns.push_back(n);
86
  collectHoVarApplyTerms(q, ns, apps);
87
  Assert(ns.size() == 1);
88
  n = ns[0];
89
}
90
91
17035
void HigherOrderTrigger::collectHoVarApplyTerms(
92
    Node q, std::vector<Node>& ns, std::map<Node, std::vector<Node> >& apps)
93
{
94
34070
  std::unordered_map<TNode, Node> visited;
95
17035
  std::unordered_map<TNode, Node>::iterator it;
96
  // whether the visited node is a child of a HO_APPLY chain
97
34070
  std::unordered_map<TNode, bool> withinApply;
98
34070
  std::vector<TNode> visit;
99
34070
  TNode cur;
100
35661
  for (unsigned i = 0, size = ns.size(); i < size; i++)
101
  {
102
18626
    visit.push_back(ns[i]);
103
18626
    withinApply[ns[i]] = false;
104
202162
    do
105
    {
106
220788
      cur = visit.back();
107
220788
      visit.pop_back();
108
109
220788
      it = visited.find(cur);
110
220788
      if (it == visited.end())
111
      {
112
        // do not look in nested quantifiers
113
108415
        if (cur.getKind() == FORALL)
114
        {
115
          visited[cur] = cur;
116
        }
117
        else
118
        {
119
108415
          bool curWithinApply = withinApply[cur];
120
108415
          visited[cur] = Node::null();
121
108415
          visit.push_back(cur);
122
202162
          for (unsigned j = 0, sizec = cur.getNumChildren(); j < sizec; j++)
123
          {
124
93747
            withinApply[cur[j]] = curWithinApply && j == 0;
125
93747
            visit.push_back(cur[j]);
126
          }
127
        }
128
      }
129
112373
      else if (it->second.isNull())
130
      {
131
        // carry the conversion
132
216830
        Node ret = cur;
133
108415
        bool childChanged = false;
134
216830
        std::vector<Node> children;
135
108415
        if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
136
        {
137
43203
          children.push_back(cur.getOperator());
138
        }
139
202162
        for (const Node& nc : cur)
140
        {
141
93747
          it = visited.find(nc);
142
93747
          Assert(it != visited.end());
143
93747
          Assert(!it->second.isNull());
144
93747
          childChanged = childChanged || nc != it->second;
145
93747
          children.push_back(it->second);
146
        }
147
108415
        if (childChanged)
148
        {
149
40
          ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
150
        }
151
        // now, convert and store the application
152
108415
        if (!withinApply[cur])
153
        {
154
216830
          TNode op;
155
108415
          if (ret.getKind() == kind::APPLY_UF)
156
          {
157
            // could be a fully applied function variable
158
37222
            op = ret.getOperator();
159
          }
160
71193
          else if (ret.getKind() == kind::HO_APPLY)
161
          {
162
113
            op = ret;
163
407
            while (op.getKind() == kind::HO_APPLY)
164
            {
165
147
              op = op[0];
166
            }
167
          }
168
108415
          if (!op.isNull())
169
          {
170
37335
            if (op.getKind() == kind::INST_CONSTANT)
171
            {
172
48
              Assert(TermUtil::getInstConstAttr(ret) == q);
173
96
              Trace("ho-quant-trigger-debug")
174
48
                  << "Ho variable apply term : " << ret << " with head " << op
175
48
                  << std::endl;
176
48
              if (ret.getKind() == kind::APPLY_UF)
177
              {
178
80
                Node prev = ret;
179
                // for consistency, convert to HO_APPLY if fully applied
180
40
                ret = uf::TheoryUfRewriter::getHoApplyForApplyUf(ret);
181
              }
182
48
              apps[op].push_back(ret);
183
            }
184
          }
185
        }
186
108415
        visited[cur] = ret;
187
      }
188
220788
    } while (!visit.empty());
189
190
    // store the conversion
191
18626
    Assert(visited.find(ns[i]) != visited.end());
192
18626
    ns[i] = visited[ns[i]];
193
  }
194
17035
}
195
196
108
uint64_t HigherOrderTrigger::addInstantiations()
197
{
198
  // call the base class implementation
199
108
  uint64_t addedFoLemmas = Trigger::addInstantiations();
200
  // also adds predicate lemms to force app completion
201
108
  uint64_t addedHoLemmas = addHoTypeMatchPredicateLemmas();
202
108
  return addedHoLemmas + addedFoLemmas;
203
}
204
205
192
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m, InferenceId id)
206
{
207
192
  if (options::hoMatching())
208
  {
209
    // get substitution corresponding to m
210
384
    std::vector<TNode> vars;
211
384
    std::vector<TNode> subs;
212
684
    for (unsigned i = 0, size = d_quant[0].getNumChildren(); i < size; i++)
213
    {
214
492
      subs.push_back(m[i]);
215
492
      vars.push_back(d_qreg.getInstantiationConstant(d_quant, i));
216
    }
217
218
192
    Trace("ho-unif-debug") << "Run higher-order unification..." << std::endl;
219
220
    // get the substituted form of all variable-operator ho application terms
221
384
    std::map<TNode, std::vector<Node> > ho_var_apps_subs;
222
432
    for (std::pair<const Node, std::vector<Node> >& ha : d_ho_var_apps)
223
    {
224
480
      TNode var = ha.first;
225
480
      for (unsigned j = 0, size = ha.second.size(); j < size; j++)
226
      {
227
480
        TNode app = ha.second[j];
228
        Node sapp =
229
480
            app.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
230
240
        ho_var_apps_subs[var].push_back(sapp);
231
480
        Trace("ho-unif-debug") << "  app[" << var << "] : " << app << " -> "
232
240
                               << sapp << std::endl;
233
      }
234
    }
235
236
    // compute argument vectors for each variable
237
192
    d_lchildren.clear();
238
192
    d_arg_to_arg_rep.clear();
239
192
    d_arg_vector.clear();
240
432
    for (std::pair<const TNode, std::vector<Node> >& ha : ho_var_apps_subs)
241
    {
242
480
      TNode var = ha.first;
243
240
      unsigned vnum = var.getAttribute(InstVarNumAttribute());
244
480
      TNode value = m[vnum];
245
240
      Trace("ho-unif-debug") << "  val[" << var << "] = " << value << std::endl;
246
247
480
      Trace("ho-unif-debug2") << "initialize lambda information..."
248
240
                              << std::endl;
249
      // initialize the lambda children
250
240
      d_lchildren[vnum].push_back(value);
251
      std::map<TNode, std::vector<Node> >::iterator ithb =
252
240
          d_ho_var_bvs.find(var);
253
240
      Assert(ithb != d_ho_var_bvs.end());
254
480
      d_lchildren[vnum].insert(
255
480
          d_lchildren[vnum].end(), ithb->second.begin(), ithb->second.end());
256
257
240
      Trace("ho-unif-debug2") << "compute fixed arguments..." << std::endl;
258
      // compute for each argument if it is only applied to a fixed value modulo
259
      // equality
260
480
      std::map<unsigned, Node> fixed_vals;
261
480
      for (unsigned i = 0; i < ha.second.size(); i++)
262
      {
263
480
        std::vector<TNode> args;
264
        // must substitute the operator we matched with the original
265
        // higher-order variable (var) that matched it. This ensures that the
266
        // argument vector (args) below is of the proper length. This handles,
267
        // for example, matches like:
268
        //   (@ x y) with (@ (@ k1 k2) k3)
269
        // where k3 but not k2 should be an argument of the match.
270
480
        Node hmatch = ha.second[i];
271
240
        Trace("ho-unif-debug2") << "Match is " << hmatch << std::endl;
272
240
        hmatch = hmatch.substitute(value, var);
273
240
        Trace("ho-unif-debug2") << "Pre-subs match is " << hmatch << std::endl;
274
480
        Node f = uf::TheoryUfRewriter::decomposeHoApply(hmatch, args);
275
        // Assert( f==value );
276
616
        for (unsigned k = 0, size = args.size(); k < size; k++)
277
        {
278
          // must now subsitute back, to handle cases like
279
          // (@ x y) matching (@ t (@ t s))
280
          // where the above substitution would produce (@ x (@ x s)),
281
          // but the argument should be (@ t s).
282
376
          args[k] = args[k].substitute(var, value);
283
752
          Node val = args[k];
284
376
          std::map<unsigned, Node>::iterator itf = fixed_vals.find(k);
285
376
          if (itf == fixed_vals.end())
286
          {
287
376
            fixed_vals[k] = val;
288
          }
289
          else if (!itf->second.isNull())
290
          {
291
            if (!d_qstate.areEqual(itf->second, args[k]))
292
            {
293
              if (!d_treg.getTermDatabase()->isEntailed(
294
                      itf->second.eqNode(args[k]), true))
295
              {
296
                fixed_vals[k] = Node::null();
297
              }
298
            }
299
          }
300
        }
301
      }
302
240
      if (Trace.isOn("ho-unif-debug"))
303
      {
304
        for (std::map<unsigned, Node>::iterator itf = fixed_vals.begin();
305
             itf != fixed_vals.end();
306
             ++itf)
307
        {
308
          Trace("ho-unif-debug") << "  arg[" << var << "][" << itf->first
309
                                 << "] : " << itf->second << std::endl;
310
        }
311
      }
312
313
      // now construct argument vectors
314
240
      Trace("ho-unif-debug2") << "compute argument vectors..." << std::endl;
315
480
      std::map<Node, unsigned> arg_to_rep;
316
640
      for (unsigned index = 0, size = ithb->second.size(); index < size;
317
           index++)
318
      {
319
800
        Node bv_at_index = ithb->second[index];
320
400
        std::map<unsigned, Node>::iterator itf = fixed_vals.find(index);
321
400
        Trace("ho-unif-debug") << "  * arg[" << var << "][" << index << "]";
322
400
        if (itf != fixed_vals.end())
323
        {
324
376
          if (!itf->second.isNull())
325
          {
326
752
            Node r = d_qstate.getRepresentative(itf->second);
327
376
            std::map<Node, unsigned>::iterator itfr = arg_to_rep.find(r);
328
376
            if (itfr != arg_to_rep.end())
329
            {
330
52
              d_arg_to_arg_rep[vnum][index] = itfr->second;
331
              // function applied to equivalent values at multiple arguments,
332
              // can permute variables
333
52
              d_arg_vector[vnum][itfr->second].push_back(bv_at_index);
334
104
              Trace("ho-unif-debug") << " = { self } ++ arg[" << var << "]["
335
52
                                     << itfr->second << "]" << std::endl;
336
            }
337
            else
338
            {
339
324
              arg_to_rep[r] = index;
340
              // function applied to single value, can either use variable or
341
              // value at this argument position
342
324
              d_arg_vector[vnum][index].push_back(bv_at_index);
343
324
              d_arg_vector[vnum][index].push_back(itf->second);
344
324
              if (!options::hoMatchingVarArgPriority())
345
              {
346
                std::reverse(d_arg_vector[vnum][index].begin(),
347
                             d_arg_vector[vnum][index].end());
348
              }
349
648
              Trace("ho-unif-debug") << " = { self, " << itf->second << " } "
350
324
                                     << std::endl;
351
            }
352
          }
353
          else
354
          {
355
            // function is applied to disequal values, can only use variable at
356
            // this argument position
357
            d_arg_vector[vnum][index].push_back(bv_at_index);
358
            Trace("ho-unif-debug") << " = { self } (disequal)" << std::endl;
359
          }
360
        }
361
        else
362
        {
363
          // argument is irrelevant to matching, assume identity variable
364
24
          d_arg_vector[vnum][index].push_back(bv_at_index);
365
24
          Trace("ho-unif-debug") << " = { self } (irrelevant)" << std::endl;
366
        }
367
      }
368
240
      Trace("ho-unif-debug2") << "finished." << std::endl;
369
    }
370
371
192
    bool ret = sendInstantiation(m, 0);
372
192
    Trace("ho-unif-debug") << "Finished, success = " << ret << std::endl;
373
192
    return ret;
374
  }
375
  else
376
  {
377
    // do not run higher-order matching
378
    return d_qim.getInstantiate()->addInstantiation(d_quant, m, id);
379
  }
380
}
381
382
// recursion depth limited by number of arguments of higher order variables
383
// occurring as pattern operators (very small)
384
906
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m,
385
                                           size_t var_index)
386
{
387
1812
  Trace("ho-unif-debug2") << "send inst " << var_index << " / "
388
906
                          << d_ho_var_list.size() << std::endl;
389
906
  if (var_index == d_ho_var_list.size())
390
  {
391
    // we now have an instantiation to try
392
1204
    return d_qim.getInstantiate()->addInstantiation(
393
602
        d_quant, m, InferenceId::QUANTIFIERS_INST_E_MATCHING_HO);
394
  }
395
  else
396
  {
397
608
    Node var = d_ho_var_list[var_index];
398
304
    unsigned vnum = var.getAttribute(InstVarNumAttribute());
399
304
    Assert(vnum < m.size());
400
608
    Node value = m[vnum];
401
304
    Assert(d_lchildren[vnum][0] == value);
402
304
    Assert(d_ho_var_bvl.find(var) != d_ho_var_bvl.end());
403
404
    // now, recurse on arguments to enumerate equivalent matching lambda
405
    // expressions
406
    bool ret =
407
304
        sendInstantiationArg(m, var_index, vnum, 0, d_ho_var_bvl[var], false);
408
409
    // reset the value
410
304
    m[vnum] = value;
411
412
304
    return ret;
413
  }
414
}
415
416
1272
bool HigherOrderTrigger::sendInstantiationArg(std::vector<Node>& m,
417
                                              unsigned var_index,
418
                                              unsigned vnum,
419
                                              unsigned arg_index,
420
                                              Node lbvl,
421
                                              bool arg_changed)
422
{
423
2544
  Trace("ho-unif-debug2") << "send inst arg " << arg_index << " / "
424
1272
                          << lbvl.getNumChildren() << std::endl;
425
1272
  if (arg_index == lbvl.getNumChildren())
426
  {
427
    // construct the lambda
428
714
    if (arg_changed)
429
    {
430
876
      Trace("ho-unif-debug2")
431
438
          << "  make lambda from children: " << d_lchildren[vnum] << std::endl;
432
      Node body =
433
876
          NodeManager::currentNM()->mkNode(kind::APPLY_UF, d_lchildren[vnum]);
434
438
      Trace("ho-unif-debug2") << "  got " << body << std::endl;
435
876
      Node lam = NodeManager::currentNM()->mkNode(kind::LAMBDA, lbvl, body);
436
438
      m[vnum] = lam;
437
438
      Trace("ho-unif-debug2") << "  try " << vnum << " -> " << lam << std::endl;
438
    }
439
714
    return sendInstantiation(m, var_index + 1);
440
  }
441
  else
442
  {
443
    std::map<unsigned, unsigned>::iterator itr =
444
558
        d_arg_to_arg_rep[vnum].find(arg_index);
445
    unsigned rindex =
446
558
        itr != d_arg_to_arg_rep[vnum].end() ? itr->second : arg_index;
447
    std::map<unsigned, std::vector<Node> >::iterator itv =
448
558
        d_arg_vector[vnum].find(rindex);
449
558
    Assert(itv != d_arg_vector[vnum].end());
450
1116
    Node prev = lbvl[arg_index];
451
558
    bool ret = false;
452
    // try each argument in the vector
453
1196
    for (unsigned i = 0, size = itv->second.size(); i < size; i++)
454
    {
455
968
      bool new_arg_changed = arg_changed || prev != itv->second[i];
456
968
      d_lchildren[vnum][arg_index + 1] = itv->second[i];
457
968
      if (sendInstantiationArg(
458
              m, var_index, vnum, arg_index + 1, lbvl, new_arg_changed))
459
      {
460
330
        ret = true;
461
330
        break;
462
      }
463
    }
464
    // clean up
465
558
    d_lchildren[vnum][arg_index + 1] = prev;
466
558
    return ret;
467
  }
468
}
469
470
108
uint64_t HigherOrderTrigger::addHoTypeMatchPredicateLemmas()
471
{
472
108
  if (d_ho_var_types.empty())
473
  {
474
    return 0;
475
  }
476
108
  Trace("ho-quant-trigger") << "addHoTypeMatchPredicateLemmas..." << std::endl;
477
108
  uint64_t numLemmas = 0;
478
  // this forces expansion of APPLY_UF terms to curried HO_APPLY chains
479
108
  TermDb* tdb = d_treg.getTermDatabase();
480
108
  unsigned size = tdb->getNumOperators();
481
108
  NodeManager* nm = NodeManager::currentNM();
482
768
  for (unsigned j = 0; j < size; j++)
483
  {
484
1320
    Node f = tdb->getOperator(j);
485
660
    if (f.isVar())
486
    {
487
1024
      TypeNode tn = f.getType();
488
512
      if (tn.isFunction())
489
      {
490
1024
        std::vector<TypeNode> argTypes = tn.getArgTypes();
491
512
        Assert(argTypes.size() > 0);
492
1024
        TypeNode range = tn.getRangeType();
493
        // for each function type suffix of the type of f, for example if
494
        // f : (Int -> (Int -> Int))
495
        // we iterate with stn = (Int -> (Int -> Int)) and (Int -> Int)
496
1112
        for (unsigned a = 0, arg_size = argTypes.size(); a < arg_size; a++)
497
        {
498
1200
          std::vector<TypeNode> sargts;
499
600
          sargts.insert(sargts.begin(), argTypes.begin() + a, argTypes.end());
500
600
          Assert(sargts.size() > 0);
501
1200
          TypeNode stn = nm->mkFunctionType(sargts, range);
502
1200
          Trace("ho-quant-trigger-debug")
503
600
              << "For " << f << ", check " << stn << "..." << std::endl;
504
          // if a variable of this type occurs in this trigger
505
600
          if (d_ho_var_types.find(stn) != d_ho_var_types.end())
506
          {
507
536
            Node u = HoTermDb::getHoTypeMatchPredicate(tn);
508
536
            Node au = nm->mkNode(kind::APPLY_UF, u, f);
509
268
            if (d_qim.addPendingLemma(au,
510
                                      InferenceId::QUANTIFIERS_HO_MATCH_PRED))
511
            {
512
              // this forces f to be a first-class member of the quantifier-free
513
              // equality engine, which in turn forces the quantifier-free
514
              // theory solver to expand it to an HO_APPLY chain.
515
116
              Trace("ho-quant")
516
58
                  << "Added ho match predicate lemma : " << au << std::endl;
517
58
              numLemmas++;
518
            }
519
          }
520
        }
521
      }
522
    }
523
  }
524
525
108
  return numLemmas;
526
}
527
528
}  // namespace inst
529
}  // namespace quantifiers
530
}  // namespace theory
531
29577
}  // namespace cvc5