GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/ematching/ho_trigger.cpp Lines: 229 257 89.1 %
Date: 2021-08-16 Branches: 463 1128 41.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mathias Preiner, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of higher-order trigger class.
14
 */
15
16
#include <stack>
17
18
#include "theory/quantifiers/ematching/ho_trigger.h"
19
#include "theory/quantifiers/instantiate.h"
20
#include "theory/quantifiers/quantifiers_inference_manager.h"
21
#include "theory/quantifiers/quantifiers_registry.h"
22
#include "theory/quantifiers/quantifiers_state.h"
23
#include "theory/quantifiers/term_database.h"
24
#include "theory/quantifiers/term_registry.h"
25
#include "theory/quantifiers/term_util.h"
26
#include "theory/uf/theory_uf_rewriter.h"
27
#include "util/hash.h"
28
29
using namespace cvc5::kind;
30
31
namespace cvc5 {
32
namespace theory {
33
namespace quantifiers {
34
namespace inst {
35
36
36
HigherOrderTrigger::HigherOrderTrigger(
37
    QuantifiersState& qs,
38
    QuantifiersInferenceManager& qim,
39
    QuantifiersRegistry& qr,
40
    TermRegistry& tr,
41
    Node q,
42
    std::vector<Node>& nodes,
43
36
    std::map<Node, std::vector<Node> >& ho_apps)
44
36
    : Trigger(qs, qim, qr, tr, q, nodes), d_ho_var_apps(ho_apps)
45
{
46
36
  NodeManager* nm = NodeManager::currentNM();
47
  // process the higher-order variable applications
48
84
  for (std::pair<const Node, std::vector<Node> >& as : d_ho_var_apps)
49
  {
50
96
    Node n = as.first;
51
48
    d_ho_var_list.push_back(n);
52
96
    TypeNode tn = n.getType();
53
48
    Assert(tn.isFunction());
54
48
    if (Trace.isOn("ho-quant-trigger"))
55
    {
56
      Trace("ho-quant-trigger") << "  have " << as.second.size();
57
      Trace("ho-quant-trigger") << " patterns with variable operator " << n
58
                                << ":" << std::endl;
59
      for (unsigned j = 0; j < as.second.size(); j++)
60
      {
61
        Trace("ho-quant-trigger") << "  " << as.second[j] << std::endl;
62
      }
63
    }
64
48
    if (d_ho_var_types.find(tn) == d_ho_var_types.end())
65
    {
66
96
      Trace("ho-quant-trigger") << "  type " << tn
67
48
                                << " needs higher-order matching." << std::endl;
68
48
      d_ho_var_types.insert(tn);
69
    }
70
    // make the bound variable lists
71
48
    d_ho_var_bvl[n] = nm->getBoundVarListForFunctionType(tn);
72
120
    for (const Node& nc : d_ho_var_bvl[n])
73
    {
74
72
      d_ho_var_bvs[n].push_back(nc);
75
    }
76
  }
77
36
}
78
79
72
HigherOrderTrigger::~HigherOrderTrigger() {}
80
void HigherOrderTrigger::collectHoVarApplyTerms(
81
    Node q, Node& n, std::map<Node, std::vector<Node> >& apps)
82
{
83
  std::vector<Node> ns;
84
  ns.push_back(n);
85
  collectHoVarApplyTerms(q, ns, apps);
86
  Assert(ns.size() == 1);
87
  n = ns[0];
88
}
89
90
16850
void HigherOrderTrigger::collectHoVarApplyTerms(
91
    Node q, std::vector<Node>& ns, std::map<Node, std::vector<Node> >& apps)
92
{
93
33700
  std::unordered_map<TNode, Node> visited;
94
16850
  std::unordered_map<TNode, Node>::iterator it;
95
  // whether the visited node is a child of a HO_APPLY chain
96
33700
  std::unordered_map<TNode, bool> withinApply;
97
33700
  std::vector<TNode> visit;
98
33700
  TNode cur;
99
35201
  for (unsigned i = 0, size = ns.size(); i < size; i++)
100
  {
101
18351
    visit.push_back(ns[i]);
102
18351
    withinApply[ns[i]] = false;
103
201133
    do
104
    {
105
219484
      cur = visit.back();
106
219484
      visit.pop_back();
107
108
219484
      it = visited.find(cur);
109
219484
      if (it == visited.end())
110
      {
111
        // do not look in nested quantifiers
112
107800
        if (cur.getKind() == FORALL)
113
        {
114
          visited[cur] = cur;
115
        }
116
        else
117
        {
118
107800
          bool curWithinApply = withinApply[cur];
119
107800
          visited[cur] = Node::null();
120
107800
          visit.push_back(cur);
121
201133
          for (unsigned j = 0, sizec = cur.getNumChildren(); j < sizec; j++)
122
          {
123
93333
            withinApply[cur[j]] = curWithinApply && j == 0;
124
93333
            visit.push_back(cur[j]);
125
          }
126
        }
127
      }
128
111684
      else if (it->second.isNull())
129
      {
130
        // carry the conversion
131
215600
        Node ret = cur;
132
107800
        bool childChanged = false;
133
215600
        std::vector<Node> children;
134
107800
        if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
135
        {
136
43188
          children.push_back(cur.getOperator());
137
        }
138
201133
        for (const Node& nc : cur)
139
        {
140
93333
          it = visited.find(nc);
141
93333
          Assert(it != visited.end());
142
93333
          Assert(!it->second.isNull());
143
93333
          childChanged = childChanged || nc != it->second;
144
93333
          children.push_back(it->second);
145
        }
146
107800
        if (childChanged)
147
        {
148
40
          ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
149
        }
150
        // now, convert and store the application
151
107800
        if (!withinApply[cur])
152
        {
153
215600
          TNode op;
154
107800
          if (ret.getKind() == kind::APPLY_UF)
155
          {
156
            // could be a fully applied function variable
157
36921
            op = ret.getOperator();
158
          }
159
70879
          else if (ret.getKind() == kind::HO_APPLY)
160
          {
161
113
            op = ret;
162
407
            while (op.getKind() == kind::HO_APPLY)
163
            {
164
147
              op = op[0];
165
            }
166
          }
167
107800
          if (!op.isNull())
168
          {
169
37034
            if (op.getKind() == kind::INST_CONSTANT)
170
            {
171
48
              Assert(TermUtil::getInstConstAttr(ret) == q);
172
96
              Trace("ho-quant-trigger-debug")
173
48
                  << "Ho variable apply term : " << ret << " with head " << op
174
48
                  << std::endl;
175
48
              if (ret.getKind() == kind::APPLY_UF)
176
              {
177
80
                Node prev = ret;
178
                // for consistency, convert to HO_APPLY if fully applied
179
40
                ret = uf::TheoryUfRewriter::getHoApplyForApplyUf(ret);
180
              }
181
48
              apps[op].push_back(ret);
182
            }
183
          }
184
        }
185
107800
        visited[cur] = ret;
186
      }
187
219484
    } while (!visit.empty());
188
189
    // store the conversion
190
18351
    Assert(visited.find(ns[i]) != visited.end());
191
18351
    ns[i] = visited[ns[i]];
192
  }
193
16850
}
194
195
108
uint64_t HigherOrderTrigger::addInstantiations()
196
{
197
  // call the base class implementation
198
108
  uint64_t addedFoLemmas = Trigger::addInstantiations();
199
  // also adds predicate lemms to force app completion
200
108
  uint64_t addedHoLemmas = addHoTypeMatchPredicateLemmas();
201
108
  return addedHoLemmas + addedFoLemmas;
202
}
203
204
192
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m, InferenceId id)
205
{
206
192
  if (options::hoMatching())
207
  {
208
    // get substitution corresponding to m
209
384
    std::vector<TNode> vars;
210
384
    std::vector<TNode> subs;
211
684
    for (unsigned i = 0, size = d_quant[0].getNumChildren(); i < size; i++)
212
    {
213
492
      subs.push_back(m[i]);
214
492
      vars.push_back(d_qreg.getInstantiationConstant(d_quant, i));
215
    }
216
217
192
    Trace("ho-unif-debug") << "Run higher-order unification..." << std::endl;
218
219
    // get the substituted form of all variable-operator ho application terms
220
384
    std::map<TNode, std::vector<Node> > ho_var_apps_subs;
221
432
    for (std::pair<const Node, std::vector<Node> >& ha : d_ho_var_apps)
222
    {
223
480
      TNode var = ha.first;
224
480
      for (unsigned j = 0, size = ha.second.size(); j < size; j++)
225
      {
226
480
        TNode app = ha.second[j];
227
        Node sapp =
228
480
            app.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
229
240
        ho_var_apps_subs[var].push_back(sapp);
230
480
        Trace("ho-unif-debug") << "  app[" << var << "] : " << app << " -> "
231
240
                               << sapp << std::endl;
232
      }
233
    }
234
235
    // compute argument vectors for each variable
236
192
    d_lchildren.clear();
237
192
    d_arg_to_arg_rep.clear();
238
192
    d_arg_vector.clear();
239
432
    for (std::pair<const TNode, std::vector<Node> >& ha : ho_var_apps_subs)
240
    {
241
480
      TNode var = ha.first;
242
240
      unsigned vnum = var.getAttribute(InstVarNumAttribute());
243
480
      TNode value = m[vnum];
244
240
      Trace("ho-unif-debug") << "  val[" << var << "] = " << value << std::endl;
245
246
480
      Trace("ho-unif-debug2") << "initialize lambda information..."
247
240
                              << std::endl;
248
      // initialize the lambda children
249
240
      d_lchildren[vnum].push_back(value);
250
      std::map<TNode, std::vector<Node> >::iterator ithb =
251
240
          d_ho_var_bvs.find(var);
252
240
      Assert(ithb != d_ho_var_bvs.end());
253
480
      d_lchildren[vnum].insert(
254
480
          d_lchildren[vnum].end(), ithb->second.begin(), ithb->second.end());
255
256
240
      Trace("ho-unif-debug2") << "compute fixed arguments..." << std::endl;
257
      // compute for each argument if it is only applied to a fixed value modulo
258
      // equality
259
480
      std::map<unsigned, Node> fixed_vals;
260
480
      for (unsigned i = 0; i < ha.second.size(); i++)
261
      {
262
480
        std::vector<TNode> args;
263
        // must substitute the operator we matched with the original
264
        // higher-order variable (var) that matched it. This ensures that the
265
        // argument vector (args) below is of the proper length. This handles,
266
        // for example, matches like:
267
        //   (@ x y) with (@ (@ k1 k2) k3)
268
        // where k3 but not k2 should be an argument of the match.
269
480
        Node hmatch = ha.second[i];
270
240
        Trace("ho-unif-debug2") << "Match is " << hmatch << std::endl;
271
240
        hmatch = hmatch.substitute(value, var);
272
240
        Trace("ho-unif-debug2") << "Pre-subs match is " << hmatch << std::endl;
273
480
        Node f = uf::TheoryUfRewriter::decomposeHoApply(hmatch, args);
274
        // Assert( f==value );
275
616
        for (unsigned k = 0, size = args.size(); k < size; k++)
276
        {
277
          // must now subsitute back, to handle cases like
278
          // (@ x y) matching (@ t (@ t s))
279
          // where the above substitution would produce (@ x (@ x s)),
280
          // but the argument should be (@ t s).
281
376
          args[k] = args[k].substitute(var, value);
282
752
          Node val = args[k];
283
376
          std::map<unsigned, Node>::iterator itf = fixed_vals.find(k);
284
376
          if (itf == fixed_vals.end())
285
          {
286
376
            fixed_vals[k] = val;
287
          }
288
          else if (!itf->second.isNull())
289
          {
290
            if (!d_qstate.areEqual(itf->second, args[k]))
291
            {
292
              if (!d_treg.getTermDatabase()->isEntailed(
293
                      itf->second.eqNode(args[k]), true))
294
              {
295
                fixed_vals[k] = Node::null();
296
              }
297
            }
298
          }
299
        }
300
      }
301
240
      if (Trace.isOn("ho-unif-debug"))
302
      {
303
        for (std::map<unsigned, Node>::iterator itf = fixed_vals.begin();
304
             itf != fixed_vals.end();
305
             ++itf)
306
        {
307
          Trace("ho-unif-debug") << "  arg[" << var << "][" << itf->first
308
                                 << "] : " << itf->second << std::endl;
309
        }
310
      }
311
312
      // now construct argument vectors
313
240
      Trace("ho-unif-debug2") << "compute argument vectors..." << std::endl;
314
480
      std::map<Node, unsigned> arg_to_rep;
315
640
      for (unsigned index = 0, size = ithb->second.size(); index < size;
316
           index++)
317
      {
318
800
        Node bv_at_index = ithb->second[index];
319
400
        std::map<unsigned, Node>::iterator itf = fixed_vals.find(index);
320
400
        Trace("ho-unif-debug") << "  * arg[" << var << "][" << index << "]";
321
400
        if (itf != fixed_vals.end())
322
        {
323
376
          if (!itf->second.isNull())
324
          {
325
752
            Node r = d_qstate.getRepresentative(itf->second);
326
376
            std::map<Node, unsigned>::iterator itfr = arg_to_rep.find(r);
327
376
            if (itfr != arg_to_rep.end())
328
            {
329
52
              d_arg_to_arg_rep[vnum][index] = itfr->second;
330
              // function applied to equivalent values at multiple arguments,
331
              // can permute variables
332
52
              d_arg_vector[vnum][itfr->second].push_back(bv_at_index);
333
104
              Trace("ho-unif-debug") << " = { self } ++ arg[" << var << "]["
334
52
                                     << itfr->second << "]" << std::endl;
335
            }
336
            else
337
            {
338
324
              arg_to_rep[r] = index;
339
              // function applied to single value, can either use variable or
340
              // value at this argument position
341
324
              d_arg_vector[vnum][index].push_back(bv_at_index);
342
324
              d_arg_vector[vnum][index].push_back(itf->second);
343
324
              if (!options::hoMatchingVarArgPriority())
344
              {
345
                std::reverse(d_arg_vector[vnum][index].begin(),
346
                             d_arg_vector[vnum][index].end());
347
              }
348
648
              Trace("ho-unif-debug") << " = { self, " << itf->second << " } "
349
324
                                     << std::endl;
350
            }
351
          }
352
          else
353
          {
354
            // function is applied to disequal values, can only use variable at
355
            // this argument position
356
            d_arg_vector[vnum][index].push_back(bv_at_index);
357
            Trace("ho-unif-debug") << " = { self } (disequal)" << std::endl;
358
          }
359
        }
360
        else
361
        {
362
          // argument is irrelevant to matching, assume identity variable
363
24
          d_arg_vector[vnum][index].push_back(bv_at_index);
364
24
          Trace("ho-unif-debug") << " = { self } (irrelevant)" << std::endl;
365
        }
366
      }
367
240
      Trace("ho-unif-debug2") << "finished." << std::endl;
368
    }
369
370
192
    bool ret = sendInstantiation(m, 0);
371
192
    Trace("ho-unif-debug") << "Finished, success = " << ret << std::endl;
372
192
    return ret;
373
  }
374
  else
375
  {
376
    // do not run higher-order matching
377
    return d_qim.getInstantiate()->addInstantiation(d_quant, m, id);
378
  }
379
}
380
381
// recursion depth limited by number of arguments of higher order variables
382
// occurring as pattern operators (very small)
383
906
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m,
384
                                           size_t var_index)
385
{
386
1812
  Trace("ho-unif-debug2") << "send inst " << var_index << " / "
387
906
                          << d_ho_var_list.size() << std::endl;
388
906
  if (var_index == d_ho_var_list.size())
389
  {
390
    // we now have an instantiation to try
391
1204
    return d_qim.getInstantiate()->addInstantiation(
392
602
        d_quant, m, InferenceId::QUANTIFIERS_INST_E_MATCHING_HO);
393
  }
394
  else
395
  {
396
608
    Node var = d_ho_var_list[var_index];
397
304
    unsigned vnum = var.getAttribute(InstVarNumAttribute());
398
304
    Assert(vnum < m.size());
399
608
    Node value = m[vnum];
400
304
    Assert(d_lchildren[vnum][0] == value);
401
304
    Assert(d_ho_var_bvl.find(var) != d_ho_var_bvl.end());
402
403
    // now, recurse on arguments to enumerate equivalent matching lambda
404
    // expressions
405
    bool ret =
406
304
        sendInstantiationArg(m, var_index, vnum, 0, d_ho_var_bvl[var], false);
407
408
    // reset the value
409
304
    m[vnum] = value;
410
411
304
    return ret;
412
  }
413
}
414
415
1272
bool HigherOrderTrigger::sendInstantiationArg(std::vector<Node>& m,
416
                                              unsigned var_index,
417
                                              unsigned vnum,
418
                                              unsigned arg_index,
419
                                              Node lbvl,
420
                                              bool arg_changed)
421
{
422
2544
  Trace("ho-unif-debug2") << "send inst arg " << arg_index << " / "
423
1272
                          << lbvl.getNumChildren() << std::endl;
424
1272
  if (arg_index == lbvl.getNumChildren())
425
  {
426
    // construct the lambda
427
714
    if (arg_changed)
428
    {
429
876
      Trace("ho-unif-debug2")
430
438
          << "  make lambda from children: " << d_lchildren[vnum] << std::endl;
431
      Node body =
432
876
          NodeManager::currentNM()->mkNode(kind::APPLY_UF, d_lchildren[vnum]);
433
438
      Trace("ho-unif-debug2") << "  got " << body << std::endl;
434
876
      Node lam = NodeManager::currentNM()->mkNode(kind::LAMBDA, lbvl, body);
435
438
      m[vnum] = lam;
436
438
      Trace("ho-unif-debug2") << "  try " << vnum << " -> " << lam << std::endl;
437
    }
438
714
    return sendInstantiation(m, var_index + 1);
439
  }
440
  else
441
  {
442
    std::map<unsigned, unsigned>::iterator itr =
443
558
        d_arg_to_arg_rep[vnum].find(arg_index);
444
    unsigned rindex =
445
558
        itr != d_arg_to_arg_rep[vnum].end() ? itr->second : arg_index;
446
    std::map<unsigned, std::vector<Node> >::iterator itv =
447
558
        d_arg_vector[vnum].find(rindex);
448
558
    Assert(itv != d_arg_vector[vnum].end());
449
1116
    Node prev = lbvl[arg_index];
450
558
    bool ret = false;
451
    // try each argument in the vector
452
1196
    for (unsigned i = 0, size = itv->second.size(); i < size; i++)
453
    {
454
968
      bool new_arg_changed = arg_changed || prev != itv->second[i];
455
968
      d_lchildren[vnum][arg_index + 1] = itv->second[i];
456
968
      if (sendInstantiationArg(
457
              m, var_index, vnum, arg_index + 1, lbvl, new_arg_changed))
458
      {
459
330
        ret = true;
460
330
        break;
461
      }
462
    }
463
    // clean up
464
558
    d_lchildren[vnum][arg_index + 1] = prev;
465
558
    return ret;
466
  }
467
}
468
469
108
uint64_t HigherOrderTrigger::addHoTypeMatchPredicateLemmas()
470
{
471
108
  if (d_ho_var_types.empty())
472
  {
473
    return 0;
474
  }
475
108
  Trace("ho-quant-trigger") << "addHoTypeMatchPredicateLemmas..." << std::endl;
476
108
  uint64_t numLemmas = 0;
477
  // this forces expansion of APPLY_UF terms to curried HO_APPLY chains
478
108
  TermDb* tdb = d_treg.getTermDatabase();
479
108
  unsigned size = tdb->getNumOperators();
480
108
  NodeManager* nm = NodeManager::currentNM();
481
768
  for (unsigned j = 0; j < size; j++)
482
  {
483
1320
    Node f = tdb->getOperator(j);
484
660
    if (f.isVar())
485
    {
486
1024
      TypeNode tn = f.getType();
487
512
      if (tn.isFunction())
488
      {
489
1024
        std::vector<TypeNode> argTypes = tn.getArgTypes();
490
512
        Assert(argTypes.size() > 0);
491
1024
        TypeNode range = tn.getRangeType();
492
        // for each function type suffix of the type of f, for example if
493
        // f : (Int -> (Int -> Int))
494
        // we iterate with stn = (Int -> (Int -> Int)) and (Int -> Int)
495
1112
        for (unsigned a = 0, arg_size = argTypes.size(); a < arg_size; a++)
496
        {
497
1200
          std::vector<TypeNode> sargts;
498
600
          sargts.insert(sargts.begin(), argTypes.begin() + a, argTypes.end());
499
600
          Assert(sargts.size() > 0);
500
1200
          TypeNode stn = nm->mkFunctionType(sargts, range);
501
1200
          Trace("ho-quant-trigger-debug")
502
600
              << "For " << f << ", check " << stn << "..." << std::endl;
503
          // if a variable of this type occurs in this trigger
504
600
          if (d_ho_var_types.find(stn) != d_ho_var_types.end())
505
          {
506
536
            Node u = tdb->getHoTypeMatchPredicate(tn);
507
536
            Node au = nm->mkNode(kind::APPLY_UF, u, f);
508
268
            if (d_qim.addPendingLemma(au,
509
                                      InferenceId::QUANTIFIERS_HO_MATCH_PRED))
510
            {
511
              // this forces f to be a first-class member of the quantifier-free
512
              // equality engine, which in turn forces the quantifier-free
513
              // theory solver to expand it to an HO_APPLY chain.
514
116
              Trace("ho-quant")
515
58
                  << "Added ho match predicate lemma : " << au << std::endl;
516
58
              numLemmas++;
517
            }
518
          }
519
        }
520
      }
521
    }
522
  }
523
524
108
  return numLemmas;
525
}
526
527
}  // namespace inst
528
}  // namespace quantifiers
529
}  // namespace theory
530
29340
}  // namespace cvc5