GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/sygus_inference.cpp Lines: 155 162 95.7 %
Date: 2021-05-22 Branches: 301 618 48.7 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mathias Preiner, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Sygus inference module.
14
 */
15
16
#include "preprocessing/passes/sygus_inference.h"
17
18
#include "preprocessing/assertion_pipeline.h"
19
#include "preprocessing/preprocessing_pass_context.h"
20
#include "smt/smt_engine.h"
21
#include "smt/smt_engine_scope.h"
22
#include "smt/smt_statistics_registry.h"
23
#include "theory/quantifiers/quantifiers_attributes.h"
24
#include "theory/quantifiers/quantifiers_rewriter.h"
25
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
26
#include "theory/quantifiers/sygus/sygus_utils.h"
27
#include "theory/rewriter.h"
28
#include "theory/smt_engine_subsolver.h"
29
30
using namespace std;
31
using namespace cvc5::kind;
32
using namespace cvc5::theory;
33
34
namespace cvc5 {
35
namespace preprocessing {
36
namespace passes {
37
38
9459
SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
39
9459
    : PreprocessingPass(preprocContext, "sygus-infer"){};
40
41
58
PreprocessingPassResult SygusInference::applyInternal(
42
    AssertionPipeline* assertionsToPreprocess)
43
{
44
58
  Trace("sygus-infer") << "Run sygus inference..." << std::endl;
45
116
  std::vector<Node> funs;
46
116
  std::vector<Node> sols;
47
  // see if we can succesfully solve the input as a sygus problem
48
58
  if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
49
  {
50
42
    Trace("sygus-infer") << "...Solved:" << std::endl;
51
42
    Assert(funs.size() == sols.size());
52
    // if so, sygus gives us function definitions, which we add as substitutions
53
118
    for (unsigned i = 0, size = funs.size(); i < size; i++)
54
    {
55
76
      Trace("sygus-infer") << funs[i] << " -> " << sols[i] << std::endl;
56
76
      d_preprocContext->addSubstitution(funs[i], sols[i]);
57
    }
58
59
    // apply substitution to everything, should result in SAT
60
147
    for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
61
         i++)
62
    {
63
210
      Node prev = (*assertionsToPreprocess)[i];
64
      Node curr =
65
210
          prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
66
105
      if (curr != prev)
67
      {
68
63
        curr = theory::Rewriter::rewrite(curr);
69
126
        Trace("sygus-infer-debug")
70
63
            << "...rewrote " << prev << " to " << curr << std::endl;
71
63
        assertionsToPreprocess->replace(i, curr);
72
      }
73
    }
74
  }
75
116
  return PreprocessingPassResult::NO_CONFLICT;
76
}
77
78
58
bool SygusInference::solveSygus(const std::vector<Node>& assertions,
79
                                std::vector<Node>& funs,
80
                                std::vector<Node>& sols)
81
{
82
58
  if (assertions.empty())
83
  {
84
    Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
85
    return false;
86
  }
87
88
58
  NodeManager* nm = NodeManager::currentNM();
89
90
  // collect free variables in all assertions
91
116
  std::vector<Node> qvars;
92
116
  std::map<TypeNode, std::vector<Node> > qtvars;
93
116
  std::vector<Node> free_functions;
94
95
116
  std::vector<TNode> visit;
96
116
  std::unordered_set<TNode> visited;
97
98
  // add top-level conjuncts to eassertions
99
116
  std::vector<Node> assertions_proc = assertions;
100
116
  std::vector<Node> eassertions;
101
58
  unsigned index = 0;
102
408
  while (index < assertions_proc.size())
103
  {
104
350
    Node ca = assertions_proc[index];
105
175
    if (ca.getKind() == AND)
106
    {
107
18
      for (const Node& ai : ca)
108
      {
109
12
        assertions_proc.push_back(ai);
110
      }
111
    }
112
    else
113
    {
114
169
      eassertions.push_back(ca);
115
    }
116
175
    index++;
117
  }
118
119
  // process eassertions
120
116
  std::vector<Node> processed_assertions;
121
223
  for (const Node& as : eassertions)
122
  {
123
    // substitution for this assertion
124
332
    std::vector<Node> vars;
125
332
    std::vector<Node> subs;
126
332
    std::map<TypeNode, unsigned> type_count;
127
332
    Node pas = as;
128
    // rewrite
129
167
    pas = theory::Rewriter::rewrite(pas);
130
167
    Trace("sygus-infer") << "assertion : " << pas << std::endl;
131
167
    if (pas.getKind() == FORALL)
132
    {
133
      // preprocess the quantified formula
134
36
      TrustNode trn = quantifiers::QuantifiersRewriter::preprocess(pas);
135
18
      if (!trn.isNull())
136
      {
137
        pas = trn.getNode();
138
      }
139
18
      Trace("sygus-infer-debug") << "  ...preprocessed to " << pas << std::endl;
140
    }
141
167
    if (pas.getKind() == FORALL)
142
    {
143
      // it must be a standard quantifier
144
36
      theory::quantifiers::QAttributes qa;
145
18
      theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
146
18
      if (!qa.isStandard())
147
      {
148
        Trace("sygus-infer")
149
            << "...fail: non-standard top-level quantifier." << std::endl;
150
        return false;
151
      }
152
      // infer prefix
153
42
      for (const Node& v : pas[0])
154
      {
155
48
        TypeNode tnv = v.getType();
156
24
        unsigned vnum = type_count[tnv];
157
24
        type_count[tnv]++;
158
24
        vars.push_back(v);
159
24
        if (vnum < qtvars[tnv].size())
160
        {
161
4
          subs.push_back(qtvars[tnv][vnum]);
162
        }
163
        else
164
        {
165
20
          Assert(vnum == qtvars[tnv].size());
166
40
          Node bv = nm->mkBoundVar(tnv);
167
20
          qtvars[tnv].push_back(bv);
168
20
          qvars.push_back(bv);
169
20
          subs.push_back(bv);
170
        }
171
      }
172
18
      pas = pas[1];
173
18
      if (!vars.empty())
174
      {
175
18
        pas =
176
36
            pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
177
      }
178
    }
179
167
    Trace("sygus-infer-debug") << "  ...substituted to " << pas << std::endl;
180
181
    // collect free functions, ensure no quantified formulas
182
332
    TNode cur = pas;
183
    // compute free variables
184
167
    visit.push_back(cur);
185
965
    do
186
    {
187
1132
      cur = visit.back();
188
1132
      visit.pop_back();
189
1132
      if (visited.find(cur) == visited.end())
190
      {
191
854
        visited.insert(cur);
192
854
        if (cur.getKind() == APPLY_UF)
193
        {
194
68
          Node op = cur.getOperator();
195
          // visit the operator, which might not be a variable
196
34
          visit.push_back(op);
197
        }
198
820
        else if (cur.isVar() && cur.getKind() != BOUND_VARIABLE)
199
        {
200
          // We are either in the case of a free first-order constant or a
201
          // function in a higher-order context. We add to free_functions
202
          // in either case. Note that a free constant that is not in a
203
          // higher-order context is a 0-argument function-to-synthesize.
204
          // We should not have traversed here before due to our visited cache.
205
132
          Assert(std::find(free_functions.begin(), free_functions.end(), cur)
206
                 == free_functions.end());
207
132
          free_functions.push_back(cur);
208
        }
209
688
        else if (cur.isClosure())
210
        {
211
4
          Trace("sygus-infer")
212
2
              << "...fail: non-top-level quantifier." << std::endl;
213
2
          return false;
214
        }
215
1785
        for (const TNode& cn : cur)
216
        {
217
933
          visit.push_back(cn);
218
        }
219
      }
220
1130
    } while (!visit.empty());
221
165
    processed_assertions.push_back(pas);
222
  }
223
224
  // no functions to synthesize
225
56
  if (free_functions.empty())
226
  {
227
2
    Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
228
2
    return false;
229
  }
230
231
  // Ensure the type of all free functions is handled by the sygus grammar
232
  // constructor utility.
233
54
  bool typeSuccess = true;
234
182
  for (const Node& f : free_functions)
235
  {
236
258
    TypeNode tn = f.getType();
237
130
    if (!theory::quantifiers::CegGrammarConstructor::isHandledType(tn))
238
    {
239
2
      Trace("sygus-infer") << "...fail: unhandled type " << tn << std::endl;
240
2
      typeSuccess = false;
241
2
      break;
242
    }
243
  }
244
54
  if (!typeSuccess)
245
  {
246
2
    return false;
247
  }
248
249
52
  Assert(!processed_assertions.empty());
250
  // conjunction of the assertions
251
52
  Trace("sygus-infer") << "Construct body..." << std::endl;
252
104
  Node body;
253
52
  if (processed_assertions.size() == 1)
254
  {
255
    body = processed_assertions[0];
256
  }
257
  else
258
  {
259
52
    body = nm->mkNode(AND, processed_assertions);
260
  }
261
262
  // for each free function symbol, make a bound variable of the same type
263
52
  Trace("sygus-infer") << "Do free function substitution..." << std::endl;
264
104
  std::vector<Node> ff_vars;
265
104
  std::map<Node, Node> ff_var_to_ff;
266
180
  for (const Node& ff : free_functions)
267
  {
268
256
    Node ffv = nm->mkBoundVar(ff.getType());
269
128
    ff_vars.push_back(ffv);
270
128
    Trace("sygus-infer") << "  synth-fun: " << ff << " as " << ffv << std::endl;
271
128
    ff_var_to_ff[ffv] = ff;
272
  }
273
  // substitute free functions -> variables
274
52
  body = body.substitute(free_functions.begin(),
275
                         free_functions.end(),
276
                         ff_vars.begin(),
277
                         ff_vars.end());
278
52
  Trace("sygus-infer-debug") << "...got : " << body << std::endl;
279
280
  // quantify the body
281
52
  Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
282
52
  body = body.negate();
283
52
  if (!qvars.empty())
284
  {
285
28
    Node bvl = nm->mkNode(BOUND_VAR_LIST, qvars);
286
14
    body = nm->mkNode(EXISTS, bvl, body);
287
  }
288
289
  // sygus attribute to mark the conjecture as a sygus conjecture
290
52
  Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
291
292
52
  body = quantifiers::SygusUtils::mkSygusConjecture(ff_vars, body);
293
294
52
  Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
295
296
  // make a separate smt call
297
104
  std::unique_ptr<SmtEngine> rrSygus;
298
52
  theory::initializeSubsolver(rrSygus);
299
52
  rrSygus->assertFormula(body);
300
52
  Trace("sygus-infer") << "*** Check sat..." << std::endl;
301
104
  Result r = rrSygus->checkSat();
302
52
  Trace("sygus-infer") << "...result : " << r << std::endl;
303
52
  if (r.asSatisfiabilityResult().isSat() != Result::UNSAT)
304
  {
305
    // failed, conjecture was infeasible
306
10
    return false;
307
  }
308
  // get the synthesis solutions
309
84
  std::map<Node, Node> synth_sols;
310
42
  rrSygus->getSynthSolutions(synth_sols);
311
312
84
  std::vector<Node> final_ff;
313
84
  std::vector<Node> final_ff_sol;
314
118
  for (std::map<Node, Node>::iterator it = synth_sols.begin();
315
118
       it != synth_sols.end();
316
       ++it)
317
  {
318
152
    Trace("sygus-infer") << "  synth sol : " << it->first << " -> "
319
76
                         << it->second << std::endl;
320
152
    Node ffv = it->first;
321
76
    std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
322
    // all synthesis solutions should correspond to a variable we introduced
323
76
    Assert(itffv != ff_var_to_ff.end());
324
76
    if (itffv != ff_var_to_ff.end())
325
    {
326
152
      Node ff = itffv->second;
327
152
      Node body2 = it->second;
328
76
      Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
329
76
      funs.push_back(ff);
330
76
      sols.push_back(body2);
331
    }
332
  }
333
42
  return true;
334
}
335
336
337
}  // namespace passes
338
}  // namespace preprocessing
339
28191
}  // namespace cvc5