GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/synth_rew_rules.cpp Lines: 3 250 1.2 %
Date: 2021-03-22 Branches: 4 1002 0.4 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file synth_rew_rules.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Mathias Preiner
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief A technique for synthesizing candidate rewrites of the form t1 = t2,
13
 ** where t1 and t2 are subterms of the input.
14
 **/
15
16
#include "preprocessing/passes/synth_rew_rules.h"
17
18
#include <sstream>
19
20
#include "expr/sygus_datatype.h"
21
#include "expr/term_canonize.h"
22
#include "options/base_options.h"
23
#include "options/quantifiers_options.h"
24
#include "preprocessing/assertion_pipeline.h"
25
#include "printer/printer.h"
26
#include "theory/quantifiers/candidate_rewrite_database.h"
27
#include "theory/quantifiers/quantifiers_attributes.h"
28
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
29
#include "theory/quantifiers/sygus/sygus_utils.h"
30
#include "theory/quantifiers/term_util.h"
31
32
using namespace std;
33
using namespace CVC4::kind;
34
35
namespace CVC4 {
36
namespace preprocessing {
37
namespace passes {
38
39
8995
SynthRewRulesPass::SynthRewRulesPass(PreprocessingPassContext* preprocContext)
40
8995
    : PreprocessingPass(preprocContext, "synth-rr"){};
41
42
PreprocessingPassResult SynthRewRulesPass::applyInternal(
43
    AssertionPipeline* assertionsToPreprocess)
44
{
45
  Trace("srs-input") << "Synthesize rewrite rules from assertions..."
46
                     << std::endl;
47
  const std::vector<Node>& assertions = assertionsToPreprocess->ref();
48
  if (assertions.empty())
49
  {
50
    return PreprocessingPassResult::NO_CONFLICT;
51
  }
52
53
  NodeManager* nm = NodeManager::currentNM();
54
55
  // initialize the candidate rewrite
56
  std::unordered_map<TNode, bool, TNodeHashFunction> visited;
57
  std::unordered_map<TNode, bool, TNodeHashFunction>::iterator it;
58
  std::vector<TNode> visit;
59
  // Get all usable terms from the input. A term is usable if it does not
60
  // contain a quantified subterm
61
  std::vector<Node> terms;
62
  // all variables (free constants) appearing in the input
63
  std::vector<Node> vars;
64
  // does the input contain a Boolean variable?
65
  bool hasBoolVar = false;
66
  // the types of subterms of our input
67
  std::map<TypeNode, bool> typesFound;
68
  // standard constants for each type (e.g. true, false for Bool)
69
  std::map<TypeNode, std::vector<Node> > consts;
70
71
  TNode cur;
72
  Trace("srs-input") << "Collect terms in assertions..." << std::endl;
73
  for (const Node& a : assertions)
74
  {
75
    Trace("srs-input-debug") << "Assertion : " << a << std::endl;
76
    visit.push_back(a);
77
    do
78
    {
79
      cur = visit.back();
80
      visit.pop_back();
81
      it = visited.find(cur);
82
      if (it == visited.end())
83
      {
84
        Trace("srs-input-debug") << "...preprocess " << cur << std::endl;
85
        visited[cur] = false;
86
        bool isQuant = cur.isClosure();
87
        // we recurse on this node if it is not a quantified formula
88
        if (!isQuant)
89
        {
90
          visit.push_back(cur);
91
          for (const Node& cc : cur)
92
          {
93
            visit.push_back(cc);
94
          }
95
        }
96
      }
97
      else if (!it->second)
98
      {
99
        Trace("srs-input-debug") << "...postprocess " << cur << std::endl;
100
        // check if all of the children are valid
101
        // this ensures we do not register terms that have e.g. quantified
102
        // formulas as subterms
103
        bool childrenValid = true;
104
        for (const Node& cc : cur)
105
        {
106
          Assert(visited.find(cc) != visited.end());
107
          if (!visited[cc])
108
          {
109
            childrenValid = false;
110
          }
111
        }
112
        if (childrenValid)
113
        {
114
          Trace("srs-input-debug") << "...children are valid" << std::endl;
115
          Trace("srs-input-debug") << "Add term " << cur << std::endl;
116
          TypeNode tn = cur.getType();
117
          if (cur.isVar())
118
          {
119
            vars.push_back(cur);
120
            if (tn.isBoolean())
121
            {
122
              hasBoolVar = true;
123
            }
124
          }
125
          // register type information
126
          if (typesFound.find(tn) == typesFound.end())
127
          {
128
            typesFound[tn] = true;
129
            // add the standard constants for this type
130
            theory::quantifiers::CegGrammarConstructor::mkSygusConstantsForType(
131
                tn, consts[tn]);
132
            // We prepend them so that they come first in the grammar
133
            // construction. The motivation is we'd prefer seeing e.g. "true"
134
            // instead of (= x x) as a canonical term.
135
            terms.insert(terms.begin(), consts[tn].begin(), consts[tn].end());
136
          }
137
          terms.push_back(cur);
138
        }
139
        visited[cur] = childrenValid;
140
      }
141
    } while (!visit.empty());
142
  }
143
  Trace("srs-input") << "...finished." << std::endl;
144
145
  Trace("srs-input") << "Make synth variables for types..." << std::endl;
146
  // We will generate a fixed number of variables per type. These are the
147
  // variables that appear as free variables in the rewrites we generate.
148
  unsigned nvars = options::sygusRewSynthInputNVars();
149
  // must have at least one variable per type
150
  nvars = nvars < 1 ? 1 : nvars;
151
  std::map<TypeNode, std::vector<Node> > tvars;
152
  std::vector<TypeNode> allVarTypes;
153
  std::vector<Node> allVars;
154
  unsigned varCounter = 0;
155
  for (std::pair<const TypeNode, bool> tfp : typesFound)
156
  {
157
    TypeNode tn = tfp.first;
158
    // If we are not interested in purely propositional rewrites, we only
159
    // need to make one Boolean variable if the input has a Boolean variable.
160
    // This ensures that no type in our grammar has zero constructors. If
161
    // our input does not contain a Boolean variable, we need not allocate any
162
    // Boolean variables here.
163
    unsigned useNVars =
164
        (options::sygusRewSynthInputUseBool() || !tn.isBoolean())
165
            ? nvars
166
            : (hasBoolVar ? 1 : 0);
167
    for (unsigned i = 0; i < useNVars; i++)
168
    {
169
      // We must have a good name for these variables, these are
170
      // the ones output in rewrite rules. We choose
171
      // a,b,c,...,y,z,x1,x2,...
172
      std::stringstream ssv;
173
      if (varCounter < 26)
174
      {
175
        ssv << static_cast<char>(varCounter + 61);
176
      }
177
      else
178
      {
179
        ssv << "x" << (varCounter - 26);
180
      }
181
      varCounter++;
182
      Node v = nm->mkBoundVar(ssv.str(), tn);
183
      tvars[tn].push_back(v);
184
      allVars.push_back(v);
185
      allVarTypes.push_back(tn);
186
    }
187
  }
188
  Trace("srs-input") << "...finished." << std::endl;
189
190
  Trace("srs-input") << "Convert subterms to free variable form..."
191
                     << std::endl;
192
  // Replace all free variables with bound variables. This ensures that
193
  // we can perform term canonization on subterms.
194
  std::vector<Node> vsubs;
195
  for (const Node& v : vars)
196
  {
197
    TypeNode tnv = v.getType();
198
    Node vs = nm->mkBoundVar(tnv);
199
    vsubs.push_back(vs);
200
  }
201
  if (!vars.empty())
202
  {
203
    for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
204
    {
205
      terms[i] = terms[i].substitute(
206
          vars.begin(), vars.end(), vsubs.begin(), vsubs.end());
207
    }
208
  }
209
  Trace("srs-input") << "...finished." << std::endl;
210
211
  Trace("srs-input") << "Process " << terms.size() << " subterms..."
212
                     << std::endl;
213
  // We've collected all terms in the input. We construct a sygus grammar in
214
  // following which generates terms that correspond to abstractions of the
215
  // terms in the input.
216
217
  // We map terms to a canonical (ordered variable) form. This ensures that
218
  // we don't generate distinct grammar types for distinct alpha-equivalent
219
  // terms, which would produce grammars of identical shape.
220
  std::map<Node, Node> term_to_cterm;
221
  std::map<Node, Node> cterm_to_term;
222
  std::vector<Node> cterms;
223
  // canonical terms for each type
224
  std::map<TypeNode, std::vector<Node> > t_cterms;
225
  expr::TermCanonize tcanon;
226
  for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
227
  {
228
    Node n = terms[i];
229
    Node cn = tcanon.getCanonicalTerm(n);
230
    term_to_cterm[n] = cn;
231
    Trace("srs-input-debug") << "Canon : " << n << " -> " << cn << std::endl;
232
    std::map<Node, Node>::iterator itc = cterm_to_term.find(cn);
233
    if (itc == cterm_to_term.end())
234
    {
235
      cterm_to_term[cn] = n;
236
      cterms.push_back(cn);
237
      t_cterms[cn.getType()].push_back(cn);
238
    }
239
  }
240
  Trace("srs-input") << "...finished." << std::endl;
241
  // the sygus variable list
242
  Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars);
243
  Trace("srs-input") << "Have " << cterms.size() << " canonical subterms."
244
                     << std::endl;
245
246
  Trace("srs-input") << "Construct unresolved types..." << std::endl;
247
  // each canonical subterm corresponds to a grammar type
248
  std::set<TypeNode> unres;
249
  std::vector<SygusDatatype> sdts;
250
  // make unresolved types for each canonical term
251
  std::map<Node, TypeNode> cterm_to_utype;
252
  for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
253
  {
254
    Node ct = cterms[i];
255
    std::stringstream ss;
256
    ss << "T" << i;
257
    std::string tname = ss.str();
258
    TypeNode tnu = nm->mkSort(tname, NodeManager::SORT_FLAG_PLACEHOLDER);
259
    cterm_to_utype[ct] = tnu;
260
    unres.insert(tnu);
261
    sdts.push_back(SygusDatatype(tname));
262
  }
263
  Trace("srs-input") << "...finished." << std::endl;
264
265
  Trace("srs-input") << "Construct sygus datatypes..." << std::endl;
266
  for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
267
  {
268
    Node ct = cterms[i];
269
    Node t = cterm_to_term[ct];
270
271
    // add the variables for the type
272
    TypeNode ctt = ct.getType();
273
    Assert(tvars.find(ctt) != tvars.end());
274
    std::vector<TypeNode> argList;
275
    // we add variable constructors if we are not Boolean, we are interested
276
    // in purely propositional rewrites (via the option), or this term is
277
    // a Boolean variable.
278
    if (!ctt.isBoolean() || options::sygusRewSynthInputUseBool()
279
        || ct.getKind() == BOUND_VARIABLE)
280
    {
281
      for (const Node& v : tvars[ctt])
282
      {
283
        std::stringstream ssc;
284
        ssc << "C_" << i << "_" << v;
285
        sdts[i].addConstructor(v, ssc.str(), argList);
286
      }
287
    }
288
    // add the constructor for the operator if it is not a variable
289
    if (ct.getKind() != BOUND_VARIABLE)
290
    {
291
      Assert(!ct.isVar());
292
      Node op = ct.hasOperator() ? ct.getOperator() : ct;
293
      // iterate over the original term
294
      for (const Node& tc : t)
295
      {
296
        // map its arguments back to canonical
297
        Assert(term_to_cterm.find(tc) != term_to_cterm.end());
298
        Node ctc = term_to_cterm[tc];
299
        Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end());
300
        // get the type
301
        argList.push_back(cterm_to_utype[ctc]);
302
      }
303
      // check if we should chain
304
      bool do_chain = false;
305
      if (argList.size() > 2)
306
      {
307
        Kind k = NodeManager::operatorToKind(op);
308
        do_chain = theory::quantifiers::TermUtil::isAssoc(k)
309
                   && theory::quantifiers::TermUtil::isComm(k);
310
        // eliminate duplicate child types
311
        std::vector<TypeNode> argListTmp = argList;
312
        argList.clear();
313
        std::map<TypeNode, bool> hasArgType;
314
        for (unsigned j = 0, size = argListTmp.size(); j < size; j++)
315
        {
316
          TypeNode tn = argListTmp[j];
317
          if (hasArgType.find(tn) == hasArgType.end())
318
          {
319
            hasArgType[tn] = true;
320
            argList.push_back(tn);
321
          }
322
        }
323
      }
324
      if (do_chain)
325
      {
326
        // we make one type per child
327
        // the operator of each constructor is a no-op
328
        Node tbv = nm->mkBoundVar(ctt);
329
        Node lambdaOp =
330
            nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
331
        std::vector<TypeNode> argListc;
332
        // the following construction admits any number of repeated factors,
333
        // so for instance, t1+t2+t3, we generate the grammar:
334
        // T_{t1+t2+t3} ->
335
        //   +( T_{t1+t2+t3}, T_{t1+t2+t3} ) | T_{t1} | T_{t2} | T_{t3}
336
        // where we write T_t to denote "the type that abstracts term t".
337
        // Notice this construction allows to abstract subsets of the factors
338
        // of t1+t2+t3. This is particularly helpful for terms t1+...+tn for
339
        // large n, where we would like to consider binary applications of +.
340
        for (unsigned j = 0, size = argList.size(); j < size; j++)
341
        {
342
          argListc.clear();
343
          argListc.push_back(argList[j]);
344
          std::stringstream sscs;
345
          sscs << "C_factor_" << i << "_" << j;
346
          // ID function is not printed and does not count towards weight
347
          sdts[i].addConstructor(lambdaOp,
348
                                 sscs.str(),
349
                                 argListc,
350
                                 0);
351
        }
352
        // recursive apply
353
        TypeNode recType = cterm_to_utype[ct];
354
        argListc.clear();
355
        argListc.push_back(recType);
356
        argListc.push_back(recType);
357
        std::stringstream ssc;
358
        ssc << "C_" << i << "_rec_" << op;
359
        sdts[i].addConstructor(op, ssc.str(), argListc);
360
      }
361
      else
362
      {
363
        std::stringstream ssc;
364
        ssc << "C_" << i << "_" << op;
365
        sdts[i].addConstructor(op, ssc.str(), argList);
366
      }
367
    }
368
    Assert(sdts[i].getNumConstructors() > 0);
369
    sdts[i].initializeDatatype(ctt, sygusVarList, false, false);
370
  }
371
  Trace("srs-input") << "...finished." << std::endl;
372
373
  Trace("srs-input") << "Make mutual datatype types for subterms..."
374
                     << std::endl;
375
  // extract the datatypes
376
  std::vector<DType> datatypes;
377
  for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
378
  {
379
    datatypes.push_back(sdts[i].getDatatype());
380
  }
381
  std::vector<TypeNode> types = nm->mkMutualDatatypeTypes(
382
      datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER);
383
  Trace("srs-input") << "...finished." << std::endl;
384
  Assert(types.size() == unres.size());
385
  std::map<Node, TypeNode> subtermTypes;
386
  for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
387
  {
388
    subtermTypes[cterms[i]] = types[i];
389
  }
390
391
  Trace("srs-input") << "Construct the top-level types..." << std::endl;
392
  // we now are ready to create the "top-level" types
393
  std::map<TypeNode, TypeNode> tlGrammarTypes;
394
  for (std::pair<const TypeNode, std::vector<Node> >& tcp : t_cterms)
395
  {
396
    TypeNode t = tcp.first;
397
    std::stringstream ss;
398
    ss << "T_" << t;
399
    SygusDatatype sdttl(ss.str());
400
    Node tbv = nm->mkBoundVar(t);
401
    // the operator of each constructor is a no-op
402
    Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
403
    Trace("srs-input") << "  We have " << tcp.second.size()
404
                       << " subterms of type " << t << std::endl;
405
    for (unsigned i = 0, size = tcp.second.size(); i < size; i++)
406
    {
407
      Node n = tcp.second[i];
408
      // add constructor that encodes abstractions of this subterm
409
      std::vector<TypeNode> argList;
410
      Assert(subtermTypes.find(n) != subtermTypes.end());
411
      argList.push_back(subtermTypes[n]);
412
      std::stringstream ssc;
413
      ssc << "Ctl_" << i;
414
      // the no-op should not be printed, hence we pass an empty callback
415
      sdttl.addConstructor(lambdaOp,
416
                           ssc.str(),
417
                           argList,
418
                           0);
419
      Trace("srs-input-debug")
420
          << "Grammar for subterm " << n << " is: " << std::endl;
421
      Trace("srs-input-debug") << subtermTypes[n].getDType() << std::endl;
422
    }
423
    // set that this is a sygus datatype
424
    sdttl.initializeDatatype(t, sygusVarList, false, false);
425
    DType dttl = sdttl.getDatatype();
426
    TypeNode tlt =
427
        nm->mkDatatypeType(dttl, NodeManager::DATATYPE_FLAG_PLACEHOLDER);
428
    tlGrammarTypes[t] = tlt;
429
    Trace("srs-input") << "Grammar is: " << std::endl;
430
    Trace("srs-input") << tlt.getDType() << std::endl;
431
  }
432
  Trace("srs-input") << "...finished." << std::endl;
433
434
  // sygus attribute to mark the conjecture as a sygus conjecture
435
  Trace("srs-input") << "Make sygus conjecture..." << std::endl;
436
  // we are "synthesizing" functions for each type of subterm
437
  std::vector<Node> synthConj;
438
  unsigned fCounter = 1;
439
  theory::SygusSynthGrammarAttribute ssg;
440
  for (std::pair<const TypeNode, TypeNode> ttp : tlGrammarTypes)
441
  {
442
    Node gvar = nm->mkBoundVar("sfproxy", ttp.second);
443
    TypeNode ft = nm->mkFunctionType(allVarTypes, ttp.first);
444
    // likewise, it is helpful if these have good names, we choose f1, f2, ...
445
    std::stringstream ssf;
446
    ssf << "f" << fCounter;
447
    fCounter++;
448
    Node sfun = nm->mkBoundVar(ssf.str(), ft);
449
    // this marks that the grammar used for solutions for sfun is the type of
450
    // gvar, which is the sygus datatype type constructed above.
451
    sfun.setAttribute(ssg, gvar);
452
453
    Node body = nm->mkConst(false);
454
    body = theory::quantifiers::SygusUtils::mkSygusConjecture({sfun}, body);
455
    synthConj.push_back(body);
456
  }
457
  Node trueNode = nm->mkConst(true);
458
  Node res =
459
      synthConj.empty()
460
          ? trueNode
461
          : (synthConj.size() == 1 ? synthConj[0] : nm->mkNode(AND, synthConj));
462
463
  Trace("srs-input") << "got : " << res << std::endl;
464
  Trace("srs-input") << "...finished." << std::endl;
465
466
  assertionsToPreprocess->replace(0, res);
467
  for (unsigned i = 1, size = assertionsToPreprocess->size(); i < size; ++i)
468
  {
469
    assertionsToPreprocess->replace(i, trueNode);
470
  }
471
472
  return PreprocessingPassResult::NO_CONFLICT;
473
}
474
475
}  // namespace passes
476
}  // namespace preprocessing
477
26676
}  // namespace CVC4