GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/smt/abduction_solver.cpp Lines: 82 98 83.7 %
Date: 2021-09-07 Branches: 161 414 38.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Morgan Deters, Mathias Preiner
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * The solver for abduction queries.
14
 */
15
16
#include "smt/abduction_solver.h"
17
18
#include <sstream>
19
20
#include "base/modal_exception.h"
21
#include "options/smt_options.h"
22
#include "smt/env.h"
23
#include "smt/smt_engine.h"
24
#include "theory/quantifiers/quantifiers_attributes.h"
25
#include "theory/quantifiers/sygus/sygus_abduct.h"
26
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
27
#include "theory/smt_engine_subsolver.h"
28
#include "theory/trust_substitutions.h"
29
30
using namespace cvc5::theory;
31
32
namespace cvc5 {
33
namespace smt {
34
35
326
AbductionSolver::AbductionSolver(Env& env) : EnvObj(env) {}
36
37
652
AbductionSolver::~AbductionSolver() {}
38
16
bool AbductionSolver::getAbduct(const std::vector<Node>& axioms,
39
                                const Node& goal,
40
                                const TypeNode& grammarType,
41
                                Node& abd)
42
{
43
16
  if (!options::produceAbducts())
44
  {
45
    const char* msg = "Cannot get abduct when produce-abducts options is off.";
46
    throw ModalException(msg);
47
  }
48
16
  Trace("sygus-abduct") << "SmtEngine::getAbduct: goal " << goal << std::endl;
49
32
  std::vector<Node> asserts(axioms.begin(), axioms.end());
50
  // must expand definitions
51
32
  Node conjn = d_env.getTopLevelSubstitutions().apply(goal);
52
  // now negate
53
16
  conjn = conjn.negate();
54
16
  d_abdConj = conjn;
55
16
  asserts.push_back(conjn);
56
32
  std::string name("A");
57
  Node aconj = quantifiers::SygusAbduct::mkAbductionConjecture(
58
32
      name, asserts, axioms, grammarType);
59
  // should be a quantified conjecture with one function-to-synthesize
60
16
  Assert(aconj.getKind() == kind::FORALL && aconj[0].getNumChildren() == 1);
61
  // remember the abduct-to-synthesize
62
16
  d_sssf = aconj[0][0];
63
32
  Trace("sygus-abduct") << "SmtEngine::getAbduct: made conjecture : " << aconj
64
16
                        << ", solving for " << d_sssf << std::endl;
65
  // we generate a new smt engine to do the abduction query
66
16
  initializeSubsolver(d_subsolver, d_env);
67
  // get the logic
68
32
  LogicInfo l = d_subsolver->getLogicInfo().getUnlockedCopy();
69
  // enable everything needed for sygus
70
16
  l.enableSygus();
71
16
  d_subsolver->setLogic(l);
72
  // assert the abduction query
73
16
  d_subsolver->assertFormula(aconj);
74
30
  return getAbductInternal(axioms, abd);
75
}
76
77
bool AbductionSolver::getAbduct(const std::vector<Node>& axioms,
78
                                const Node& goal,
79
                                Node& abd)
80
{
81
  TypeNode grammarType;
82
  return getAbduct(axioms, goal, grammarType, abd);
83
}
84
85
16
bool AbductionSolver::getAbductInternal(const std::vector<Node>& axioms,
86
                                        Node& abd)
87
{
88
  // should have initialized the subsolver by now
89
16
  Assert(d_subsolver != nullptr);
90
16
  Trace("sygus-abduct") << "  SmtEngine::getAbduct check sat..." << std::endl;
91
30
  Result r = d_subsolver->checkSat();
92
14
  Trace("sygus-abduct") << "  SmtEngine::getAbduct result: " << r << std::endl;
93
14
  if (r.asSatisfiabilityResult().isSat() == Result::UNSAT)
94
  {
95
    // get the synthesis solution
96
28
    std::map<Node, Node> sols;
97
14
    d_subsolver->getSynthSolutions(sols);
98
14
    Assert(sols.size() == 1);
99
14
    std::map<Node, Node>::iterator its = sols.find(d_sssf);
100
14
    if (its != sols.end())
101
    {
102
28
      Trace("sygus-abduct")
103
14
          << "SmtEngine::getAbduct: solution is " << its->second << std::endl;
104
14
      abd = its->second;
105
14
      if (abd.getKind() == kind::LAMBDA)
106
      {
107
13
        abd = abd[1];
108
      }
109
      // get the grammar type for the abduct
110
28
      Node agdtbv = d_sssf.getAttribute(SygusSynthFunVarListAttribute());
111
14
      if(!agdtbv.isNull())
112
      {
113
13
        Assert(agdtbv.getKind() == kind::BOUND_VAR_LIST);
114
        // convert back to original
115
        // must replace formal arguments of abd with the free variables in the
116
        // input problem that they correspond to.
117
26
        std::vector<Node> vars;
118
26
        std::vector<Node> syms;
119
        SygusVarToTermAttribute sta;
120
76
        for (const Node& bv : agdtbv)
121
        {
122
63
          vars.push_back(bv);
123
63
          syms.push_back(bv.hasAttribute(sta) ? bv.getAttribute(sta) : bv);
124
        }
125
13
        abd = abd.substitute(vars.begin(), vars.end(), syms.begin(), syms.end());
126
      }
127
128
      // if check abducts option is set, we check the correctness
129
14
      if (options::checkAbducts())
130
      {
131
10
        checkAbduct(axioms, abd);
132
      }
133
14
      return true;
134
    }
135
    Trace("sygus-abduct") << "SmtEngine::getAbduct: could not find solution!"
136
                          << std::endl;
137
    throw RecoverableModalException("Could not find solution for get-abduct.");
138
  }
139
  return false;
140
}
141
142
10
void AbductionSolver::checkAbduct(const std::vector<Node>& axioms, Node a)
143
{
144
10
  Assert(a.getType().isBoolean());
145
20
  Trace("check-abduct") << "SmtEngine::checkAbduct: get expanded assertions"
146
10
                        << std::endl;
147
148
20
  std::vector<Node> asserts(axioms.begin(), axioms.end());
149
10
  asserts.push_back(a);
150
151
  // two checks: first, consistent with assertions, second, implies negated goal
152
  // is unsatisfiable.
153
30
  for (unsigned j = 0; j < 2; j++)
154
  {
155
40
    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
156
20
                          << ": make new SMT engine" << std::endl;
157
    // Start new SMT engine to check solution
158
40
    std::unique_ptr<SmtEngine> abdChecker;
159
20
    initializeSubsolver(abdChecker, d_env);
160
40
    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
161
20
                          << ": asserting formulas" << std::endl;
162
178
    for (const Node& e : asserts)
163
    {
164
158
      abdChecker->assertFormula(e);
165
    }
166
40
    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
167
20
                          << ": check the assertions" << std::endl;
168
40
    Result r = abdChecker->checkSat();
169
40
    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
170
20
                          << ": result is " << r << std::endl;
171
40
    std::stringstream serr;
172
20
    bool isError = false;
173
20
    if (j == 0)
174
    {
175
10
      if (r.asSatisfiabilityResult().isSat() != Result::SAT)
176
      {
177
        isError = true;
178
        serr << "SmtEngine::checkAbduct(): produced solution cannot be shown "
179
                "to be consisconsistenttent with assertions, result was "
180
             << r;
181
      }
182
20
      Trace("check-abduct")
183
10
          << "SmtEngine::checkAbduct: goal is " << d_abdConj << std::endl;
184
      // add the goal to the set of assertions
185
10
      Assert(!d_abdConj.isNull());
186
10
      asserts.push_back(d_abdConj);
187
    }
188
    else
189
    {
190
10
      if (r.asSatisfiabilityResult().isSat() != Result::UNSAT)
191
      {
192
        isError = true;
193
        serr << "SmtEngine::checkAbduct(): negated goal cannot be shown "
194
                "unsatisfiable with produced solution, result was "
195
             << r;
196
      }
197
    }
198
    // did we get an unexpected result?
199
20
    if (isError)
200
    {
201
      InternalError() << serr.str();
202
    }
203
  }
204
10
}
205
206
}  // namespace smt
207
29502
}  // namespace cvc5