GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/sygus_inference.cpp Lines: 160 167 95.8 %
Date: 2021-03-22 Branches: 306 624 49.0 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file sygus_inference.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Mathias Preiner, Andres Noetzli
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Sygus inference module
13
 **/
14
15
#include "preprocessing/passes/sygus_inference.h"
16
17
#include "preprocessing/assertion_pipeline.h"
18
#include "preprocessing/preprocessing_pass_context.h"
19
#include "smt/smt_engine.h"
20
#include "smt/smt_engine_scope.h"
21
#include "smt/smt_statistics_registry.h"
22
#include "theory/quantifiers/quantifiers_attributes.h"
23
#include "theory/quantifiers/quantifiers_rewriter.h"
24
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
25
#include "theory/quantifiers/sygus/sygus_utils.h"
26
#include "theory/rewriter.h"
27
#include "theory/smt_engine_subsolver.h"
28
29
using namespace std;
30
using namespace CVC4::kind;
31
using namespace CVC4::theory;
32
33
namespace CVC4 {
34
namespace preprocessing {
35
namespace passes {
36
37
8995
SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
38
8995
    : PreprocessingPass(preprocContext, "sygus-infer"){};
39
40
62
PreprocessingPassResult SygusInference::applyInternal(
41
    AssertionPipeline* assertionsToPreprocess)
42
{
43
62
  Trace("sygus-infer") << "Run sygus inference..." << std::endl;
44
124
  std::vector<Node> funs;
45
124
  std::vector<Node> sols;
46
  // see if we can succesfully solve the input as a sygus problem
47
62
  if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
48
  {
49
40
    Assert(funs.size() == sols.size());
50
    // if so, sygus gives us function definitions
51
40
    SmtEngine* master_smte = d_preprocContext->getSmt();
52
113
    for (unsigned i = 0, size = funs.size(); i < size; i++)
53
    {
54
146
      std::vector<Node> args;
55
146
      Node sol = sols[i];
56
      // if it is a non-constant function
57
73
      if (sol.getKind() == LAMBDA)
58
      {
59
43
        for (const Node& v : sol[0])
60
        {
61
23
          args.push_back(v);
62
        }
63
20
        sol = sol[1];
64
      }
65
73
      master_smte->defineFunction(funs[i], args, sol);
66
    }
67
68
    // apply substitution to everything, should result in SAT
69
141
    for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
70
         i++)
71
    {
72
202
      Node prev = (*assertionsToPreprocess)[i];
73
      Node curr =
74
202
          prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
75
101
      if (curr != prev)
76
      {
77
61
        curr = theory::Rewriter::rewrite(curr);
78
122
        Trace("sygus-infer-debug")
79
61
            << "...rewrote " << prev << " to " << curr << std::endl;
80
61
        assertionsToPreprocess->replace(i, curr);
81
      }
82
    }
83
  }
84
124
  return PreprocessingPassResult::NO_CONFLICT;
85
}
86
87
62
bool SygusInference::solveSygus(const std::vector<Node>& assertions,
88
                                std::vector<Node>& funs,
89
                                std::vector<Node>& sols)
90
{
91
62
  if (assertions.empty())
92
  {
93
    Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
94
    return false;
95
  }
96
97
62
  NodeManager* nm = NodeManager::currentNM();
98
99
  // collect free variables in all assertions
100
124
  std::vector<Node> qvars;
101
124
  std::map<TypeNode, std::vector<Node> > qtvars;
102
124
  std::vector<Node> free_functions;
103
104
124
  std::vector<TNode> visit;
105
124
  std::unordered_set<TNode, TNodeHashFunction> visited;
106
107
  // add top-level conjuncts to eassertions
108
124
  std::vector<Node> assertions_proc = assertions;
109
124
  std::vector<Node> eassertions;
110
62
  unsigned index = 0;
111
462
  while (index < assertions_proc.size())
112
  {
113
400
    Node ca = assertions_proc[index];
114
200
    if (ca.getKind() == AND)
115
    {
116
24
      for (const Node& ai : ca)
117
      {
118
16
        assertions_proc.push_back(ai);
119
      }
120
    }
121
    else
122
    {
123
192
      eassertions.push_back(ca);
124
    }
125
200
    index++;
126
  }
127
128
  // process eassertions
129
124
  std::vector<Node> processed_assertions;
130
248
  for (const Node& as : eassertions)
131
  {
132
    // substitution for this assertion
133
375
    std::vector<Node> vars;
134
375
    std::vector<Node> subs;
135
375
    std::map<TypeNode, unsigned> type_count;
136
375
    Node pas = as;
137
    // rewrite
138
189
    pas = theory::Rewriter::rewrite(pas);
139
189
    Trace("sygus-infer") << "assertion : " << pas << std::endl;
140
189
    if (pas.getKind() == FORALL)
141
    {
142
      // preprocess the quantified formula
143
34
      TrustNode trn = quantifiers::QuantifiersRewriter::preprocess(pas);
144
17
      if (!trn.isNull())
145
      {
146
        pas = trn.getNode();
147
      }
148
17
      Trace("sygus-infer-debug") << "  ...preprocessed to " << pas << std::endl;
149
    }
150
189
    if (pas.getKind() == FORALL)
151
    {
152
      // it must be a standard quantifier
153
34
      theory::quantifiers::QAttributes qa;
154
17
      theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
155
17
      if (!qa.isStandard())
156
      {
157
        Trace("sygus-infer")
158
            << "...fail: non-standard top-level quantifier." << std::endl;
159
        return false;
160
      }
161
      // infer prefix
162
40
      for (const Node& v : pas[0])
163
      {
164
46
        TypeNode tnv = v.getType();
165
23
        unsigned vnum = type_count[tnv];
166
23
        type_count[tnv]++;
167
23
        vars.push_back(v);
168
23
        if (vnum < qtvars[tnv].size())
169
        {
170
5
          subs.push_back(qtvars[tnv][vnum]);
171
        }
172
        else
173
        {
174
18
          Assert(vnum == qtvars[tnv].size());
175
36
          Node bv = nm->mkBoundVar(tnv);
176
18
          qtvars[tnv].push_back(bv);
177
18
          qvars.push_back(bv);
178
18
          subs.push_back(bv);
179
        }
180
      }
181
17
      pas = pas[1];
182
17
      if (!vars.empty())
183
      {
184
17
        pas =
185
34
            pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
186
      }
187
    }
188
189
    Trace("sygus-infer-debug") << "  ...substituted to " << pas << std::endl;
189
190
    // collect free functions, ensure no quantified formulas
191
375
    TNode cur = pas;
192
    // compute free variables
193
189
    visit.push_back(cur);
194
983
    do
195
    {
196
1172
      cur = visit.back();
197
1172
      visit.pop_back();
198
1172
      if (visited.find(cur) == visited.end())
199
      {
200
888
        visited.insert(cur);
201
888
        if (cur.getKind() == APPLY_UF)
202
        {
203
62
          Node op = cur.getOperator();
204
          // visit the operator, which might not be a variable
205
31
          visit.push_back(op);
206
        }
207
857
        else if (cur.isVar() && cur.getKind() != BOUND_VARIABLE)
208
        {
209
          // We are either in the case of a free first-order constant or a
210
          // function in a higher-order context. We add to free_functions
211
          // in either case. Note that a free constant that is not in a
212
          // higher-order context is a 0-argument function-to-synthesize.
213
          // We should not have traversed here before due to our visited cache.
214
150
          Assert(std::find(free_functions.begin(), free_functions.end(), cur)
215
                 == free_functions.end());
216
150
          free_functions.push_back(cur);
217
        }
218
707
        else if (cur.isClosure())
219
        {
220
6
          Trace("sygus-infer")
221
3
              << "...fail: non-top-level quantifier." << std::endl;
222
3
          return false;
223
        }
224
1840
        for (const TNode& cn : cur)
225
        {
226
955
          visit.push_back(cn);
227
        }
228
      }
229
1169
    } while (!visit.empty());
230
186
    processed_assertions.push_back(pas);
231
  }
232
233
  // no functions to synthesize
234
59
  if (free_functions.empty())
235
  {
236
3
    Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
237
3
    return false;
238
  }
239
240
  // Ensure the type of all free functions is handled by the sygus grammar
241
  // constructor utility.
242
56
  bool typeSuccess = true;
243
202
  for (const Node& f : free_functions)
244
  {
245
294
    TypeNode tn = f.getType();
246
148
    if (!theory::quantifiers::CegGrammarConstructor::isHandledType(tn))
247
    {
248
2
      Trace("sygus-infer") << "...fail: unhandled type " << tn << std::endl;
249
2
      typeSuccess = false;
250
2
      break;
251
    }
252
  }
253
56
  if (!typeSuccess)
254
  {
255
2
    return false;
256
  }
257
258
54
  Assert(!processed_assertions.empty());
259
  // conjunction of the assertions
260
54
  Trace("sygus-infer") << "Construct body..." << std::endl;
261
108
  Node body;
262
54
  if (processed_assertions.size() == 1)
263
  {
264
    body = processed_assertions[0];
265
  }
266
  else
267
  {
268
54
    body = nm->mkNode(AND, processed_assertions);
269
  }
270
271
  // for each free function symbol, make a bound variable of the same type
272
54
  Trace("sygus-infer") << "Do free function substitution..." << std::endl;
273
108
  std::vector<Node> ff_vars;
274
108
  std::map<Node, Node> ff_var_to_ff;
275
200
  for (const Node& ff : free_functions)
276
  {
277
292
    Node ffv = nm->mkBoundVar(ff.getType());
278
146
    ff_vars.push_back(ffv);
279
146
    Trace("sygus-infer") << "  synth-fun: " << ff << " as " << ffv << std::endl;
280
146
    ff_var_to_ff[ffv] = ff;
281
  }
282
  // substitute free functions -> variables
283
54
  body = body.substitute(free_functions.begin(),
284
                         free_functions.end(),
285
                         ff_vars.begin(),
286
                         ff_vars.end());
287
54
  Trace("sygus-infer-debug") << "...got : " << body << std::endl;
288
289
  // quantify the body
290
54
  Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
291
54
  body = body.negate();
292
54
  if (!qvars.empty())
293
  {
294
24
    Node bvl = nm->mkNode(BOUND_VAR_LIST, qvars);
295
12
    body = nm->mkNode(EXISTS, bvl, body);
296
  }
297
298
  // sygus attribute to mark the conjecture as a sygus conjecture
299
54
  Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
300
301
54
  body = quantifiers::SygusUtils::mkSygusConjecture(ff_vars, body);
302
303
54
  Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
304
305
  // make a separate smt call
306
108
  std::unique_ptr<SmtEngine> rrSygus;
307
54
  theory::initializeSubsolver(rrSygus);
308
54
  rrSygus->assertFormula(body);
309
54
  Trace("sygus-infer") << "*** Check sat..." << std::endl;
310
108
  Result r = rrSygus->checkSat();
311
54
  Trace("sygus-infer") << "...result : " << r << std::endl;
312
54
  if (r.asSatisfiabilityResult().isSat() != Result::UNSAT)
313
  {
314
    // failed, conjecture was infeasible
315
14
    return false;
316
  }
317
  // get the synthesis solutions
318
80
  std::map<Node, Node> synth_sols;
319
40
  rrSygus->getSynthSolutions(synth_sols);
320
321
80
  std::vector<Node> final_ff;
322
80
  std::vector<Node> final_ff_sol;
323
113
  for (std::map<Node, Node>::iterator it = synth_sols.begin();
324
113
       it != synth_sols.end();
325
       ++it)
326
  {
327
146
    Trace("sygus-infer") << "  synth sol : " << it->first << " -> "
328
73
                         << it->second << std::endl;
329
146
    Node ffv = it->first;
330
73
    std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
331
    // all synthesis solutions should correspond to a variable we introduced
332
73
    Assert(itffv != ff_var_to_ff.end());
333
73
    if (itffv != ff_var_to_ff.end())
334
    {
335
146
      Node ff = itffv->second;
336
146
      Node body2 = it->second;
337
73
      Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
338
73
      funs.push_back(ff);
339
73
      sols.push_back(body2);
340
    }
341
  }
342
40
  return true;
343
}
344
345
346
}  // namespace passes
347
}  // namespace preprocessing
348
26676
}  // namespace CVC4