GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/ematching/ho_trigger.cpp Lines: 229 257 89.1 %
Date: 2021-09-29 Branches: 463 1128 41.0 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mathias Preiner, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of higher-order trigger class.
14
 */
15
16
#include "theory/quantifiers/ematching/ho_trigger.h"
17
18
#include <stack>
19
20
#include "theory/quantifiers/ho_term_database.h"
21
#include "theory/quantifiers/instantiate.h"
22
#include "theory/quantifiers/quantifiers_inference_manager.h"
23
#include "theory/quantifiers/quantifiers_registry.h"
24
#include "theory/quantifiers/quantifiers_state.h"
25
#include "theory/quantifiers/term_registry.h"
26
#include "theory/quantifiers/term_util.h"
27
#include "theory/uf/theory_uf_rewriter.h"
28
#include "util/hash.h"
29
30
using namespace cvc5::kind;
31
32
namespace cvc5 {
33
namespace theory {
34
namespace quantifiers {
35
namespace inst {
36
37
8
HigherOrderTrigger::HigherOrderTrigger(
38
    Env& env,
39
    QuantifiersState& qs,
40
    QuantifiersInferenceManager& qim,
41
    QuantifiersRegistry& qr,
42
    TermRegistry& tr,
43
    Node q,
44
    std::vector<Node>& nodes,
45
8
    std::map<Node, std::vector<Node> >& ho_apps)
46
8
    : Trigger(env, qs, qim, qr, tr, q, nodes), d_ho_var_apps(ho_apps)
47
{
48
8
  NodeManager* nm = NodeManager::currentNM();
49
  // process the higher-order variable applications
50
18
  for (std::pair<const Node, std::vector<Node> >& as : d_ho_var_apps)
51
  {
52
20
    Node n = as.first;
53
10
    d_ho_var_list.push_back(n);
54
20
    TypeNode tn = n.getType();
55
10
    Assert(tn.isFunction());
56
10
    if (Trace.isOn("ho-quant-trigger"))
57
    {
58
      Trace("ho-quant-trigger") << "  have " << as.second.size();
59
      Trace("ho-quant-trigger") << " patterns with variable operator " << n
60
                                << ":" << std::endl;
61
      for (unsigned j = 0; j < as.second.size(); j++)
62
      {
63
        Trace("ho-quant-trigger") << "  " << as.second[j] << std::endl;
64
      }
65
    }
66
10
    if (d_ho_var_types.find(tn) == d_ho_var_types.end())
67
    {
68
20
      Trace("ho-quant-trigger") << "  type " << tn
69
10
                                << " needs higher-order matching." << std::endl;
70
10
      d_ho_var_types.insert(tn);
71
    }
72
    // make the bound variable lists
73
10
    d_ho_var_bvl[n] = nm->getBoundVarListForFunctionType(tn);
74
25
    for (const Node& nc : d_ho_var_bvl[n])
75
    {
76
15
      d_ho_var_bvs[n].push_back(nc);
77
    }
78
  }
79
8
}
80
81
16
HigherOrderTrigger::~HigherOrderTrigger() {}
82
void HigherOrderTrigger::collectHoVarApplyTerms(
83
    Node q, Node& n, std::map<Node, std::vector<Node> >& apps)
84
{
85
  std::vector<Node> ns;
86
  ns.push_back(n);
87
  collectHoVarApplyTerms(q, ns, apps);
88
  Assert(ns.size() == 1);
89
  n = ns[0];
90
}
91
92
6172
void HigherOrderTrigger::collectHoVarApplyTerms(
93
    Node q, std::vector<Node>& ns, std::map<Node, std::vector<Node> >& apps)
94
{
95
12344
  std::unordered_map<TNode, Node> visited;
96
6172
  std::unordered_map<TNode, Node>::iterator it;
97
  // whether the visited node is a child of a HO_APPLY chain
98
12344
  std::unordered_map<TNode, bool> withinApply;
99
12344
  std::vector<TNode> visit;
100
12344
  TNode cur;
101
12918
  for (unsigned i = 0, size = ns.size(); i < size; i++)
102
  {
103
6746
    visit.push_back(ns[i]);
104
6746
    withinApply[ns[i]] = false;
105
75031
    do
106
    {
107
81777
      cur = visit.back();
108
81777
      visit.pop_back();
109
110
81777
      it = visited.find(cur);
111
81777
      if (it == visited.end())
112
      {
113
        // do not look in nested quantifiers
114
40166
        if (cur.getKind() == FORALL)
115
        {
116
          visited[cur] = cur;
117
        }
118
        else
119
        {
120
40166
          bool curWithinApply = withinApply[cur];
121
40166
          visited[cur] = Node::null();
122
40166
          visit.push_back(cur);
123
75031
          for (unsigned j = 0, sizec = cur.getNumChildren(); j < sizec; j++)
124
          {
125
34865
            withinApply[cur[j]] = curWithinApply && j == 0;
126
34865
            visit.push_back(cur[j]);
127
          }
128
        }
129
      }
130
41611
      else if (it->second.isNull())
131
      {
132
        // carry the conversion
133
80332
        Node ret = cur;
134
40166
        bool childChanged = false;
135
80332
        std::vector<Node> children;
136
40166
        if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
137
        {
138
16106
          children.push_back(cur.getOperator());
139
        }
140
75031
        for (const Node& nc : cur)
141
        {
142
34865
          it = visited.find(nc);
143
34865
          Assert(it != visited.end());
144
34865
          Assert(!it->second.isNull());
145
34865
          childChanged = childChanged || nc != it->second;
146
34865
          children.push_back(it->second);
147
        }
148
40166
        if (childChanged)
149
        {
150
6
          ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
151
        }
152
        // now, convert and store the application
153
40166
        if (!withinApply[cur])
154
        {
155
80332
          TNode op;
156
40166
          if (ret.getKind() == kind::APPLY_UF)
157
          {
158
            // could be a fully applied function variable
159
14145
            op = ret.getOperator();
160
          }
161
26021
          else if (ret.getKind() == kind::HO_APPLY)
162
          {
163
68
            op = ret;
164
266
            while (op.getKind() == kind::HO_APPLY)
165
            {
166
99
              op = op[0];
167
            }
168
          }
169
40166
          if (!op.isNull())
170
          {
171
14213
            if (op.getKind() == kind::INST_CONSTANT)
172
            {
173
10
              Assert(TermUtil::getInstConstAttr(ret) == q);
174
20
              Trace("ho-quant-trigger-debug")
175
10
                  << "Ho variable apply term : " << ret << " with head " << op
176
10
                  << std::endl;
177
10
              if (ret.getKind() == kind::APPLY_UF)
178
              {
179
16
                Node prev = ret;
180
                // for consistency, convert to HO_APPLY if fully applied
181
8
                ret = uf::TheoryUfRewriter::getHoApplyForApplyUf(ret);
182
              }
183
10
              apps[op].push_back(ret);
184
            }
185
          }
186
        }
187
40166
        visited[cur] = ret;
188
      }
189
81777
    } while (!visit.empty());
190
191
    // store the conversion
192
6746
    Assert(visited.find(ns[i]) != visited.end());
193
6746
    ns[i] = visited[ns[i]];
194
  }
195
6172
}
196
197
24
uint64_t HigherOrderTrigger::addInstantiations()
198
{
199
  // call the base class implementation
200
24
  uint64_t addedFoLemmas = Trigger::addInstantiations();
201
  // also adds predicate lemms to force app completion
202
24
  uint64_t addedHoLemmas = addHoTypeMatchPredicateLemmas();
203
24
  return addedHoLemmas + addedFoLemmas;
204
}
205
206
41
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m, InferenceId id)
207
{
208
41
  if (options::hoMatching())
209
  {
210
    // get substitution corresponding to m
211
82
    std::vector<TNode> vars;
212
82
    std::vector<TNode> subs;
213
150
    for (unsigned i = 0, size = d_quant[0].getNumChildren(); i < size; i++)
214
    {
215
109
      subs.push_back(m[i]);
216
109
      vars.push_back(d_qreg.getInstantiationConstant(d_quant, i));
217
    }
218
219
41
    Trace("ho-unif-debug") << "Run higher-order unification..." << std::endl;
220
221
    // get the substituted form of all variable-operator ho application terms
222
82
    std::map<TNode, std::vector<Node> > ho_var_apps_subs;
223
87
    for (std::pair<const Node, std::vector<Node> >& ha : d_ho_var_apps)
224
    {
225
92
      TNode var = ha.first;
226
92
      for (unsigned j = 0, size = ha.second.size(); j < size; j++)
227
      {
228
92
        TNode app = ha.second[j];
229
        Node sapp =
230
92
            app.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
231
46
        ho_var_apps_subs[var].push_back(sapp);
232
92
        Trace("ho-unif-debug") << "  app[" << var << "] : " << app << " -> "
233
46
                               << sapp << std::endl;
234
      }
235
    }
236
237
    // compute argument vectors for each variable
238
41
    d_lchildren.clear();
239
41
    d_arg_to_arg_rep.clear();
240
41
    d_arg_vector.clear();
241
87
    for (std::pair<const TNode, std::vector<Node> >& ha : ho_var_apps_subs)
242
    {
243
92
      TNode var = ha.first;
244
46
      unsigned vnum = var.getAttribute(InstVarNumAttribute());
245
92
      TNode value = m[vnum];
246
46
      Trace("ho-unif-debug") << "  val[" << var << "] = " << value << std::endl;
247
248
92
      Trace("ho-unif-debug2") << "initialize lambda information..."
249
46
                              << std::endl;
250
      // initialize the lambda children
251
46
      d_lchildren[vnum].push_back(value);
252
      std::map<TNode, std::vector<Node> >::iterator ithb =
253
46
          d_ho_var_bvs.find(var);
254
46
      Assert(ithb != d_ho_var_bvs.end());
255
92
      d_lchildren[vnum].insert(
256
92
          d_lchildren[vnum].end(), ithb->second.begin(), ithb->second.end());
257
258
46
      Trace("ho-unif-debug2") << "compute fixed arguments..." << std::endl;
259
      // compute for each argument if it is only applied to a fixed value modulo
260
      // equality
261
92
      std::map<unsigned, Node> fixed_vals;
262
92
      for (unsigned i = 0; i < ha.second.size(); i++)
263
      {
264
92
        std::vector<TNode> args;
265
        // must substitute the operator we matched with the original
266
        // higher-order variable (var) that matched it. This ensures that the
267
        // argument vector (args) below is of the proper length. This handles,
268
        // for example, matches like:
269
        //   (@ x y) with (@ (@ k1 k2) k3)
270
        // where k3 but not k2 should be an argument of the match.
271
92
        Node hmatch = ha.second[i];
272
46
        Trace("ho-unif-debug2") << "Match is " << hmatch << std::endl;
273
46
        hmatch = hmatch.substitute(value, var);
274
46
        Trace("ho-unif-debug2") << "Pre-subs match is " << hmatch << std::endl;
275
92
        Node f = uf::TheoryUfRewriter::decomposeHoApply(hmatch, args);
276
        // Assert( f==value );
277
120
        for (unsigned k = 0, size = args.size(); k < size; k++)
278
        {
279
          // must now subsitute back, to handle cases like
280
          // (@ x y) matching (@ t (@ t s))
281
          // where the above substitution would produce (@ x (@ x s)),
282
          // but the argument should be (@ t s).
283
74
          args[k] = args[k].substitute(var, value);
284
148
          Node val = args[k];
285
74
          std::map<unsigned, Node>::iterator itf = fixed_vals.find(k);
286
74
          if (itf == fixed_vals.end())
287
          {
288
74
            fixed_vals[k] = val;
289
          }
290
          else if (!itf->second.isNull())
291
          {
292
            if (!d_qstate.areEqual(itf->second, args[k]))
293
            {
294
              if (!d_treg.getTermDatabase()->isEntailed(
295
                      itf->second.eqNode(args[k]), true))
296
              {
297
                fixed_vals[k] = Node::null();
298
              }
299
            }
300
          }
301
        }
302
      }
303
46
      if (Trace.isOn("ho-unif-debug"))
304
      {
305
        for (std::map<unsigned, Node>::iterator itf = fixed_vals.begin();
306
             itf != fixed_vals.end();
307
             ++itf)
308
        {
309
          Trace("ho-unif-debug") << "  arg[" << var << "][" << itf->first
310
                                 << "] : " << itf->second << std::endl;
311
        }
312
      }
313
314
      // now construct argument vectors
315
46
      Trace("ho-unif-debug2") << "compute argument vectors..." << std::endl;
316
92
      std::map<Node, unsigned> arg_to_rep;
317
125
      for (unsigned index = 0, size = ithb->second.size(); index < size;
318
           index++)
319
      {
320
158
        Node bv_at_index = ithb->second[index];
321
79
        std::map<unsigned, Node>::iterator itf = fixed_vals.find(index);
322
79
        Trace("ho-unif-debug") << "  * arg[" << var << "][" << index << "]";
323
79
        if (itf != fixed_vals.end())
324
        {
325
74
          if (!itf->second.isNull())
326
          {
327
148
            Node r = d_qstate.getRepresentative(itf->second);
328
74
            std::map<Node, unsigned>::iterator itfr = arg_to_rep.find(r);
329
74
            if (itfr != arg_to_rep.end())
330
            {
331
13
              d_arg_to_arg_rep[vnum][index] = itfr->second;
332
              // function applied to equivalent values at multiple arguments,
333
              // can permute variables
334
13
              d_arg_vector[vnum][itfr->second].push_back(bv_at_index);
335
26
              Trace("ho-unif-debug") << " = { self } ++ arg[" << var << "]["
336
13
                                     << itfr->second << "]" << std::endl;
337
            }
338
            else
339
            {
340
61
              arg_to_rep[r] = index;
341
              // function applied to single value, can either use variable or
342
              // value at this argument position
343
61
              d_arg_vector[vnum][index].push_back(bv_at_index);
344
61
              d_arg_vector[vnum][index].push_back(itf->second);
345
61
              if (!options::hoMatchingVarArgPriority())
346
              {
347
                std::reverse(d_arg_vector[vnum][index].begin(),
348
                             d_arg_vector[vnum][index].end());
349
              }
350
122
              Trace("ho-unif-debug") << " = { self, " << itf->second << " } "
351
61
                                     << std::endl;
352
            }
353
          }
354
          else
355
          {
356
            // function is applied to disequal values, can only use variable at
357
            // this argument position
358
            d_arg_vector[vnum][index].push_back(bv_at_index);
359
            Trace("ho-unif-debug") << " = { self } (disequal)" << std::endl;
360
          }
361
        }
362
        else
363
        {
364
          // argument is irrelevant to matching, assume identity variable
365
5
          d_arg_vector[vnum][index].push_back(bv_at_index);
366
5
          Trace("ho-unif-debug") << " = { self } (irrelevant)" << std::endl;
367
        }
368
      }
369
46
      Trace("ho-unif-debug2") << "finished." << std::endl;
370
    }
371
372
41
    bool ret = sendInstantiation(m, 0);
373
41
    Trace("ho-unif-debug") << "Finished, success = " << ret << std::endl;
374
41
    return ret;
375
  }
376
  else
377
  {
378
    // do not run higher-order matching
379
    return d_qim.getInstantiate()->addInstantiation(d_quant, m, id);
380
  }
381
}
382
383
// recursion depth limited by number of arguments of higher order variables
384
// occurring as pattern operators (very small)
385
162
bool HigherOrderTrigger::sendInstantiation(std::vector<Node>& m,
386
                                           size_t var_index)
387
{
388
324
  Trace("ho-unif-debug2") << "send inst " << var_index << " / "
389
162
                          << d_ho_var_list.size() << std::endl;
390
162
  if (var_index == d_ho_var_list.size())
391
  {
392
    // we now have an instantiation to try
393
226
    return d_qim.getInstantiate()->addInstantiation(
394
113
        d_quant, m, InferenceId::QUANTIFIERS_INST_E_MATCHING_HO);
395
  }
396
  else
397
  {
398
98
    Node var = d_ho_var_list[var_index];
399
49
    unsigned vnum = var.getAttribute(InstVarNumAttribute());
400
49
    Assert(vnum < m.size());
401
98
    Node value = m[vnum];
402
49
    Assert(d_lchildren[vnum][0] == value);
403
49
    Assert(d_ho_var_bvl.find(var) != d_ho_var_bvl.end());
404
405
    // now, recurse on arguments to enumerate equivalent matching lambda
406
    // expressions
407
    bool ret =
408
49
        sendInstantiationArg(m, var_index, vnum, 0, d_ho_var_bvl[var], false);
409
410
    // reset the value
411
49
    m[vnum] = value;
412
413
49
    return ret;
414
  }
415
}
416
417
222
bool HigherOrderTrigger::sendInstantiationArg(std::vector<Node>& m,
418
                                              unsigned var_index,
419
                                              unsigned vnum,
420
                                              unsigned arg_index,
421
                                              Node lbvl,
422
                                              bool arg_changed)
423
{
424
444
  Trace("ho-unif-debug2") << "send inst arg " << arg_index << " / "
425
222
                          << lbvl.getNumChildren() << std::endl;
426
222
  if (arg_index == lbvl.getNumChildren())
427
  {
428
    // construct the lambda
429
121
    if (arg_changed)
430
    {
431
158
      Trace("ho-unif-debug2")
432
79
          << "  make lambda from children: " << d_lchildren[vnum] << std::endl;
433
      Node body =
434
158
          NodeManager::currentNM()->mkNode(kind::APPLY_UF, d_lchildren[vnum]);
435
79
      Trace("ho-unif-debug2") << "  got " << body << std::endl;
436
158
      Node lam = NodeManager::currentNM()->mkNode(kind::LAMBDA, lbvl, body);
437
79
      m[vnum] = lam;
438
79
      Trace("ho-unif-debug2") << "  try " << vnum << " -> " << lam << std::endl;
439
    }
440
121
    return sendInstantiation(m, var_index + 1);
441
  }
442
  else
443
  {
444
    std::map<unsigned, unsigned>::iterator itr =
445
101
        d_arg_to_arg_rep[vnum].find(arg_index);
446
    unsigned rindex =
447
101
        itr != d_arg_to_arg_rep[vnum].end() ? itr->second : arg_index;
448
    std::map<unsigned, std::vector<Node> >::iterator itv =
449
101
        d_arg_vector[vnum].find(rindex);
450
101
    Assert(itv != d_arg_vector[vnum].end());
451
202
    Node prev = lbvl[arg_index];
452
101
    bool ret = false;
453
    // try each argument in the vector
454
208
    for (unsigned i = 0, size = itv->second.size(); i < size; i++)
455
    {
456
173
      bool new_arg_changed = arg_changed || prev != itv->second[i];
457
173
      d_lchildren[vnum][arg_index + 1] = itv->second[i];
458
173
      if (sendInstantiationArg(
459
              m, var_index, vnum, arg_index + 1, lbvl, new_arg_changed))
460
      {
461
66
        ret = true;
462
66
        break;
463
      }
464
    }
465
    // clean up
466
101
    d_lchildren[vnum][arg_index + 1] = prev;
467
101
    return ret;
468
  }
469
}
470
471
24
uint64_t HigherOrderTrigger::addHoTypeMatchPredicateLemmas()
472
{
473
24
  if (d_ho_var_types.empty())
474
  {
475
    return 0;
476
  }
477
24
  Trace("ho-quant-trigger") << "addHoTypeMatchPredicateLemmas..." << std::endl;
478
24
  uint64_t numLemmas = 0;
479
  // this forces expansion of APPLY_UF terms to curried HO_APPLY chains
480
24
  TermDb* tdb = d_treg.getTermDatabase();
481
24
  unsigned size = tdb->getNumOperators();
482
24
  NodeManager* nm = NodeManager::currentNM();
483
156
  for (unsigned j = 0; j < size; j++)
484
  {
485
264
    Node f = tdb->getOperator(j);
486
132
    if (f.isVar())
487
    {
488
202
      TypeNode tn = f.getType();
489
101
      if (tn.isFunction())
490
      {
491
202
        std::vector<TypeNode> argTypes = tn.getArgTypes();
492
101
        Assert(argTypes.size() > 0);
493
202
        TypeNode range = tn.getRangeType();
494
        // for each function type suffix of the type of f, for example if
495
        // f : (Int -> (Int -> Int))
496
        // we iterate with stn = (Int -> (Int -> Int)) and (Int -> Int)
497
221
        for (unsigned a = 0, arg_size = argTypes.size(); a < arg_size; a++)
498
        {
499
240
          std::vector<TypeNode> sargts;
500
120
          sargts.insert(sargts.begin(), argTypes.begin() + a, argTypes.end());
501
120
          Assert(sargts.size() > 0);
502
240
          TypeNode stn = nm->mkFunctionType(sargts, range);
503
240
          Trace("ho-quant-trigger-debug")
504
120
              << "For " << f << ", check " << stn << "..." << std::endl;
505
          // if a variable of this type occurs in this trigger
506
120
          if (d_ho_var_types.find(stn) != d_ho_var_types.end())
507
          {
508
86
            Node u = HoTermDb::getHoTypeMatchPredicate(tn);
509
86
            Node au = nm->mkNode(kind::APPLY_UF, u, f);
510
43
            if (d_qim.addPendingLemma(au,
511
                                      InferenceId::QUANTIFIERS_HO_MATCH_PRED))
512
            {
513
              // this forces f to be a first-class member of the quantifier-free
514
              // equality engine, which in turn forces the quantifier-free
515
              // theory solver to expand it to an HO_APPLY chain.
516
26
              Trace("ho-quant")
517
13
                  << "Added ho match predicate lemma : " << au << std::endl;
518
13
              numLemmas++;
519
            }
520
          }
521
        }
522
      }
523
    }
524
  }
525
526
24
  return numLemmas;
527
}
528
529
}  // namespace inst
530
}  // namespace quantifiers
531
}  // namespace theory
532
22746
}  // namespace cvc5