GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/sygus_inference.cpp Lines: 156 163 95.7 %
Date: 2021-09-17 Branches: 304 622 48.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Mathias Preiner, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Sygus inference module.
14
 */
15
16
#include "preprocessing/passes/sygus_inference.h"
17
18
#include "preprocessing/assertion_pipeline.h"
19
#include "preprocessing/preprocessing_pass_context.h"
20
#include "smt/smt_engine.h"
21
#include "smt/smt_engine_scope.h"
22
#include "smt/smt_statistics_registry.h"
23
#include "theory/quantifiers/quantifiers_attributes.h"
24
#include "theory/quantifiers/quantifiers_preprocess.h"
25
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
26
#include "theory/quantifiers/sygus/sygus_utils.h"
27
#include "theory/rewriter.h"
28
#include "theory/smt_engine_subsolver.h"
29
30
using namespace std;
31
using namespace cvc5::kind;
32
using namespace cvc5::theory;
33
34
namespace cvc5 {
35
namespace preprocessing {
36
namespace passes {
37
38
9942
SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
39
9942
    : PreprocessingPass(preprocContext, "sygus-infer"){};
40
41
57
PreprocessingPassResult SygusInference::applyInternal(
42
    AssertionPipeline* assertionsToPreprocess)
43
{
44
57
  Trace("sygus-infer") << "Run sygus inference..." << std::endl;
45
114
  std::vector<Node> funs;
46
114
  std::vector<Node> sols;
47
  // see if we can succesfully solve the input as a sygus problem
48
57
  if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
49
  {
50
41
    Trace("sygus-infer") << "...Solved:" << std::endl;
51
41
    Assert(funs.size() == sols.size());
52
    // if so, sygus gives us function definitions, which we add as substitutions
53
114
    for (unsigned i = 0, size = funs.size(); i < size; i++)
54
    {
55
73
      Trace("sygus-infer") << funs[i] << " -> " << sols[i] << std::endl;
56
73
      d_preprocContext->addSubstitution(funs[i], sols[i]);
57
    }
58
59
    // apply substitution to everything, should result in SAT
60
143
    for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
61
         i++)
62
    {
63
204
      Node prev = (*assertionsToPreprocess)[i];
64
      Node curr =
65
204
          prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
66
102
      if (curr != prev)
67
      {
68
61
        curr = rewrite(curr);
69
122
        Trace("sygus-infer-debug")
70
61
            << "...rewrote " << prev << " to " << curr << std::endl;
71
61
        assertionsToPreprocess->replace(i, curr);
72
      }
73
    }
74
  }
75
114
  return PreprocessingPassResult::NO_CONFLICT;
76
}
77
78
57
bool SygusInference::solveSygus(const std::vector<Node>& assertions,
79
                                std::vector<Node>& funs,
80
                                std::vector<Node>& sols)
81
{
82
57
  if (assertions.empty())
83
  {
84
    Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
85
    return false;
86
  }
87
88
57
  NodeManager* nm = NodeManager::currentNM();
89
90
  // collect free variables in all assertions
91
114
  std::vector<Node> qvars;
92
114
  std::map<TypeNode, std::vector<Node> > qtvars;
93
114
  std::vector<Node> free_functions;
94
95
114
  std::vector<TNode> visit;
96
114
  std::unordered_set<TNode> visited;
97
98
  // add top-level conjuncts to eassertions
99
114
  std::vector<Node> assertions_proc = assertions;
100
114
  std::vector<Node> eassertions;
101
57
  unsigned index = 0;
102
401
  while (index < assertions_proc.size())
103
  {
104
344
    Node ca = assertions_proc[index];
105
172
    if (ca.getKind() == AND)
106
    {
107
18
      for (const Node& ai : ca)
108
      {
109
12
        assertions_proc.push_back(ai);
110
      }
111
    }
112
    else
113
    {
114
166
      eassertions.push_back(ca);
115
    }
116
172
    index++;
117
  }
118
119
  // process eassertions
120
114
  std::vector<Node> processed_assertions;
121
114
  quantifiers::QuantifiersPreprocess qp(d_env);
122
219
  for (const Node& as : eassertions)
123
  {
124
    // substitution for this assertion
125
326
    std::vector<Node> vars;
126
326
    std::vector<Node> subs;
127
326
    std::map<TypeNode, unsigned> type_count;
128
326
    Node pas = as;
129
    // rewrite
130
164
    pas = rewrite(pas);
131
164
    Trace("sygus-infer") << "assertion : " << pas << std::endl;
132
164
    if (pas.getKind() == FORALL)
133
    {
134
      // preprocess the quantified formula
135
32
      TrustNode trn = qp.preprocess(pas);
136
16
      if (!trn.isNull())
137
      {
138
        pas = trn.getNode();
139
      }
140
16
      Trace("sygus-infer-debug") << "  ...preprocessed to " << pas << std::endl;
141
    }
142
164
    if (pas.getKind() == FORALL)
143
    {
144
      // it must be a standard quantifier
145
32
      theory::quantifiers::QAttributes qa;
146
16
      theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
147
16
      if (!qa.isStandard())
148
      {
149
        Trace("sygus-infer")
150
            << "...fail: non-standard top-level quantifier." << std::endl;
151
        return false;
152
      }
153
      // infer prefix
154
38
      for (const Node& v : pas[0])
155
      {
156
44
        TypeNode tnv = v.getType();
157
22
        unsigned vnum = type_count[tnv];
158
22
        type_count[tnv]++;
159
22
        vars.push_back(v);
160
22
        if (vnum < qtvars[tnv].size())
161
        {
162
4
          subs.push_back(qtvars[tnv][vnum]);
163
        }
164
        else
165
        {
166
18
          Assert(vnum == qtvars[tnv].size());
167
36
          Node bv = nm->mkBoundVar(tnv);
168
18
          qtvars[tnv].push_back(bv);
169
18
          qvars.push_back(bv);
170
18
          subs.push_back(bv);
171
        }
172
      }
173
16
      pas = pas[1];
174
16
      if (!vars.empty())
175
      {
176
16
        pas =
177
32
            pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
178
      }
179
    }
180
164
    Trace("sygus-infer-debug") << "  ...substituted to " << pas << std::endl;
181
182
    // collect free functions, ensure no quantified formulas
183
326
    TNode cur = pas;
184
    // compute free variables
185
164
    visit.push_back(cur);
186
902
    do
187
    {
188
1066
      cur = visit.back();
189
1066
      visit.pop_back();
190
1066
      if (visited.find(cur) == visited.end())
191
      {
192
805
        visited.insert(cur);
193
805
        if (cur.getKind() == APPLY_UF)
194
        {
195
64
          Node op = cur.getOperator();
196
          // visit the operator, which might not be a variable
197
32
          visit.push_back(op);
198
        }
199
773
        else if (cur.isVar() && cur.getKind() != BOUND_VARIABLE)
200
        {
201
          // We are either in the case of a free first-order constant or a
202
          // function in a higher-order context. We add to free_functions
203
          // in either case. Note that a free constant that is not in a
204
          // higher-order context is a 0-argument function-to-synthesize.
205
          // We should not have traversed here before due to our visited cache.
206
129
          Assert(std::find(free_functions.begin(), free_functions.end(), cur)
207
                 == free_functions.end());
208
129
          free_functions.push_back(cur);
209
        }
210
644
        else if (cur.isClosure())
211
        {
212
4
          Trace("sygus-infer")
213
2
              << "...fail: non-top-level quantifier." << std::endl;
214
2
          return false;
215
        }
216
1675
        for (const TNode& cn : cur)
217
        {
218
872
          visit.push_back(cn);
219
        }
220
      }
221
1064
    } while (!visit.empty());
222
162
    processed_assertions.push_back(pas);
223
  }
224
225
  // no functions to synthesize
226
55
  if (free_functions.empty())
227
  {
228
2
    Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
229
2
    return false;
230
  }
231
232
  // Ensure the type of all free functions is handled by the sygus grammar
233
  // constructor utility.
234
53
  bool typeSuccess = true;
235
178
  for (const Node& f : free_functions)
236
  {
237
252
    TypeNode tn = f.getType();
238
127
    if (!theory::quantifiers::CegGrammarConstructor::isHandledType(tn))
239
    {
240
2
      Trace("sygus-infer") << "...fail: unhandled type " << tn << std::endl;
241
2
      typeSuccess = false;
242
2
      break;
243
    }
244
  }
245
53
  if (!typeSuccess)
246
  {
247
2
    return false;
248
  }
249
250
51
  Assert(!processed_assertions.empty());
251
  // conjunction of the assertions
252
51
  Trace("sygus-infer") << "Construct body..." << std::endl;
253
102
  Node body;
254
51
  if (processed_assertions.size() == 1)
255
  {
256
    body = processed_assertions[0];
257
  }
258
  else
259
  {
260
51
    body = nm->mkNode(AND, processed_assertions);
261
  }
262
263
  // for each free function symbol, make a bound variable of the same type
264
51
  Trace("sygus-infer") << "Do free function substitution..." << std::endl;
265
102
  std::vector<Node> ff_vars;
266
102
  std::map<Node, Node> ff_var_to_ff;
267
176
  for (const Node& ff : free_functions)
268
  {
269
250
    Node ffv = nm->mkBoundVar(ff.getType());
270
125
    ff_vars.push_back(ffv);
271
125
    Trace("sygus-infer") << "  synth-fun: " << ff << " as " << ffv << std::endl;
272
125
    ff_var_to_ff[ffv] = ff;
273
  }
274
  // substitute free functions -> variables
275
51
  body = body.substitute(free_functions.begin(),
276
                         free_functions.end(),
277
                         ff_vars.begin(),
278
                         ff_vars.end());
279
51
  Trace("sygus-infer-debug") << "...got : " << body << std::endl;
280
281
  // quantify the body
282
51
  Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
283
51
  body = body.negate();
284
51
  if (!qvars.empty())
285
  {
286
24
    Node bvl = nm->mkNode(BOUND_VAR_LIST, qvars);
287
12
    body = nm->mkNode(EXISTS, bvl, body);
288
  }
289
290
  // sygus attribute to mark the conjecture as a sygus conjecture
291
51
  Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
292
293
51
  body = quantifiers::SygusUtils::mkSygusConjecture(ff_vars, body);
294
295
51
  Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
296
297
  // make a separate smt call
298
102
  std::unique_ptr<SmtEngine> rrSygus;
299
51
  theory::initializeSubsolver(rrSygus, options(), logicInfo());
300
51
  rrSygus->assertFormula(body);
301
51
  Trace("sygus-infer") << "*** Check sat..." << std::endl;
302
102
  Result r = rrSygus->checkSat();
303
51
  Trace("sygus-infer") << "...result : " << r << std::endl;
304
51
  if (r.asSatisfiabilityResult().isSat() != Result::UNSAT)
305
  {
306
    // failed, conjecture was infeasible
307
10
    return false;
308
  }
309
  // get the synthesis solutions
310
82
  std::map<Node, Node> synth_sols;
311
41
  rrSygus->getSynthSolutions(synth_sols);
312
313
82
  std::vector<Node> final_ff;
314
82
  std::vector<Node> final_ff_sol;
315
114
  for (std::map<Node, Node>::iterator it = synth_sols.begin();
316
114
       it != synth_sols.end();
317
       ++it)
318
  {
319
146
    Trace("sygus-infer") << "  synth sol : " << it->first << " -> "
320
73
                         << it->second << std::endl;
321
146
    Node ffv = it->first;
322
73
    std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
323
    // all synthesis solutions should correspond to a variable we introduced
324
73
    Assert(itffv != ff_var_to_ff.end());
325
73
    if (itffv != ff_var_to_ff.end())
326
    {
327
146
      Node ff = itffv->second;
328
146
      Node body2 = it->second;
329
73
      Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
330
73
      funs.push_back(ff);
331
73
      sols.push_back(body2);
332
    }
333
  }
334
41
  return true;
335
}
336
337
338
}  // namespace passes
339
}  // namespace preprocessing
340
29577
}  // namespace cvc5