GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/cegqi/vts_term_cache.cpp Lines: 137 146 93.8 %
Date: 2021-05-22 Branches: 306 592 51.7 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Tianyi Liang
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of virtual term substitution term cache.
14
 */
15
16
#include "theory/quantifiers/cegqi/vts_term_cache.h"
17
18
#include "expr/node_algorithm.h"
19
#include "expr/skolem_manager.h"
20
#include "theory/arith/arith_msum.h"
21
#include "theory/quantifiers/quantifiers_inference_manager.h"
22
#include "theory/rewriter.h"
23
24
using namespace cvc5::kind;
25
26
namespace cvc5 {
27
namespace theory {
28
namespace quantifiers {
29
30
6323
VtsTermCache::VtsTermCache(QuantifiersInferenceManager& qim) : d_qim(qim)
31
{
32
6323
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
33
6323
}
34
35
6503
void VtsTermCache::getVtsTerms(std::vector<Node>& t,
36
                               bool isFree,
37
                               bool create,
38
                               bool inc_delta)
39
{
40
6503
  if (inc_delta)
41
  {
42
12824
    Node delta = getVtsDelta(isFree, create);
43
6412
    if (!delta.isNull())
44
    {
45
1575
      t.push_back(delta);
46
    }
47
  }
48
6503
  NodeManager* nm = NodeManager::currentNM();
49
19509
  for (unsigned r = 0; r < 2; r++)
50
  {
51
26012
    TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
52
26012
    Node inf = getVtsInfinity(tn, isFree, create);
53
13006
    if (!inf.isNull())
54
    {
55
1292
      t.push_back(inf);
56
    }
57
  }
58
6503
}
59
60
18514
Node VtsTermCache::getVtsDelta(bool isFree, bool create)
61
{
62
18514
  if (create)
63
  {
64
107
    NodeManager* nm = NodeManager::currentNM();
65
107
    SkolemManager* sm = nm->getSkolemManager();
66
107
    if (d_vts_delta_free.isNull())
67
    {
68
25
      d_vts_delta_free =
69
100
          sm->mkDummySkolem("delta_free",
70
50
                            nm->realType(),
71
25
                            "free delta for virtual term substitution");
72
50
      Node delta_lem = nm->mkNode(GT, d_vts_delta_free, d_zero);
73
25
      d_qim.lemma(delta_lem, InferenceId::QUANTIFIERS_CEGQI_VTS_LB_DELTA);
74
    }
75
107
    if (d_vts_delta.isNull())
76
    {
77
100
      d_vts_delta = sm->mkDummySkolem(
78
75
          "delta", nm->realType(), "delta for virtual term substitution");
79
      // mark as a virtual term
80
      VirtualTermSkolemAttribute vtsa;
81
25
      d_vts_delta.setAttribute(vtsa, true);
82
    }
83
  }
84
18514
  return isFree ? d_vts_delta_free : d_vts_delta;
85
}
86
87
27422
Node VtsTermCache::getVtsInfinity(TypeNode tn, bool isFree, bool create)
88
{
89
27422
  if (create)
90
  {
91
152
    NodeManager* nm = NodeManager::currentNM();
92
152
    SkolemManager* sm = nm->getSkolemManager();
93
152
    if (d_vts_inf_free[tn].isNull())
94
    {
95
12
      d_vts_inf_free[tn] = sm->mkDummySkolem(
96
          "inf_free", tn, "free infinity for virtual term substitution");
97
    }
98
152
    if (d_vts_inf[tn].isNull())
99
    {
100
12
      d_vts_inf[tn] = sm->mkDummySkolem(
101
          "inf", tn, "infinity for virtual term substitution");
102
      // mark as a virtual term
103
      VirtualTermSkolemAttribute vtsa;
104
12
      d_vts_inf[tn].setAttribute(vtsa, true);
105
    }
106
  }
107
27422
  return isFree ? d_vts_inf_free[tn] : d_vts_inf[tn];
108
}
109
110
228
Node VtsTermCache::substituteVtsFreeTerms(Node n)
111
{
112
456
  std::vector<Node> vars;
113
228
  getVtsTerms(vars, false, false);
114
456
  std::vector<Node> vars_free;
115
228
  getVtsTerms(vars_free, true, false);
116
228
  Assert(vars.size() == vars_free.size());
117
228
  if (vars.empty())
118
  {
119
    return n;
120
  }
121
  return n.substitute(
122
228
      vars.begin(), vars.end(), vars_free.begin(), vars_free.end());
123
}
124
125
1911
Node VtsTermCache::rewriteVtsSymbols(Node n)
126
{
127
1911
  NodeManager* nm = NodeManager::currentNM();
128
1911
  if ((n.getKind() == EQUAL || n.getKind() == GEQ))
129
  {
130
1180
    Trace("quant-vts-debug") << "VTS : process " << n << std::endl;
131
2360
    Node rew_vts_inf;
132
1180
    bool rew_delta = false;
133
    // rewriting infinity always takes precedence over rewriting delta
134
3540
    for (unsigned r = 0; r < 2; r++)
135
    {
136
4720
      TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
137
4720
      Node inf = getVtsInfinity(tn, false, false);
138
2360
      if (!inf.isNull() && expr::hasSubterm(n, inf))
139
      {
140
632
        if (rew_vts_inf.isNull())
141
        {
142
628
          rew_vts_inf = inf;
143
        }
144
        else
145
        {
146
          // for mixed int/real with multiple infinities
147
8
          Trace("quant-vts-debug") << "Multiple infinities...equate " << inf
148
4
                                   << " = " << rew_vts_inf << std::endl;
149
8
          std::vector<Node> subs_lhs;
150
4
          subs_lhs.push_back(inf);
151
8
          std::vector<Node> subs_rhs;
152
4
          subs_rhs.push_back(rew_vts_inf);
153
4
          n = n.substitute(subs_lhs.begin(),
154
                           subs_lhs.end(),
155
                           subs_rhs.begin(),
156
                           subs_rhs.end());
157
4
          n = Rewriter::rewrite(n);
158
          // may have cancelled
159
4
          if (!expr::hasSubterm(n, rew_vts_inf))
160
          {
161
            rew_vts_inf = Node::null();
162
          }
163
        }
164
      }
165
    }
166
1180
    if (rew_vts_inf.isNull())
167
    {
168
552
      if (!d_vts_delta.isNull() && expr::hasSubterm(n, d_vts_delta))
169
      {
170
338
        rew_delta = true;
171
      }
172
    }
173
1180
    if (!rew_vts_inf.isNull() || rew_delta)
174
    {
175
966
      std::map<Node, Node> msum;
176
966
      if (ArithMSum::getMonomialSumLit(n, msum))
177
      {
178
966
        if (Trace.isOn("quant-vts-debug"))
179
        {
180
          Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl;
181
          ArithMSum::debugPrintMonomialSum(msum, "quant-vts-debug");
182
        }
183
1932
        Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta;
184
966
        Assert(!vts_sym.isNull());
185
1932
        Node iso_n;
186
1932
        Node nlit;
187
966
        int res = ArithMSum::isolate(vts_sym, msum, iso_n, n.getKind(), true);
188
966
        if (res != 0)
189
        {
190
1916
          Trace("quant-vts-debug") << "VTS isolated :  -> " << iso_n
191
958
                                   << ", res = " << res << std::endl;
192
1916
          Node slv = iso_n[res == 1 ? 1 : 0];
193
          // ensure the vts terms have been eliminated
194
958
          if (containsVtsTerm(slv))
195
          {
196
368
            Trace("quant-vts-warn")
197
184
                << "Bad vts literal : " << n << ", contains " << vts_sym
198
184
                << " but bad solved form " << slv << "." << std::endl;
199
            // safe case: just convert to free symbols
200
184
            nlit = substituteVtsFreeTerms(n);
201
184
            Trace("quant-vts-debug") << "...return " << nlit << std::endl;
202
184
            return nlit;
203
          }
204
          else
205
          {
206
774
            if (!rew_vts_inf.isNull())
207
            {
208
444
              nlit = nm->mkConst(n.getKind() == GEQ && res == 1);
209
            }
210
            else
211
            {
212
330
              Assert(iso_n[res == 1 ? 0 : 1] == d_vts_delta);
213
330
              if (n.getKind() == EQUAL)
214
              {
215
42
                nlit = nm->mkConst(false);
216
              }
217
288
              else if (res == 1)
218
              {
219
160
                nlit = nm->mkNode(GEQ, d_zero, slv);
220
              }
221
              else
222
              {
223
128
                nlit = nm->mkNode(GT, slv, d_zero);
224
              }
225
            }
226
          }
227
774
          Trace("quant-vts-debug") << "Return " << nlit << std::endl;
228
774
          return nlit;
229
        }
230
        else
231
        {
232
16
          Trace("quant-vts-warn")
233
8
              << "Bad vts literal : " << n << ", contains " << vts_sym
234
8
              << " but could not isolate." << std::endl;
235
          // safe case: just convert to free symbols
236
8
          nlit = substituteVtsFreeTerms(n);
237
8
          Trace("quant-vts-debug") << "...return " << nlit << std::endl;
238
8
          return nlit;
239
        }
240
      }
241
    }
242
214
    return n;
243
  }
244
731
  else if (n.getKind() == FORALL)
245
  {
246
    // cannot traverse beneath quantifiers
247
36
    return substituteVtsFreeTerms(n);
248
  }
249
695
  bool childChanged = false;
250
1390
  std::vector<Node> children;
251
2523
  for (const Node& nc : n)
252
  {
253
3656
    Node nn = rewriteVtsSymbols(nc);
254
1828
    children.push_back(nn);
255
1828
    childChanged = childChanged || nn != nc;
256
  }
257
695
  if (childChanged)
258
  {
259
679
    if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
260
    {
261
      children.insert(children.begin(), n.getOperator());
262
    }
263
1358
    Node ret = nm->mkNode(n.getKind(), children);
264
679
    Trace("quant-vts-debug") << "...make node " << ret << std::endl;
265
679
    return ret;
266
  }
267
16
  return n;
268
}
269
270
958
bool VtsTermCache::containsVtsTerm(Node n, bool isFree)
271
{
272
1916
  std::vector<Node> t;
273
958
  getVtsTerms(t, isFree, false);
274
1916
  return expr::hasSubterm(n, t);
275
}
276
277
4998
bool VtsTermCache::containsVtsTerm(std::vector<Node>& n, bool isFree)
278
{
279
9996
  std::vector<Node> t;
280
4998
  getVtsTerms(t, isFree, false);
281
4998
  if (!t.empty())
282
  {
283
269
    for (const Node& nc : n)
284
    {
285
187
      if (expr::hasSubterm(nc, t))
286
      {
287
87
        return true;
288
      }
289
    }
290
  }
291
4911
  return false;
292
}
293
294
bool VtsTermCache::containsVtsInfinity(Node n, bool isFree)
295
{
296
  std::vector<Node> t;
297
  getVtsTerms(t, isFree, false, false);
298
  return expr::hasSubterm(n, t);
299
}
300
301
}  // namespace quantifiers
302
}  // namespace theory
303
28194
}  // namespace cvc5