GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/synth_rew_rules.cpp Lines: 3 250 1.2 %
Date: 2021-05-24 Branches: 4 1002 0.4 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mathias Preiner, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * A technique for synthesizing candidate rewrites of the form t1 = t2,
14
 * where t1 and t2 are subterms of the input.
15
 */
16
17
#include "preprocessing/passes/synth_rew_rules.h"
18
19
#include <sstream>
20
21
#include "expr/sygus_datatype.h"
22
#include "expr/term_canonize.h"
23
#include "options/base_options.h"
24
#include "options/quantifiers_options.h"
25
#include "preprocessing/assertion_pipeline.h"
26
#include "printer/printer.h"
27
#include "theory/quantifiers/candidate_rewrite_database.h"
28
#include "theory/quantifiers/quantifiers_attributes.h"
29
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
30
#include "theory/quantifiers/sygus/sygus_utils.h"
31
#include "theory/quantifiers/term_util.h"
32
33
using namespace std;
34
using namespace cvc5::kind;
35
36
namespace cvc5 {
37
namespace preprocessing {
38
namespace passes {
39
40
9459
SynthRewRulesPass::SynthRewRulesPass(PreprocessingPassContext* preprocContext)
41
9459
    : PreprocessingPass(preprocContext, "synth-rr"){};
42
43
PreprocessingPassResult SynthRewRulesPass::applyInternal(
44
    AssertionPipeline* assertionsToPreprocess)
45
{
46
  Trace("srs-input") << "Synthesize rewrite rules from assertions..."
47
                     << std::endl;
48
  const std::vector<Node>& assertions = assertionsToPreprocess->ref();
49
  if (assertions.empty())
50
  {
51
    return PreprocessingPassResult::NO_CONFLICT;
52
  }
53
54
  NodeManager* nm = NodeManager::currentNM();
55
56
  // initialize the candidate rewrite
57
  std::unordered_map<TNode, bool> visited;
58
  std::unordered_map<TNode, bool>::iterator it;
59
  std::vector<TNode> visit;
60
  // Get all usable terms from the input. A term is usable if it does not
61
  // contain a quantified subterm
62
  std::vector<Node> terms;
63
  // all variables (free constants) appearing in the input
64
  std::vector<Node> vars;
65
  // does the input contain a Boolean variable?
66
  bool hasBoolVar = false;
67
  // the types of subterms of our input
68
  std::map<TypeNode, bool> typesFound;
69
  // standard constants for each type (e.g. true, false for Bool)
70
  std::map<TypeNode, std::vector<Node> > consts;
71
72
  TNode cur;
73
  Trace("srs-input") << "Collect terms in assertions..." << std::endl;
74
  for (const Node& a : assertions)
75
  {
76
    Trace("srs-input-debug") << "Assertion : " << a << std::endl;
77
    visit.push_back(a);
78
    do
79
    {
80
      cur = visit.back();
81
      visit.pop_back();
82
      it = visited.find(cur);
83
      if (it == visited.end())
84
      {
85
        Trace("srs-input-debug") << "...preprocess " << cur << std::endl;
86
        visited[cur] = false;
87
        bool isQuant = cur.isClosure();
88
        // we recurse on this node if it is not a quantified formula
89
        if (!isQuant)
90
        {
91
          visit.push_back(cur);
92
          for (const Node& cc : cur)
93
          {
94
            visit.push_back(cc);
95
          }
96
        }
97
      }
98
      else if (!it->second)
99
      {
100
        Trace("srs-input-debug") << "...postprocess " << cur << std::endl;
101
        // check if all of the children are valid
102
        // this ensures we do not register terms that have e.g. quantified
103
        // formulas as subterms
104
        bool childrenValid = true;
105
        for (const Node& cc : cur)
106
        {
107
          Assert(visited.find(cc) != visited.end());
108
          if (!visited[cc])
109
          {
110
            childrenValid = false;
111
          }
112
        }
113
        if (childrenValid)
114
        {
115
          Trace("srs-input-debug") << "...children are valid" << std::endl;
116
          Trace("srs-input-debug") << "Add term " << cur << std::endl;
117
          TypeNode tn = cur.getType();
118
          if (cur.isVar())
119
          {
120
            vars.push_back(cur);
121
            if (tn.isBoolean())
122
            {
123
              hasBoolVar = true;
124
            }
125
          }
126
          // register type information
127
          if (typesFound.find(tn) == typesFound.end())
128
          {
129
            typesFound[tn] = true;
130
            // add the standard constants for this type
131
            theory::quantifiers::CegGrammarConstructor::mkSygusConstantsForType(
132
                tn, consts[tn]);
133
            // We prepend them so that they come first in the grammar
134
            // construction. The motivation is we'd prefer seeing e.g. "true"
135
            // instead of (= x x) as a canonical term.
136
            terms.insert(terms.begin(), consts[tn].begin(), consts[tn].end());
137
          }
138
          terms.push_back(cur);
139
        }
140
        visited[cur] = childrenValid;
141
      }
142
    } while (!visit.empty());
143
  }
144
  Trace("srs-input") << "...finished." << std::endl;
145
146
  Trace("srs-input") << "Make synth variables for types..." << std::endl;
147
  // We will generate a fixed number of variables per type. These are the
148
  // variables that appear as free variables in the rewrites we generate.
149
  unsigned nvars = options::sygusRewSynthInputNVars();
150
  // must have at least one variable per type
151
  nvars = nvars < 1 ? 1 : nvars;
152
  std::map<TypeNode, std::vector<Node> > tvars;
153
  std::vector<TypeNode> allVarTypes;
154
  std::vector<Node> allVars;
155
  unsigned varCounter = 0;
156
  for (std::pair<const TypeNode, bool> tfp : typesFound)
157
  {
158
    TypeNode tn = tfp.first;
159
    // If we are not interested in purely propositional rewrites, we only
160
    // need to make one Boolean variable if the input has a Boolean variable.
161
    // This ensures that no type in our grammar has zero constructors. If
162
    // our input does not contain a Boolean variable, we need not allocate any
163
    // Boolean variables here.
164
    unsigned useNVars =
165
        (options::sygusRewSynthInputUseBool() || !tn.isBoolean())
166
            ? nvars
167
            : (hasBoolVar ? 1 : 0);
168
    for (unsigned i = 0; i < useNVars; i++)
169
    {
170
      // We must have a good name for these variables, these are
171
      // the ones output in rewrite rules. We choose
172
      // a,b,c,...,y,z,x1,x2,...
173
      std::stringstream ssv;
174
      if (varCounter < 26)
175
      {
176
        ssv << static_cast<char>(varCounter + 61);
177
      }
178
      else
179
      {
180
        ssv << "x" << (varCounter - 26);
181
      }
182
      varCounter++;
183
      Node v = nm->mkBoundVar(ssv.str(), tn);
184
      tvars[tn].push_back(v);
185
      allVars.push_back(v);
186
      allVarTypes.push_back(tn);
187
    }
188
  }
189
  Trace("srs-input") << "...finished." << std::endl;
190
191
  Trace("srs-input") << "Convert subterms to free variable form..."
192
                     << std::endl;
193
  // Replace all free variables with bound variables. This ensures that
194
  // we can perform term canonization on subterms.
195
  std::vector<Node> vsubs;
196
  for (const Node& v : vars)
197
  {
198
    TypeNode tnv = v.getType();
199
    Node vs = nm->mkBoundVar(tnv);
200
    vsubs.push_back(vs);
201
  }
202
  if (!vars.empty())
203
  {
204
    for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
205
    {
206
      terms[i] = terms[i].substitute(
207
          vars.begin(), vars.end(), vsubs.begin(), vsubs.end());
208
    }
209
  }
210
  Trace("srs-input") << "...finished." << std::endl;
211
212
  Trace("srs-input") << "Process " << terms.size() << " subterms..."
213
                     << std::endl;
214
  // We've collected all terms in the input. We construct a sygus grammar in
215
  // following which generates terms that correspond to abstractions of the
216
  // terms in the input.
217
218
  // We map terms to a canonical (ordered variable) form. This ensures that
219
  // we don't generate distinct grammar types for distinct alpha-equivalent
220
  // terms, which would produce grammars of identical shape.
221
  std::map<Node, Node> term_to_cterm;
222
  std::map<Node, Node> cterm_to_term;
223
  std::vector<Node> cterms;
224
  // canonical terms for each type
225
  std::map<TypeNode, std::vector<Node> > t_cterms;
226
  expr::TermCanonize tcanon;
227
  for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
228
  {
229
    Node n = terms[i];
230
    Node cn = tcanon.getCanonicalTerm(n);
231
    term_to_cterm[n] = cn;
232
    Trace("srs-input-debug") << "Canon : " << n << " -> " << cn << std::endl;
233
    std::map<Node, Node>::iterator itc = cterm_to_term.find(cn);
234
    if (itc == cterm_to_term.end())
235
    {
236
      cterm_to_term[cn] = n;
237
      cterms.push_back(cn);
238
      t_cterms[cn.getType()].push_back(cn);
239
    }
240
  }
241
  Trace("srs-input") << "...finished." << std::endl;
242
  // the sygus variable list
243
  Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars);
244
  Trace("srs-input") << "Have " << cterms.size() << " canonical subterms."
245
                     << std::endl;
246
247
  Trace("srs-input") << "Construct unresolved types..." << std::endl;
248
  // each canonical subterm corresponds to a grammar type
249
  std::set<TypeNode> unres;
250
  std::vector<SygusDatatype> sdts;
251
  // make unresolved types for each canonical term
252
  std::map<Node, TypeNode> cterm_to_utype;
253
  for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
254
  {
255
    Node ct = cterms[i];
256
    std::stringstream ss;
257
    ss << "T" << i;
258
    std::string tname = ss.str();
259
    TypeNode tnu = nm->mkSort(tname, NodeManager::SORT_FLAG_PLACEHOLDER);
260
    cterm_to_utype[ct] = tnu;
261
    unres.insert(tnu);
262
    sdts.push_back(SygusDatatype(tname));
263
  }
264
  Trace("srs-input") << "...finished." << std::endl;
265
266
  Trace("srs-input") << "Construct sygus datatypes..." << std::endl;
267
  for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
268
  {
269
    Node ct = cterms[i];
270
    Node t = cterm_to_term[ct];
271
272
    // add the variables for the type
273
    TypeNode ctt = ct.getType();
274
    Assert(tvars.find(ctt) != tvars.end());
275
    std::vector<TypeNode> argList;
276
    // we add variable constructors if we are not Boolean, we are interested
277
    // in purely propositional rewrites (via the option), or this term is
278
    // a Boolean variable.
279
    if (!ctt.isBoolean() || options::sygusRewSynthInputUseBool()
280
        || ct.getKind() == BOUND_VARIABLE)
281
    {
282
      for (const Node& v : tvars[ctt])
283
      {
284
        std::stringstream ssc;
285
        ssc << "C_" << i << "_" << v;
286
        sdts[i].addConstructor(v, ssc.str(), argList);
287
      }
288
    }
289
    // add the constructor for the operator if it is not a variable
290
    if (ct.getKind() != BOUND_VARIABLE)
291
    {
292
      Assert(!ct.isVar());
293
      Node op = ct.hasOperator() ? ct.getOperator() : ct;
294
      // iterate over the original term
295
      for (const Node& tc : t)
296
      {
297
        // map its arguments back to canonical
298
        Assert(term_to_cterm.find(tc) != term_to_cterm.end());
299
        Node ctc = term_to_cterm[tc];
300
        Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end());
301
        // get the type
302
        argList.push_back(cterm_to_utype[ctc]);
303
      }
304
      // check if we should chain
305
      bool do_chain = false;
306
      if (argList.size() > 2)
307
      {
308
        Kind k = NodeManager::operatorToKind(op);
309
        do_chain = theory::quantifiers::TermUtil::isAssoc(k)
310
                   && theory::quantifiers::TermUtil::isComm(k);
311
        // eliminate duplicate child types
312
        std::vector<TypeNode> argListTmp = argList;
313
        argList.clear();
314
        std::map<TypeNode, bool> hasArgType;
315
        for (unsigned j = 0, size = argListTmp.size(); j < size; j++)
316
        {
317
          TypeNode tn = argListTmp[j];
318
          if (hasArgType.find(tn) == hasArgType.end())
319
          {
320
            hasArgType[tn] = true;
321
            argList.push_back(tn);
322
          }
323
        }
324
      }
325
      if (do_chain)
326
      {
327
        // we make one type per child
328
        // the operator of each constructor is a no-op
329
        Node tbv = nm->mkBoundVar(ctt);
330
        Node lambdaOp =
331
            nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
332
        std::vector<TypeNode> argListc;
333
        // the following construction admits any number of repeated factors,
334
        // so for instance, t1+t2+t3, we generate the grammar:
335
        // T_{t1+t2+t3} ->
336
        //   +( T_{t1+t2+t3}, T_{t1+t2+t3} ) | T_{t1} | T_{t2} | T_{t3}
337
        // where we write T_t to denote "the type that abstracts term t".
338
        // Notice this construction allows to abstract subsets of the factors
339
        // of t1+t2+t3. This is particularly helpful for terms t1+...+tn for
340
        // large n, where we would like to consider binary applications of +.
341
        for (unsigned j = 0, size = argList.size(); j < size; j++)
342
        {
343
          argListc.clear();
344
          argListc.push_back(argList[j]);
345
          std::stringstream sscs;
346
          sscs << "C_factor_" << i << "_" << j;
347
          // ID function is not printed and does not count towards weight
348
          sdts[i].addConstructor(lambdaOp,
349
                                 sscs.str(),
350
                                 argListc,
351
                                 0);
352
        }
353
        // recursive apply
354
        TypeNode recType = cterm_to_utype[ct];
355
        argListc.clear();
356
        argListc.push_back(recType);
357
        argListc.push_back(recType);
358
        std::stringstream ssc;
359
        ssc << "C_" << i << "_rec_" << op;
360
        sdts[i].addConstructor(op, ssc.str(), argListc);
361
      }
362
      else
363
      {
364
        std::stringstream ssc;
365
        ssc << "C_" << i << "_" << op;
366
        sdts[i].addConstructor(op, ssc.str(), argList);
367
      }
368
    }
369
    Assert(sdts[i].getNumConstructors() > 0);
370
    sdts[i].initializeDatatype(ctt, sygusVarList, false, false);
371
  }
372
  Trace("srs-input") << "...finished." << std::endl;
373
374
  Trace("srs-input") << "Make mutual datatype types for subterms..."
375
                     << std::endl;
376
  // extract the datatypes
377
  std::vector<DType> datatypes;
378
  for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
379
  {
380
    datatypes.push_back(sdts[i].getDatatype());
381
  }
382
  std::vector<TypeNode> types = nm->mkMutualDatatypeTypes(
383
      datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER);
384
  Trace("srs-input") << "...finished." << std::endl;
385
  Assert(types.size() == unres.size());
386
  std::map<Node, TypeNode> subtermTypes;
387
  for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
388
  {
389
    subtermTypes[cterms[i]] = types[i];
390
  }
391
392
  Trace("srs-input") << "Construct the top-level types..." << std::endl;
393
  // we now are ready to create the "top-level" types
394
  std::map<TypeNode, TypeNode> tlGrammarTypes;
395
  for (std::pair<const TypeNode, std::vector<Node> >& tcp : t_cterms)
396
  {
397
    TypeNode t = tcp.first;
398
    std::stringstream ss;
399
    ss << "T_" << t;
400
    SygusDatatype sdttl(ss.str());
401
    Node tbv = nm->mkBoundVar(t);
402
    // the operator of each constructor is a no-op
403
    Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
404
    Trace("srs-input") << "  We have " << tcp.second.size()
405
                       << " subterms of type " << t << std::endl;
406
    for (unsigned i = 0, size = tcp.second.size(); i < size; i++)
407
    {
408
      Node n = tcp.second[i];
409
      // add constructor that encodes abstractions of this subterm
410
      std::vector<TypeNode> argList;
411
      Assert(subtermTypes.find(n) != subtermTypes.end());
412
      argList.push_back(subtermTypes[n]);
413
      std::stringstream ssc;
414
      ssc << "Ctl_" << i;
415
      // the no-op should not be printed, hence we pass an empty callback
416
      sdttl.addConstructor(lambdaOp,
417
                           ssc.str(),
418
                           argList,
419
                           0);
420
      Trace("srs-input-debug")
421
          << "Grammar for subterm " << n << " is: " << std::endl;
422
      Trace("srs-input-debug") << subtermTypes[n].getDType() << std::endl;
423
    }
424
    // set that this is a sygus datatype
425
    sdttl.initializeDatatype(t, sygusVarList, false, false);
426
    DType dttl = sdttl.getDatatype();
427
    TypeNode tlt =
428
        nm->mkDatatypeType(dttl, NodeManager::DATATYPE_FLAG_PLACEHOLDER);
429
    tlGrammarTypes[t] = tlt;
430
    Trace("srs-input") << "Grammar is: " << std::endl;
431
    Trace("srs-input") << tlt.getDType() << std::endl;
432
  }
433
  Trace("srs-input") << "...finished." << std::endl;
434
435
  // sygus attribute to mark the conjecture as a sygus conjecture
436
  Trace("srs-input") << "Make sygus conjecture..." << std::endl;
437
  // we are "synthesizing" functions for each type of subterm
438
  std::vector<Node> synthConj;
439
  unsigned fCounter = 1;
440
  theory::SygusSynthGrammarAttribute ssg;
441
  for (std::pair<const TypeNode, TypeNode> ttp : tlGrammarTypes)
442
  {
443
    Node gvar = nm->mkBoundVar("sfproxy", ttp.second);
444
    TypeNode ft = nm->mkFunctionType(allVarTypes, ttp.first);
445
    // likewise, it is helpful if these have good names, we choose f1, f2, ...
446
    std::stringstream ssf;
447
    ssf << "f" << fCounter;
448
    fCounter++;
449
    Node sfun = nm->mkBoundVar(ssf.str(), ft);
450
    // this marks that the grammar used for solutions for sfun is the type of
451
    // gvar, which is the sygus datatype type constructed above.
452
    sfun.setAttribute(ssg, gvar);
453
454
    Node body = nm->mkConst(false);
455
    body = theory::quantifiers::SygusUtils::mkSygusConjecture({sfun}, body);
456
    synthConj.push_back(body);
457
  }
458
  Node trueNode = nm->mkConst(true);
459
  Node res =
460
      synthConj.empty()
461
          ? trueNode
462
          : (synthConj.size() == 1 ? synthConj[0] : nm->mkNode(AND, synthConj));
463
464
  Trace("srs-input") << "got : " << res << std::endl;
465
  Trace("srs-input") << "...finished." << std::endl;
466
467
  assertionsToPreprocess->replace(0, res);
468
  for (unsigned i = 1, size = assertionsToPreprocess->size(); i < size; ++i)
469
  {
470
    assertionsToPreprocess->replace(i, trueNode);
471
  }
472
473
  return PreprocessingPassResult::NO_CONFLICT;
474
}
475
476
}  // namespace passes
477
}  // namespace preprocessing
478
28191
}  // namespace cvc5