GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus/sygus_reconstruct.cpp Lines: 183 203 90.1 %
Date: 2021-11-07 Branches: 394 782 50.4 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Abdalrhman Mohamed, Andrew Reynolds, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation for reconstruct.
14
 */
15
16
#include "theory/quantifiers/sygus/sygus_reconstruct.h"
17
18
#include "expr/node_algorithm.h"
19
#include "smt/command.h"
20
#include "theory/datatypes/sygus_datatype_utils.h"
21
#include "theory/rewriter.h"
22
23
using namespace cvc5::kind;
24
25
namespace cvc5 {
26
namespace theory {
27
namespace quantifiers {
28
29
1900
SygusReconstruct::SygusReconstruct(Env& env,
30
                                   TermDbSygus* tds,
31
1900
                                   SygusStatistics& s)
32
1900
    : EnvObj(env), d_tds(tds), d_stats(s)
33
{
34
1900
}
35
36
28
Node SygusReconstruct::reconstructSolution(Node sol,
37
                                           TypeNode stn,
38
                                           int8_t& reconstructed,
39
                                           uint64_t enumLimit)
40
{
41
56
  Trace("sygus-rcons") << "SygusReconstruct::reconstructSolution: " << sol
42
28
                       << std::endl;
43
28
  Trace("sygus-rcons") << "- target sygus type is " << stn << std::endl;
44
28
  Trace("sygus-rcons") << "- enumeration limit is " << enumLimit << std::endl;
45
46
  // this method may get called multiple times with the same object. We need to
47
  // reset the state to avoid conflicts
48
28
  clear();
49
50
28
  initialize(stn);
51
52
  /** a set of builtin terms to reconstruct satisfied for each sygus datatype */
53
56
  TypeBuiltinSetMap termsToRecons;
54
55
  // add the main obligation to the set of obligations
56
  // paramaters stn and sol constitute the main obligation to satisfy
57
28
  d_obs.push_back(std::make_unique<RConsObligation>(stn, sol));
58
28
  termsToRecons[stn].emplace(sol);
59
28
  d_stnInfo[stn].setBuiltinToOb(sol, d_obs[0].get());
60
28
  RConsObligation* ob0 = d_obs[0].get();
61
56
  Node k0 = ob0->getSkolem();
62
63
  // We need to add the main obligation to the crd in case it cannot be broken
64
  // down by matching. By doing so, we can solve the obligation using
65
  // enumeration and crd (if it is in the grammar)
66
28
  d_stnInfo[stn].addTerm(sol);
67
68
  // the set of unique (up to rewriting) patterns/shapes in the grammar used by
69
  // matching
70
56
  std::unordered_map<TypeNode, std::vector<Node>> pool;
71
72
28
  uint64_t count = 0;
73
74
  // algorithm
75
20662
  while (d_sol[k0].isNull() && count < enumLimit)
76
  {
77
    // enumeration phase
78
    // a temporary set of new terms to reconstruct cached for processing in the
79
    // match phase
80
20634
    TypeBuiltinSetMap termsToReconsPrime;
81
22005
    for (const std::pair<const TypeNode, BuiltinSet>& pair : termsToRecons)
82
    {
83
      // enumerate a new term
84
11688
      Trace("sygus-rcons") << "enum: " << stn << ": ";
85
16396
      Node sz = d_stnInfo[pair.first].nextEnum();
86
11688
      if (sz.isNull())
87
      {
88
6293
        continue;
89
      }
90
10103
      Node builtin = rewrite(datatypes::utils::sygusToBuiltin(sz));
91
      // if enumerated term does not contain free variables, then its
92
      // corresponding obligation can be solved immediately
93
5395
      if (sz.isConst())
94
      {
95
1374
        Node rep = d_stnInfo[pair.first].addTerm(builtin);
96
687
        RConsObligation* ob = d_stnInfo[pair.first].builtinToOb(rep);
97
        // check if the enumerated term solves an obligation
98
687
        if (ob == nullptr)
99
        {
100
          // if not, create an "artifical" obligation whose solution would be
101
          // the enumerated term
102
1316
          d_obs.push_back(
103
1316
              std::make_unique<RConsObligation>(pair.first, builtin));
104
658
          d_stnInfo[pair.first].setBuiltinToOb(builtin, d_obs.back().get());
105
658
          ob = d_obs.back().get();
106
        }
107
        // mark the obligation as solved
108
687
        markSolved(ob, sz);
109
        // Since we added the term to the candidate rewrite database, there is
110
        // no point in adding it to the pool too
111
687
        continue;
112
      }
113
      // if there are no matches (all calls to notify return true)...
114
4708
      if (d_poolTrie.getMatches(builtin, this))
115
      {
116
        // then, this is a new term and we should add it to pool
117
472
        d_poolTrie.addTerm(builtin);
118
472
        pool[pair.first].push_back(sz);
119
1379
        for (const Node& t : pair.second)
120
        {
121
907
          RConsObligation* ob = d_stnInfo[pair.first].builtinToOb(t);
122
907
          if (d_sol[ob->getSkolem()].isNull())
123
          {
124
907
            Trace("sygus-rcons") << "ob: " << *ob << std::endl;
125
            // try to match builtin term `t` with the enumerated term sz
126
1814
            TypeBuiltinSetMap temp = matchNewObs(t, sz);
127
            // cache the new obligations for processing in the match phase
128
928
            for (const std::pair<const TypeNode, BuiltinSet>& tempPair : temp)
129
            {
130
21
              termsToReconsPrime[tempPair.first].insert(
131
                  tempPair.second.cbegin(), tempPair.second.cend());
132
            }
133
          }
134
        }
135
      }
136
    }
137
    // match phase
138
10357
    while (!termsToReconsPrime.empty())
139
    {
140
      // a temporary set of new terms to reconstruct cached for later processing
141
40
      TypeBuiltinSetMap obsDPrime;
142
24
      for (const std::pair<const TypeNode, BuiltinSet>& pair :
143
20
           termsToReconsPrime)
144
      {
145
58
        for (const Node& t : pair.second)
146
        {
147
34
          termsToRecons[pair.first].emplace(t);
148
34
          RConsObligation* ob = d_stnInfo[pair.first].builtinToOb(t);
149
34
          if (d_sol[ob->getSkolem()].isNull())
150
          {
151
34
            Trace("sygus-rcons") << "ob: " << *ob << std::endl;
152
314
            for (const Node& sz : pool[pair.first])
153
            {
154
              // try to match each newly generated and cached term with patterns
155
              // in pool
156
560
              TypeBuiltinSetMap temp = matchNewObs(t, sz);
157
              // cache the new terms for later processing
158
285
              for (const std::pair<const TypeNode, BuiltinSet>& tempPair : temp)
159
              {
160
5
                obsDPrime[tempPair.first].insert(tempPair.second.cbegin(),
161
                                                 tempPair.second.cend());
162
              }
163
            }
164
          }
165
        }
166
      }
167
20
      termsToReconsPrime = std::move(obsDPrime);
168
    }
169
    // remove reconstructed terms from termsToRecons
170
10317
    removeReconstructedTerms(termsToRecons);
171
10317
    ++count;
172
  }
173
174
28
  if (Trace("sygus-rcons").isConnected())
175
  {
176
    RConsObligation::printCandSols(ob0, d_obs);
177
    printPool(pool);
178
  }
179
180
  // if the main obligation is solved, return the solution
181
28
  if (!d_sol[k0].isNull())
182
  {
183
28
    reconstructed = 1;
184
    // The algorithm mostly works with rewritten terms and may not notice that
185
    // the original terms contain variables eliminated by the rewriter. For
186
    // example, rewrite((ite true 0 z)) = 0. In such cases, we replace those
187
    // variables with ground values.
188
28
    return d_sol[k0].isConst() ? Node(d_sol[k0]) : mkGround(d_sol[k0]);
189
  }
190
191
  // we ran out of elements, return null
192
  reconstructed = -1;
193
  warning() << CommandFailure(
194
      "Cannot get synth function: reconstruction to syntax failed.");
195
  return Node::null();
196
}
197
198
1187
TypeBuiltinSetMap SygusReconstruct::matchNewObs(Node t, Node sz)
199
{
200
1187
  TypeBuiltinSetMap termsToReconsPrime;
201
202
  // substitutions generated by match. Note that we might have already seen (and
203
  // even solved) obligations corresponsing to those substitutions
204
2374
  NodePairMap matches;
205
  // the builtin terms corresponding to sygus variables in the grammar are bound
206
  // variables. However, we want the `match` method to treat them as ground
207
  // terms. So, we add redundant substitutions
208
1187
  matches.insert(d_sygusVars.cbegin(), d_sygusVars.cend());
209
210
  // try to match the builtin term with the pattern sz
211
1187
  if (expr::match(rewrite(datatypes::utils::sygusToBuiltin(sz)), t, matches))
212
  {
213
    // the bound variables z generated by the enumerators are reused across
214
    // enumerated terms, so we need to replace them with our own skolems
215
94
    NodePairMap subs;
216
47
    Trace("sygus-rcons") << "-- ct: " << sz << std::endl;
217
    // remove redundant substitutions
218
164
    for (const std::pair<const Node, Node>& pair : d_sygusVars)
219
    {
220
117
      matches.erase(pair.first);
221
    }
222
    // for each match
223
126
    for (const std::pair<const Node, Node>& match : matches)
224
    {
225
158
      TypeNode stn = datatypes::utils::builtinVarToSygus(match.first).getType();
226
      RConsObligation* newOb;
227
      // did we come across an equivalent obligation before?
228
158
      Node rep = d_stnInfo[stn].addTerm(match.second);
229
79
      RConsObligation* repOb = d_stnInfo[stn].builtinToOb(rep);
230
79
      if (repOb != nullptr)
231
      {
232
        // if so, use the original obligation
233
45
        newOb = repOb;
234
        // while `match.second` is equivalent to `rep`, it may be easier to
235
        // reconstruct than `rep`. For example:
236
        //
237
        // Grammar: S -> p | q | (not S) | (and S S) | (or S S)
238
        // rep = (= p q)
239
        // match.second = (or (and p q) (and (not p) (not q)))
240
        //
241
        // In this case, `match.second` is easy to reconstruct by matching
242
        // because it only uses operators that are already in the grammar.
243
        // `rep`, on the other hand, cannot be reconstructed by matching and
244
        // can only be solved by enumeration (currently).
245
        //
246
        // At this point, we do not know which one is easier to reconstruct by
247
        // matching, so we add `match.second` to the set of equivalent builtin
248
        // terms in `repOb` and match against both terms.
249
135
        if (repOb->getBuiltins().find(match.second)
250
135
            == repOb->getBuiltins().cend())
251
        {
252
1
          repOb->addBuiltin(match.second);
253
1
          d_stnInfo[stn].setBuiltinToOb(match.second, repOb);
254
1
          termsToReconsPrime[stn].emplace(match.second);
255
        }
256
      }
257
      else
258
      {
259
        // otherwise, create a new obligation of the corresponding sygus type
260
34
        d_obs.push_back(std::make_unique<RConsObligation>(stn, match.second));
261
34
        d_stnInfo[stn].setBuiltinToOb(match.second, d_obs.back().get());
262
34
        newOb = d_obs.back().get();
263
        // if the match is a constant and the grammar allows random constants
264
34
        if (match.second.isConst() && stn.getDType().getSygusAllowConst())
265
        {
266
          // then immediately solve the obligation
267
1
          markSolved(newOb, d_tds->getProxyVariable(stn, match.second));
268
        }
269
        else
270
        {
271
          // otherwise, add this match to the list of obligations
272
33
          termsToReconsPrime[stn].emplace(match.second);
273
        }
274
      }
275
79
      subs.emplace(datatypes::utils::builtinVarToSygus(match.first),
276
158
                   newOb->getSkolem());
277
    }
278
    // replace original free vars in sz with new ones
279
47
    if (!subs.empty())
280
    {
281
42
      sz = sz.substitute(subs.cbegin(), subs.cend());
282
    }
283
    // sz is solved if it has no sub-obligations or if all of them are solved
284
47
    bool isSolved = true;
285
126
    for (const std::pair<const Node, Node>& match : matches)
286
    {
287
158
      TypeNode stn = datatypes::utils::builtinVarToSygus(match.first).getType();
288
79
      RConsObligation* ob = d_stnInfo[stn].builtinToOb(match.second);
289
79
      if (d_sol[ob->getSkolem()].isNull())
290
      {
291
35
        isSolved = false;
292
35
        d_subObs[sz].push_back(ob);
293
      }
294
    }
295
296
47
    RConsObligation* ob = d_stnInfo[sz.getType()].builtinToOb(t);
297
298
47
    if (isSolved)
299
    {
300
      // As it traverses sz, substitute populates its input cache with TNodes
301
      // that are not preserved by this module and maybe destroyed after the
302
      // method call. To avoid referencing those unsafe TNodes throughout this
303
      // module, we pass a iterators of d_sol instead.
304
48
      Node s = sz.substitute(d_sol.cbegin(), d_sol.cend());
305
24
      markSolved(ob, s);
306
    }
307
    else
308
    {
309
      // add sz as a possible solution to ob
310
23
      ob->addCandidateSolution(sz);
311
23
      d_parentOb[sz] = ob;
312
23
      d_subObs[sz].back()->addCandidateSolutionToWatchSet(sz);
313
    }
314
  }
315
316
2374
  return termsToReconsPrime;
317
}
318
319
712
void SygusReconstruct::markSolved(RConsObligation* ob, Node s)
320
{
321
  // return if ob is already solved
322
712
  if (!d_sol[ob->getSkolem()].isNull())
323
  {
324
18
    return;
325
  }
326
327
694
  Trace("sygus-rcons") << "sol: " << s << std::endl;
328
1388
  Trace("sygus-rcons") << "builtin sol: " << datatypes::utils::sygusToBuiltin(s)
329
694
                       << std::endl;
330
331
  // First, mark ob as solved
332
694
  ob->addCandidateSolution(s);
333
694
  d_sol[ob->getSkolem()] = s;
334
694
  d_parentOb[s] = ob;
335
336
1388
  std::vector<RConsObligation*> stack;
337
694
  stack.push_back(ob);
338
339
2118
  while (!stack.empty())
340
  {
341
712
    RConsObligation* curr = stack.back();
342
712
    stack.pop_back();
343
344
    // for each partial solution/parent of the now solved obligation `curr`
345
734
    for (const Node& parent : curr->getWatchSet())
346
    {
347
      // remove `curr` and (possibly) other solved obligations from its list
348
      // of children
349
76
      while (!d_subObs[parent].empty()
350
98
             && !d_sol[d_subObs[parent].back()->getSkolem()].isNull())
351
      {
352
27
        d_subObs[parent].pop_back();
353
      }
354
355
      // if the partial solution does not have any more children...
356
22
      if (d_subObs[parent].empty())
357
      {
358
        // then it is completely solved and can be used as a solution of its
359
        // corresponding obligation
360
        // pass iterators of d_sol to avoid populating it with unsafe TNodes
361
36
        Node parentSol = parent.substitute(d_sol.cbegin(), d_sol.cend());
362
18
        RConsObligation* parentOb = d_parentOb[parent];
363
        // proceed only if parent obligation is not already solved
364
18
        if (d_sol[parentOb->getSkolem()].isNull())
365
        {
366
18
          parentOb->addCandidateSolution(parentSol);
367
18
          d_sol[parentOb->getSkolem()] = parentSol;
368
18
          d_parentOb[parentSol] = parentOb;
369
          // repeat the same process for the parent obligation
370
18
          stack.push_back(parentOb);
371
        }
372
      }
373
      else
374
      {
375
        // if it does have remaining children, add it to the watch list of one
376
        // of them (picked arbitrarily)
377
4
        d_subObs[parent].back()->addCandidateSolutionToWatchSet(parent);
378
      }
379
    }
380
  }
381
}
382
383
28
void SygusReconstruct::initialize(TypeNode stn)
384
{
385
56
  std::vector<Node> builtinVars;
386
387
  // Cache the sygus variables introduced by the problem (which we treat as
388
  // ground terms when calling the `match` method) as opposed to the sygus
389
  // variables introduced by the enumerators (which we treat as bound
390
  // variables).
391
83
  for (Node sv : stn.getDType().getSygusVarList())
392
  {
393
55
    builtinVars.push_back(datatypes::utils::sygusToBuiltin(sv));
394
55
    d_sygusVars.emplace(datatypes::utils::sygusToBuiltin(sv),
395
110
                        datatypes::utils::sygusToBuiltin(sv));
396
  }
397
398
56
  SygusTypeInfo stnInfo;
399
28
  stnInfo.initialize(d_tds, stn);
400
401
  // find the non-terminals of the grammar
402
56
  std::vector<TypeNode> sfTypes;
403
28
  stnInfo.getSubfieldTypes(sfTypes);
404
405
  // initialize the enumerators and candidate rewrite databases. Notice here
406
  // that we treat the sygus variables introduced by the problem as bound
407
  // variables (needed for making sure that obligations are equal). This is fine
408
  // as we will never add variables that were introduced by the enumerators to
409
  // the database.
410
85
  for (TypeNode tn : sfTypes)
411
  {
412
57
    d_stnInfo[tn].initialize(d_env, d_tds, d_stats, tn, builtinVars);
413
  }
414
28
}
415
416
10317
void SygusReconstruct::removeReconstructedTerms(
417
    TypeBuiltinSetMap& termsToRecons)
418
{
419
22012
  for (std::pair<const TypeNode, BuiltinSet>& pair : termsToRecons)
420
  {
421
11695
    BuiltinSet::iterator it = pair.second.begin();
422
39785
    while (it != pair.second.end())
423
    {
424
14045
      if (d_sol[d_stnInfo[pair.first].builtinToOb(*it)->getSkolem()].isNull())
425
      {
426
13991
        ++it;
427
      }
428
      else
429
      {
430
54
        it = pair.second.erase(it);
431
      }
432
    }
433
  }
434
10317
}
435
436
10
Node SygusReconstruct::mkGround(Node n) const
437
{
438
  // get the set of bound variables in n
439
20
  std::unordered_set<TNode> vars;
440
10
  expr::getVariables(n, vars);
441
442
20
  std::unordered_map<TNode, TNode> subs;
443
444
  // generate a ground value for each one of those variables
445
20
  for (const TNode& var : vars)
446
  {
447
10
    subs.emplace(var, var.getType().mkGroundValue());
448
  }
449
450
  // substitute the variables with ground values
451
20
  return n.substitute(subs);
452
}
453
454
4371
bool SygusReconstruct::notify(Node s,
455
                              Node n,
456
                              std::vector<Node>& vars,
457
                              std::vector<Node>& subs)
458
{
459
14758
  for (size_t i = 0; i < vars.size(); ++i)
460
  {
461
    // We consider sygus variables as ground terms. So, if they are not equal to
462
    // their substitution, then s is not matchable with n and we try the next
463
    // term s. Example: If s = (+ z x) and n = (+ z y), then s is not matchable
464
    // with n and we return true
465
10522
    if (d_sygusVars.find(vars[i]) != d_sygusVars.cend() && vars[i] != subs[i])
466
    {
467
135
      return true;
468
    }
469
  }
470
  // Note: false here means that we finally found an s that is matchable with n,
471
  // so we should not add n to the pool
472
4236
  return false;
473
}
474
475
28
void SygusReconstruct::clear()
476
{
477
28
  d_obs.clear();
478
28
  d_stnInfo.clear();
479
28
  d_sol.clear();
480
28
  d_subObs.clear();
481
28
  d_parentOb.clear();
482
28
  d_sygusVars.clear();
483
28
  d_poolTrie.clear();
484
28
}
485
486
void SygusReconstruct::printPool(
487
    const std::unordered_map<TypeNode, std::vector<Node>>& pool) const
488
{
489
  Trace("sygus-rcons") << std::endl << "Pool:" << std::endl << '{';
490
491
  for (const std::pair<const TypeNode, std::vector<Node>>& pair : pool)
492
  {
493
    Trace("sygus-rcons") << std::endl
494
                         << "  " << pair.first << ':' << std::endl
495
                         << "  [" << std::endl;
496
497
    for (const Node& sygusTerm : pair.second)
498
    {
499
      Trace("sygus-rcons")
500
          << "    "
501
          << rewrite(datatypes::utils::sygusToBuiltin(sygusTerm)).toString()
502
          << std::endl;
503
    }
504
505
    Trace("sygus-rcons") << "  ]" << std::endl;
506
  }
507
508
  Trace("sygus-rcons") << '}' << std::endl;
509
}
510
511
}  // namespace quantifiers
512
}  // namespace theory
513
31137
}  // namespace cvc5