GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus/transition_inference.cpp Lines: 278 296 93.9 %
Date: 2021-03-23 Branches: 521 1118 46.6 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file transition_inference.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Andres Noetzli
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implmentation of utility for inferring whether a synthesis conjecture
13
 ** encodes a transition system.
14
 **
15
 **/
16
#include "theory/quantifiers/sygus/transition_inference.h"
17
18
#include "expr/node_algorithm.h"
19
#include "theory/arith/arith_msum.h"
20
#include "theory/quantifiers/term_util.h"
21
#include "theory/rewriter.h"
22
23
using namespace CVC4::kind;
24
25
namespace CVC4 {
26
namespace theory {
27
namespace quantifiers {
28
29
534
bool DetTrace::DetTraceTrie::add(Node loc, const std::vector<Node>& val)
30
{
31
534
  DetTraceTrie* curr = this;
32
1588
  for (const Node& v : val)
33
  {
34
1054
    curr = &(curr->d_children[v]);
35
  }
36
534
  if (curr->d_children.empty())
37
  {
38
534
    curr->d_children[loc].clear();
39
534
    return true;
40
  }
41
  return false;
42
}
43
44
20
Node DetTrace::DetTraceTrie::constructFormula(const std::vector<Node>& vars,
45
                                              unsigned index)
46
{
47
20
  NodeManager* nm = NodeManager::currentNM();
48
20
  if (index == vars.size())
49
  {
50
    return nm->mkConst(true);
51
  }
52
40
  std::vector<Node> disj;
53
64
  for (std::pair<const Node, DetTraceTrie>& p : d_children)
54
  {
55
88
    Node eq = vars[index].eqNode(p.first);
56
44
    if (index < vars.size() - 1)
57
    {
58
30
      Node conc = p.second.constructFormula(vars, index + 1);
59
15
      disj.push_back(nm->mkNode(AND, eq, conc));
60
    }
61
    else
62
    {
63
29
      disj.push_back(eq);
64
    }
65
  }
66
20
  Assert(!disj.empty());
67
20
  return disj.size() == 1 ? disj[0] : nm->mkNode(OR, disj);
68
}
69
70
534
bool DetTrace::increment(Node loc, std::vector<Node>& vals)
71
{
72
534
  if (d_trie.add(loc, vals))
73
  {
74
1588
    for (unsigned i = 0, vsize = vals.size(); i < vsize; i++)
75
    {
76
1054
      d_curr[i] = vals[i];
77
    }
78
534
    return true;
79
  }
80
  return false;
81
}
82
83
5
Node DetTrace::constructFormula(const std::vector<Node>& vars)
84
{
85
5
  return d_trie.constructFormula(vars);
86
}
87
88
539
void DetTrace::print(const char* c) const
89
{
90
1601
  for (const Node& n : d_curr)
91
  {
92
1062
    Trace(c) << n << " ";
93
  }
94
539
}
95
96
62
Node TransitionInference::getFunction() const { return d_func; }
97
98
63
void TransitionInference::getVariables(std::vector<Node>& vars) const
99
{
100
63
  vars.insert(vars.end(), d_vars.begin(), d_vars.end());
101
63
}
102
103
63
Node TransitionInference::getPreCondition() const { return d_pre.d_this; }
104
592
Node TransitionInference::getPostCondition() const { return d_post.d_this; }
105
545
Node TransitionInference::getTransitionRelation() const
106
{
107
545
  return d_trans.d_this;
108
}
109
110
152
void TransitionInference::getConstantSubstitution(
111
    const std::vector<Node>& vars,
112
    const std::vector<Node>& disjuncts,
113
    std::vector<Node>& const_var,
114
    std::vector<Node>& const_subs,
115
    bool reqPol)
116
{
117
845
  for (const Node& d : disjuncts)
118
  {
119
1386
    Node sn;
120
693
    if (!const_var.empty())
121
    {
122
375
      sn = d.substitute(const_var.begin(),
123
                        const_var.end(),
124
                        const_subs.begin(),
125
                        const_subs.end());
126
375
      sn = Rewriter::rewrite(sn);
127
    }
128
    else
129
    {
130
318
      sn = d;
131
    }
132
693
    bool slit_pol = sn.getKind() != NOT;
133
1386
    Node slit = sn.getKind() == NOT ? sn[0] : sn;
134
693
    if (slit.getKind() == EQUAL && slit_pol == reqPol)
135
    {
136
      // check if it is a variable equality
137
498
      TNode v;
138
498
      Node s;
139
377
      for (unsigned r = 0; r < 2; r++)
140
      {
141
346
        if (std::find(vars.begin(), vars.end(), slit[r]) != vars.end())
142
        {
143
218
          if (!expr::hasSubterm(slit[1 - r], slit[r]))
144
          {
145
218
            v = slit[r];
146
218
            s = slit[1 - r];
147
218
            break;
148
          }
149
        }
150
      }
151
249
      if (v.isNull())
152
      {
153
        // solve for var
154
62
        std::map<Node, Node> msum;
155
31
        if (ArithMSum::getMonomialSumLit(slit, msum))
156
        {
157
124
          for (std::pair<const Node, Node>& m : msum)
158
          {
159
93
            if (std::find(vars.begin(), vars.end(), m.first) != vars.end())
160
            {
161
62
              Node veq_c;
162
62
              Node val;
163
31
              int ires = ArithMSum::isolate(m.first, msum, veq_c, val, EQUAL);
164
62
              if (ires != 0 && veq_c.isNull()
165
62
                  && !expr::hasSubterm(val, m.first))
166
              {
167
31
                v = m.first;
168
31
                s = val;
169
              }
170
            }
171
          }
172
        }
173
      }
174
249
      if (!v.isNull())
175
      {
176
498
        TNode ts = s;
177
1174
        for (unsigned k = 0, csize = const_subs.size(); k < csize; k++)
178
        {
179
925
          const_subs[k] = Rewriter::rewrite(const_subs[k].substitute(v, ts));
180
        }
181
498
        Trace("cegqi-inv-debug2")
182
249
            << "...substitution : " << v << " -> " << s << std::endl;
183
249
        const_var.push_back(v);
184
249
        const_subs.push_back(s);
185
      }
186
    }
187
  }
188
152
}
189
190
72
void TransitionInference::process(Node n, Node f)
191
{
192
  // set the function
193
72
  d_func = f;
194
72
  process(n);
195
72
}
196
197
72
void TransitionInference::process(Node n)
198
{
199
72
  NodeManager* nm = NodeManager::currentNM();
200
72
  d_complete = true;
201
72
  d_trivial = true;
202
144
  std::vector<Node> n_check;
203
72
  if (n.getKind() == AND)
204
  {
205
194
    for (const Node& nc : n)
206
    {
207
144
      n_check.push_back(nc);
208
    }
209
  }
210
  else
211
  {
212
22
    n_check.push_back(n);
213
  }
214
238
  for (const Node& nn : n_check)
215
  {
216
318
    std::map<bool, std::map<Node, bool> > visited;
217
318
    std::map<bool, Node> terms;
218
318
    std::vector<Node> disjuncts;
219
332
    Trace("cegqi-inv") << "TransitionInference : Process disjunct : " << nn
220
166
                       << std::endl;
221
175
    if (!processDisjunct(nn, terms, disjuncts, visited, true))
222
    {
223
9
      d_complete = false;
224
9
      continue;
225
    }
226
157
    if (terms.empty())
227
    {
228
5
      continue;
229
    }
230
304
    Node curr;
231
    // The component that this disjunct contributes to, where
232
    // 1 : pre-condition, -1 : post-condition, 0 : transition relation
233
    int comp_num;
234
152
    std::map<bool, Node>::iterator itt = terms.find(false);
235
152
    if (itt != terms.end())
236
    {
237
103
      curr = itt->second;
238
103
      if (terms.find(true) != terms.end())
239
      {
240
40
        comp_num = 0;
241
      }
242
      else
243
      {
244
63
        comp_num = -1;
245
      }
246
    }
247
    else
248
    {
249
49
      curr = terms[true];
250
49
      comp_num = 1;
251
    }
252
152
    Trace("cegqi-inv-debug2") << "  normalize based on " << curr << std::endl;
253
304
    std::vector<Node> vars;
254
304
    std::vector<Node> svars;
255
152
    getNormalizedSubstitution(curr, d_vars, vars, svars, disjuncts);
256
841
    for (unsigned j = 0, dsize = disjuncts.size(); j < dsize; j++)
257
    {
258
689
      Trace("cegqi-inv-debug2") << "  apply " << disjuncts[j] << std::endl;
259
689
      disjuncts[j] = Rewriter::rewrite(disjuncts[j].substitute(
260
          vars.begin(), vars.end(), svars.begin(), svars.end()));
261
689
      Trace("cegqi-inv-debug2") << "  ..." << disjuncts[j] << std::endl;
262
    }
263
304
    std::vector<Node> const_var;
264
304
    std::vector<Node> const_subs;
265
152
    if (comp_num == 0)
266
    {
267
      // transition
268
40
      Assert(terms.find(true) != terms.end());
269
80
      Node next = terms[true];
270
40
      next = Rewriter::rewrite(next.substitute(
271
          vars.begin(), vars.end(), svars.begin(), svars.end()));
272
80
      Trace("cegqi-inv-debug")
273
40
          << "transition next predicate : " << next << std::endl;
274
      // make the primed variables if we have not already
275
40
      if (d_prime_vars.empty())
276
      {
277
302
        for (unsigned j = 0, nchild = next.getNumChildren(); j < nchild; j++)
278
        {
279
          Node v = nm->mkSkolem(
280
524
              "ir", next[j].getType(), "template inference rev argument");
281
262
          d_prime_vars.push_back(v);
282
        }
283
      }
284
      // normalize the other direction
285
40
      Trace("cegqi-inv-debug2") << "  normalize based on " << next << std::endl;
286
80
      std::vector<Node> rvars;
287
80
      std::vector<Node> rsvars;
288
40
      getNormalizedSubstitution(next, d_prime_vars, rvars, rsvars, disjuncts);
289
40
      Assert(rvars.size() == rsvars.size());
290
263
      for (unsigned j = 0, dsize = disjuncts.size(); j < dsize; j++)
291
      {
292
223
        Trace("cegqi-inv-debug2") << "  apply " << disjuncts[j] << std::endl;
293
223
        disjuncts[j] = Rewriter::rewrite(disjuncts[j].substitute(
294
            rvars.begin(), rvars.end(), rsvars.begin(), rsvars.end()));
295
223
        Trace("cegqi-inv-debug2") << "  ..." << disjuncts[j] << std::endl;
296
      }
297
40
      getConstantSubstitution(
298
          d_prime_vars, disjuncts, const_var, const_subs, false);
299
    }
300
    else
301
    {
302
112
      getConstantSubstitution(d_vars, disjuncts, const_var, const_subs, false);
303
    }
304
304
    Node res;
305
152
    if (disjuncts.empty())
306
    {
307
5
      res = nm->mkConst(false);
308
    }
309
147
    else if (disjuncts.size() == 1)
310
    {
311
38
      res = disjuncts[0];
312
    }
313
    else
314
    {
315
109
      res = nm->mkNode(OR, disjuncts);
316
    }
317
152
    if (expr::hasBoundVar(res))
318
    {
319
      Trace("cegqi-inv-debug2") << "...failed, free variable." << std::endl;
320
      d_complete = false;
321
      continue;
322
    }
323
304
    Trace("cegqi-inv") << "*** inferred "
324
407
                       << (comp_num == 1 ? "pre"
325
255
                                         : (comp_num == -1 ? "post" : "trans"))
326
152
                       << "-condition : " << res << std::endl;
327
152
    Component& c =
328
        (comp_num == 1 ? d_pre : (comp_num == -1 ? d_post : d_trans));
329
152
    c.d_conjuncts.push_back(res);
330
152
    if (!const_var.empty())
331
    {
332
81
      bool has_const_eq = const_var.size() == d_vars.size();
333
162
      Trace("cegqi-inv") << "    with constant substitution, complete = "
334
81
                         << has_const_eq << " : " << std::endl;
335
330
      for (unsigned i = 0, csize = const_var.size(); i < csize; i++)
336
      {
337
498
        Trace("cegqi-inv") << "      " << const_var[i] << " -> "
338
249
                           << const_subs[i] << std::endl;
339
249
        if (has_const_eq)
340
        {
341
71
          c.d_const_eq[res][const_var[i]] = const_subs[i];
342
        }
343
      }
344
162
      Trace("cegqi-inv") << "...size = " << const_var.size()
345
81
                         << ", #vars = " << d_vars.size() << std::endl;
346
    }
347
  }
348
349
  // finalize the components
350
288
  for (int i = -1; i <= 1; i++)
351
  {
352
216
    Component& c = (i == 1 ? d_pre : (i == -1 ? d_post : d_trans));
353
432
    Node ret;
354
216
    if (c.d_conjuncts.empty())
355
    {
356
70
      ret = nm->mkConst(true);
357
    }
358
146
    else if (c.d_conjuncts.size() == 1)
359
    {
360
142
      ret = c.d_conjuncts[0];
361
    }
362
    else
363
    {
364
4
      ret = nm->mkNode(AND, c.d_conjuncts);
365
    }
366
216
    if (i == 0 || i == 1)
367
    {
368
      // pre-condition and transition are negated
369
144
      ret = TermUtil::simpleNegate(ret);
370
    }
371
216
    c.d_this = ret;
372
  }
373
72
}
374
192
void TransitionInference::getNormalizedSubstitution(
375
    Node curr,
376
    const std::vector<Node>& pvars,
377
    std::vector<Node>& vars,
378
    std::vector<Node>& subs,
379
    std::vector<Node>& disjuncts)
380
{
381
1406
  for (unsigned j = 0, nchild = curr.getNumChildren(); j < nchild; j++)
382
  {
383
1214
    if (curr[j].getKind() == BOUND_VARIABLE)
384
    {
385
      // if the argument is a bound variable, add to the renaming
386
1148
      vars.push_back(curr[j]);
387
1148
      subs.push_back(pvars[j]);
388
    }
389
    else
390
    {
391
      // otherwise, treat as a constraint on the variable
392
      // For example, this transforms e.g. a precondition clause
393
      // I( 0, 1 ) to x1 != 0 OR x2 != 1 OR I( x1, x2 ).
394
132
      Node eq = curr[j].eqNode(pvars[j]);
395
66
      disjuncts.push_back(eq.negate());
396
    }
397
  }
398
192
}
399
400
5989
bool TransitionInference::processDisjunct(
401
    Node n,
402
    std::map<bool, Node>& terms,
403
    std::vector<Node>& disjuncts,
404
    std::map<bool, std::map<Node, bool> >& visited,
405
    bool topLevel)
406
{
407
5989
  if (visited[topLevel].find(n) != visited[topLevel].end())
408
  {
409
2376
    return true;
410
  }
411
3613
  visited[topLevel][n] = true;
412
3613
  bool childTopLevel = n.getKind() == OR && topLevel;
413
  // if another part mentions UF or a free variable, then fail
414
3613
  bool lit_pol = n.getKind() != NOT;
415
7226
  Node lit = n.getKind() == NOT ? n[0] : n;
416
  // is it an application of the function-to-synthesize? Yes if we haven't
417
  // encountered a function or if it matches the existing d_func.
418
7226
  if (lit.getKind() == APPLY_UF
419
7226
      && (d_func.isNull() || lit.getOperator() == d_func))
420
  {
421
402
    Node op = lit.getOperator();
422
    // initialize the variables
423
201
    if (d_trivial)
424
    {
425
70
      d_trivial = false;
426
70
      d_func = op;
427
70
      Trace("cegqi-inv-debug") << "Use " << op << " with args ";
428
70
      NodeManager* nm = NodeManager::currentNM();
429
498
      for (const Node& l : lit)
430
      {
431
856
        Node v = nm->mkSkolem("i", l.getType(), "template inference argument");
432
428
        d_vars.push_back(v);
433
428
        Trace("cegqi-inv-debug") << v << " ";
434
      }
435
70
      Trace("cegqi-inv-debug") << std::endl;
436
    }
437
201
    Assert(!d_func.isNull());
438
201
    if (topLevel)
439
    {
440
192
      if (terms.find(lit_pol) == terms.end())
441
      {
442
192
        terms[lit_pol] = lit;
443
192
        return true;
444
      }
445
      else
446
      {
447
        Trace("cegqi-inv-debug")
448
            << "...failed, repeated inv-app : " << lit << std::endl;
449
        return false;
450
      }
451
    }
452
18
    Trace("cegqi-inv-debug")
453
9
        << "...failed, non-entailed inv-app : " << lit << std::endl;
454
9
    return false;
455
  }
456
3412
  else if (topLevel && !childTopLevel)
457
  {
458
641
    disjuncts.push_back(n);
459
  }
460
9219
  for (const Node& nc : n)
461
  {
462
5823
    if (!processDisjunct(nc, terms, disjuncts, visited, childTopLevel))
463
    {
464
16
      return false;
465
    }
466
  }
467
3396
  return true;
468
}
469
470
45
TraceIncStatus TransitionInference::initializeTrace(DetTrace& dt,
471
                                                    Node loc,
472
                                                    bool fwd)
473
{
474
45
  Component& c = fwd ? d_pre : d_post;
475
45
  Assert(c.has(loc));
476
45
  std::map<Node, std::map<Node, Node> >::iterator it = c.d_const_eq.find(loc);
477
45
  if (it != c.d_const_eq.end())
478
  {
479
20
    std::vector<Node> next;
480
28
    for (const Node& v : d_vars)
481
    {
482
18
      Assert(it->second.find(v) != it->second.end());
483
18
      next.push_back(it->second[v]);
484
18
      dt.d_curr.push_back(it->second[v]);
485
    }
486
10
    Trace("cegqi-inv-debug2") << "dtrace : initial increment" << std::endl;
487
10
    bool ret = dt.increment(loc, next);
488
10
    AlwaysAssert(ret);
489
10
    return TRACE_INC_SUCCESS;
490
  }
491
35
  return TRACE_INC_INVALID;
492
}
493
494
529
TraceIncStatus TransitionInference::incrementTrace(DetTrace& dt,
495
                                                   Node loc,
496
                                                   bool fwd)
497
{
498
529
  Assert(d_trans.has(loc));
499
  // check if it satisfies the pre/post condition
500
1058
  Node cc = fwd ? getPostCondition() : getPreCondition();
501
529
  Assert(!cc.isNull());
502
1058
  Node ccr = Rewriter::rewrite(cc.substitute(
503
1058
      d_vars.begin(), d_vars.end(), dt.d_curr.begin(), dt.d_curr.end()));
504
529
  if (ccr.isConst())
505
  {
506
529
    if (ccr.getConst<bool>() == (fwd ? false : true))
507
    {
508
      Trace("cegqi-inv-debug2") << "dtrace : counterexample" << std::endl;
509
      return TRACE_INC_CEX;
510
    }
511
  }
512
513
  // terminates?
514
1058
  Node c = getTransitionRelation();
515
529
  Assert(!c.isNull());
516
517
529
  Assert(d_vars.size() == dt.d_curr.size());
518
1058
  Node cr = Rewriter::rewrite(c.substitute(
519
1058
      d_vars.begin(), d_vars.end(), dt.d_curr.begin(), dt.d_curr.end()));
520
529
  if (cr.isConst())
521
  {
522
5
    if (!cr.getConst<bool>())
523
    {
524
5
      Trace("cegqi-inv-debug2") << "dtrace : terminated" << std::endl;
525
5
      return TRACE_INC_TERMINATE;
526
    }
527
    return TRACE_INC_INVALID;
528
  }
529
524
  if (!fwd)
530
  {
531
    // only implemented in forward direction
532
    Assert(false);
533
    return TRACE_INC_INVALID;
534
  }
535
524
  Component& cm = d_trans;
536
524
  std::map<Node, std::map<Node, Node> >::iterator it = cm.d_const_eq.find(loc);
537
524
  if (it == cm.d_const_eq.end())
538
  {
539
    return TRACE_INC_INVALID;
540
  }
541
1048
  std::vector<Node> next;
542
1560
  for (const Node& pv : d_prime_vars)
543
  {
544
1036
    Assert(it->second.find(pv) != it->second.end());
545
2072
    Node pvs = it->second[pv];
546
1036
    Assert(d_vars.size() == dt.d_curr.size());
547
2072
    Node pvsr = Rewriter::rewrite(pvs.substitute(
548
2072
        d_vars.begin(), d_vars.end(), dt.d_curr.begin(), dt.d_curr.end()));
549
1036
    next.push_back(pvsr);
550
  }
551
524
  if (dt.increment(loc, next))
552
  {
553
524
    Trace("cegqi-inv-debug2") << "dtrace : success increment" << std::endl;
554
524
    return TRACE_INC_SUCCESS;
555
  }
556
  // looped
557
  Trace("cegqi-inv-debug2") << "dtrace : looped" << std::endl;
558
  return TRACE_INC_TERMINATE;
559
}
560
561
47
TraceIncStatus TransitionInference::initializeTrace(DetTrace& dt, bool fwd)
562
{
563
47
  Trace("cegqi-inv-debug2") << "Initialize trace" << std::endl;
564
47
  Component& c = fwd ? d_pre : d_post;
565
47
  if (c.d_conjuncts.size() == 1)
566
  {
567
45
    return initializeTrace(dt, c.d_conjuncts[0], fwd);
568
  }
569
2
  return TRACE_INC_INVALID;
570
}
571
572
529
TraceIncStatus TransitionInference::incrementTrace(DetTrace& dt, bool fwd)
573
{
574
529
  if (d_trans.d_conjuncts.size() == 1)
575
  {
576
529
    return incrementTrace(dt, d_trans.d_conjuncts[0], fwd);
577
  }
578
  return TRACE_INC_INVALID;
579
}
580
581
5
Node TransitionInference::constructFormulaTrace(DetTrace& dt) const
582
{
583
5
  return dt.constructFormula(d_vars);
584
}
585
586
}  // namespace quantifiers
587
}  // namespace theory
588
26685
}  // namespace CVC4