GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/cegqi/vts_term_cache.cpp Lines: 137 146 93.8 %
Date: 2021-09-29 Branches: 306 590 51.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Tianyi Liang
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of virtual term substitution term cache.
14
 */
15
16
#include "theory/quantifiers/cegqi/vts_term_cache.h"
17
18
#include "expr/node_algorithm.h"
19
#include "expr/skolem_manager.h"
20
#include "theory/arith/arith_msum.h"
21
#include "theory/quantifiers/quantifiers_inference_manager.h"
22
#include "theory/rewriter.h"
23
#include "util/rational.h"
24
25
using namespace cvc5::kind;
26
27
namespace cvc5 {
28
namespace theory {
29
namespace quantifiers {
30
31
4694
VtsTermCache::VtsTermCache(QuantifiersInferenceManager& qim) : d_qim(qim)
32
{
33
4694
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
34
4694
}
35
36
6020
void VtsTermCache::getVtsTerms(std::vector<Node>& t,
37
                               bool isFree,
38
                               bool create,
39
                               bool inc_delta)
40
{
41
6020
  if (inc_delta)
42
  {
43
11968
    Node delta = getVtsDelta(isFree, create);
44
5984
    if (!delta.isNull())
45
    {
46
393
      t.push_back(delta);
47
    }
48
  }
49
6020
  NodeManager* nm = NodeManager::currentNM();
50
18060
  for (unsigned r = 0; r < 2; r++)
51
  {
52
24080
    TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
53
24080
    Node inf = getVtsInfinity(tn, isFree, create);
54
12040
    if (!inf.isNull())
55
    {
56
323
      t.push_back(inf);
57
    }
58
  }
59
6020
}
60
61
21302
Node VtsTermCache::getVtsDelta(bool isFree, bool create)
62
{
63
21302
  if (create)
64
  {
65
28
    NodeManager* nm = NodeManager::currentNM();
66
28
    SkolemManager* sm = nm->getSkolemManager();
67
28
    if (d_vts_delta_free.isNull())
68
    {
69
8
      d_vts_delta_free =
70
32
          sm->mkDummySkolem("delta_free",
71
16
                            nm->realType(),
72
8
                            "free delta for virtual term substitution");
73
16
      Node delta_lem = nm->mkNode(GT, d_vts_delta_free, d_zero);
74
8
      d_qim.lemma(delta_lem, InferenceId::QUANTIFIERS_CEGQI_VTS_LB_DELTA);
75
    }
76
28
    if (d_vts_delta.isNull())
77
    {
78
32
      d_vts_delta = sm->mkDummySkolem(
79
24
          "delta", nm->realType(), "delta for virtual term substitution");
80
      // mark as a virtual term
81
      VirtualTermSkolemAttribute vtsa;
82
8
      d_vts_delta.setAttribute(vtsa, true);
83
    }
84
  }
85
21302
  return isFree ? d_vts_delta_free : d_vts_delta;
86
}
87
88
27930
Node VtsTermCache::getVtsInfinity(TypeNode tn, bool isFree, bool create)
89
{
90
27930
  if (create)
91
  {
92
38
    NodeManager* nm = NodeManager::currentNM();
93
38
    SkolemManager* sm = nm->getSkolemManager();
94
38
    if (d_vts_inf_free[tn].isNull())
95
    {
96
3
      d_vts_inf_free[tn] = sm->mkDummySkolem(
97
          "inf_free", tn, "free infinity for virtual term substitution");
98
    }
99
38
    if (d_vts_inf[tn].isNull())
100
    {
101
3
      d_vts_inf[tn] = sm->mkDummySkolem(
102
          "inf", tn, "infinity for virtual term substitution");
103
      // mark as a virtual term
104
      VirtualTermSkolemAttribute vtsa;
105
3
      d_vts_inf[tn].setAttribute(vtsa, true);
106
    }
107
  }
108
27930
  return isFree ? d_vts_inf_free[tn] : d_vts_inf[tn];
109
}
110
111
57
Node VtsTermCache::substituteVtsFreeTerms(Node n)
112
{
113
114
  std::vector<Node> vars;
114
57
  getVtsTerms(vars, false, false);
115
114
  std::vector<Node> vars_free;
116
57
  getVtsTerms(vars_free, true, false);
117
57
  Assert(vars.size() == vars_free.size());
118
57
  if (vars.empty())
119
  {
120
    return n;
121
  }
122
  return n.substitute(
123
57
      vars.begin(), vars.end(), vars_free.begin(), vars_free.end());
124
}
125
126
497
Node VtsTermCache::rewriteVtsSymbols(Node n)
127
{
128
497
  NodeManager* nm = NodeManager::currentNM();
129
497
  if ((n.getKind() == EQUAL || n.getKind() == GEQ))
130
  {
131
299
    Trace("quant-vts-debug") << "VTS : process " << n << std::endl;
132
598
    Node rew_vts_inf;
133
299
    bool rew_delta = false;
134
    // rewriting infinity always takes precedence over rewriting delta
135
897
    for (unsigned r = 0; r < 2; r++)
136
    {
137
1196
      TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
138
1196
      Node inf = getVtsInfinity(tn, false, false);
139
598
      if (!inf.isNull() && expr::hasSubterm(n, inf))
140
      {
141
158
        if (rew_vts_inf.isNull())
142
        {
143
157
          rew_vts_inf = inf;
144
        }
145
        else
146
        {
147
          // for mixed int/real with multiple infinities
148
2
          Trace("quant-vts-debug") << "Multiple infinities...equate " << inf
149
1
                                   << " = " << rew_vts_inf << std::endl;
150
2
          std::vector<Node> subs_lhs;
151
1
          subs_lhs.push_back(inf);
152
2
          std::vector<Node> subs_rhs;
153
1
          subs_rhs.push_back(rew_vts_inf);
154
1
          n = n.substitute(subs_lhs.begin(),
155
                           subs_lhs.end(),
156
                           subs_rhs.begin(),
157
                           subs_rhs.end());
158
1
          n = Rewriter::rewrite(n);
159
          // may have cancelled
160
1
          if (!expr::hasSubterm(n, rew_vts_inf))
161
          {
162
            rew_vts_inf = Node::null();
163
          }
164
        }
165
      }
166
    }
167
299
    if (rew_vts_inf.isNull())
168
    {
169
142
      if (!d_vts_delta.isNull() && expr::hasSubterm(n, d_vts_delta))
170
      {
171
87
        rew_delta = true;
172
      }
173
    }
174
299
    if (!rew_vts_inf.isNull() || rew_delta)
175
    {
176
244
      std::map<Node, Node> msum;
177
244
      if (ArithMSum::getMonomialSumLit(n, msum))
178
      {
179
244
        if (Trace.isOn("quant-vts-debug"))
180
        {
181
          Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl;
182
          ArithMSum::debugPrintMonomialSum(msum, "quant-vts-debug");
183
        }
184
488
        Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta;
185
244
        Assert(!vts_sym.isNull());
186
488
        Node iso_n;
187
488
        Node nlit;
188
244
        int res = ArithMSum::isolate(vts_sym, msum, iso_n, n.getKind(), true);
189
244
        if (res != 0)
190
        {
191
484
          Trace("quant-vts-debug") << "VTS isolated :  -> " << iso_n
192
242
                                   << ", res = " << res << std::endl;
193
484
          Node slv = iso_n[res == 1 ? 1 : 0];
194
          // ensure the vts terms have been eliminated
195
242
          if (containsVtsTerm(slv))
196
          {
197
92
            Trace("quant-vts-warn")
198
46
                << "Bad vts literal : " << n << ", contains " << vts_sym
199
46
                << " but bad solved form " << slv << "." << std::endl;
200
            // safe case: just convert to free symbols
201
46
            nlit = substituteVtsFreeTerms(n);
202
46
            Trace("quant-vts-debug") << "...return " << nlit << std::endl;
203
46
            return nlit;
204
          }
205
          else
206
          {
207
196
            if (!rew_vts_inf.isNull())
208
            {
209
111
              nlit = nm->mkConst(n.getKind() == GEQ && res == 1);
210
            }
211
            else
212
            {
213
85
              Assert(iso_n[res == 1 ? 0 : 1] == d_vts_delta);
214
85
              if (n.getKind() == EQUAL)
215
              {
216
11
                nlit = nm->mkConst(false);
217
              }
218
74
              else if (res == 1)
219
              {
220
40
                nlit = nm->mkNode(GEQ, d_zero, slv);
221
              }
222
              else
223
              {
224
34
                nlit = nm->mkNode(GT, slv, d_zero);
225
              }
226
            }
227
          }
228
196
          Trace("quant-vts-debug") << "Return " << nlit << std::endl;
229
196
          return nlit;
230
        }
231
        else
232
        {
233
4
          Trace("quant-vts-warn")
234
2
              << "Bad vts literal : " << n << ", contains " << vts_sym
235
2
              << " but could not isolate." << std::endl;
236
          // safe case: just convert to free symbols
237
2
          nlit = substituteVtsFreeTerms(n);
238
2
          Trace("quant-vts-debug") << "...return " << nlit << std::endl;
239
2
          return nlit;
240
        }
241
      }
242
    }
243
55
    return n;
244
  }
245
198
  else if (n.getKind() == FORALL)
246
  {
247
    // cannot traverse beneath quantifiers
248
9
    return substituteVtsFreeTerms(n);
249
  }
250
189
  bool childChanged = false;
251
378
  std::vector<Node> children;
252
664
  for (const Node& nc : n)
253
  {
254
950
    Node nn = rewriteVtsSymbols(nc);
255
475
    children.push_back(nn);
256
475
    childChanged = childChanged || nn != nc;
257
  }
258
189
  if (childChanged)
259
  {
260
176
    if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
261
    {
262
      children.insert(children.begin(), n.getOperator());
263
    }
264
352
    Node ret = nm->mkNode(n.getKind(), children);
265
176
    Trace("quant-vts-debug") << "...make node " << ret << std::endl;
266
176
    return ret;
267
  }
268
13
  return n;
269
}
270
271
242
bool VtsTermCache::containsVtsTerm(Node n, bool isFree)
272
{
273
484
  std::vector<Node> t;
274
242
  getVtsTerms(t, isFree, false);
275
484
  return expr::hasSubterm(n, t);
276
}
277
278
5628
bool VtsTermCache::containsVtsTerm(std::vector<Node>& n, bool isFree)
279
{
280
11256
  std::vector<Node> t;
281
5628
  getVtsTerms(t, isFree, false);
282
5628
  if (!t.empty())
283
  {
284
73
    for (const Node& nc : n)
285
    {
286
57
      if (expr::hasSubterm(nc, t))
287
      {
288
23
        return true;
289
      }
290
    }
291
  }
292
5605
  return false;
293
}
294
295
bool VtsTermCache::containsVtsInfinity(Node n, bool isFree)
296
{
297
  std::vector<Node> t;
298
  getVtsTerms(t, isFree, false, false);
299
  return expr::hasSubterm(n, t);
300
}
301
302
}  // namespace quantifiers
303
}  // namespace theory
304
22746
}  // namespace cvc5