GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/cegqi/vts_term_cache.cpp Lines: 137 146 93.8 %
Date: 2021-09-10 Branches: 306 590 51.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Tianyi Liang
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of virtual term substitution term cache.
14
 */
15
16
#include "theory/quantifiers/cegqi/vts_term_cache.h"
17
18
#include "expr/node_algorithm.h"
19
#include "expr/skolem_manager.h"
20
#include "theory/arith/arith_msum.h"
21
#include "theory/quantifiers/quantifiers_inference_manager.h"
22
#include "theory/rewriter.h"
23
#include "util/rational.h"
24
25
using namespace cvc5::kind;
26
27
namespace cvc5 {
28
namespace theory {
29
namespace quantifiers {
30
31
6708
VtsTermCache::VtsTermCache(QuantifiersInferenceManager& qim) : d_qim(qim)
32
{
33
6708
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
34
6708
}
35
36
11393
void VtsTermCache::getVtsTerms(std::vector<Node>& t,
37
                               bool isFree,
38
                               bool create,
39
                               bool inc_delta)
40
{
41
11393
  if (inc_delta)
42
  {
43
22612
    Node delta = getVtsDelta(isFree, create);
44
11306
    if (!delta.isNull())
45
    {
46
1545
      t.push_back(delta);
47
    }
48
  }
49
11393
  NodeManager* nm = NodeManager::currentNM();
50
34179
  for (unsigned r = 0; r < 2; r++)
51
  {
52
45572
    TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
53
45572
    Node inf = getVtsInfinity(tn, isFree, create);
54
22786
    if (!inf.isNull())
55
    {
56
1292
      t.push_back(inf);
57
    }
58
  }
59
11393
}
60
61
40202
Node VtsTermCache::getVtsDelta(bool isFree, bool create)
62
{
63
40202
  if (create)
64
  {
65
103
    NodeManager* nm = NodeManager::currentNM();
66
103
    SkolemManager* sm = nm->getSkolemManager();
67
103
    if (d_vts_delta_free.isNull())
68
    {
69
23
      d_vts_delta_free =
70
92
          sm->mkDummySkolem("delta_free",
71
46
                            nm->realType(),
72
23
                            "free delta for virtual term substitution");
73
46
      Node delta_lem = nm->mkNode(GT, d_vts_delta_free, d_zero);
74
23
      d_qim.lemma(delta_lem, InferenceId::QUANTIFIERS_CEGQI_VTS_LB_DELTA);
75
    }
76
103
    if (d_vts_delta.isNull())
77
    {
78
92
      d_vts_delta = sm->mkDummySkolem(
79
69
          "delta", nm->realType(), "delta for virtual term substitution");
80
      // mark as a virtual term
81
      VirtualTermSkolemAttribute vtsa;
82
23
      d_vts_delta.setAttribute(vtsa, true);
83
    }
84
  }
85
40202
  return isFree ? d_vts_delta_free : d_vts_delta;
86
}
87
88
53988
Node VtsTermCache::getVtsInfinity(TypeNode tn, bool isFree, bool create)
89
{
90
53988
  if (create)
91
  {
92
152
    NodeManager* nm = NodeManager::currentNM();
93
152
    SkolemManager* sm = nm->getSkolemManager();
94
152
    if (d_vts_inf_free[tn].isNull())
95
    {
96
12
      d_vts_inf_free[tn] = sm->mkDummySkolem(
97
          "inf_free", tn, "free infinity for virtual term substitution");
98
    }
99
152
    if (d_vts_inf[tn].isNull())
100
    {
101
12
      d_vts_inf[tn] = sm->mkDummySkolem(
102
          "inf", tn, "infinity for virtual term substitution");
103
      // mark as a virtual term
104
      VirtualTermSkolemAttribute vtsa;
105
12
      d_vts_inf[tn].setAttribute(vtsa, true);
106
    }
107
  }
108
53988
  return isFree ? d_vts_inf_free[tn] : d_vts_inf[tn];
109
}
110
111
228
Node VtsTermCache::substituteVtsFreeTerms(Node n)
112
{
113
456
  std::vector<Node> vars;
114
228
  getVtsTerms(vars, false, false);
115
456
  std::vector<Node> vars_free;
116
228
  getVtsTerms(vars_free, true, false);
117
228
  Assert(vars.size() == vars_free.size());
118
228
  if (vars.empty())
119
  {
120
    return n;
121
  }
122
  return n.substitute(
123
228
      vars.begin(), vars.end(), vars_free.begin(), vars_free.end());
124
}
125
126
1895
Node VtsTermCache::rewriteVtsSymbols(Node n)
127
{
128
1895
  NodeManager* nm = NodeManager::currentNM();
129
1895
  if ((n.getKind() == EQUAL || n.getKind() == GEQ))
130
  {
131
1172
    Trace("quant-vts-debug") << "VTS : process " << n << std::endl;
132
2344
    Node rew_vts_inf;
133
1172
    bool rew_delta = false;
134
    // rewriting infinity always takes precedence over rewriting delta
135
3516
    for (unsigned r = 0; r < 2; r++)
136
    {
137
4688
      TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
138
4688
      Node inf = getVtsInfinity(tn, false, false);
139
2344
      if (!inf.isNull() && expr::hasSubterm(n, inf))
140
      {
141
632
        if (rew_vts_inf.isNull())
142
        {
143
628
          rew_vts_inf = inf;
144
        }
145
        else
146
        {
147
          // for mixed int/real with multiple infinities
148
8
          Trace("quant-vts-debug") << "Multiple infinities...equate " << inf
149
4
                                   << " = " << rew_vts_inf << std::endl;
150
8
          std::vector<Node> subs_lhs;
151
4
          subs_lhs.push_back(inf);
152
8
          std::vector<Node> subs_rhs;
153
4
          subs_rhs.push_back(rew_vts_inf);
154
4
          n = n.substitute(subs_lhs.begin(),
155
                           subs_lhs.end(),
156
                           subs_rhs.begin(),
157
                           subs_rhs.end());
158
4
          n = Rewriter::rewrite(n);
159
          // may have cancelled
160
4
          if (!expr::hasSubterm(n, rew_vts_inf))
161
          {
162
            rew_vts_inf = Node::null();
163
          }
164
        }
165
      }
166
    }
167
1172
    if (rew_vts_inf.isNull())
168
    {
169
544
      if (!d_vts_delta.isNull() && expr::hasSubterm(n, d_vts_delta))
170
      {
171
330
        rew_delta = true;
172
      }
173
    }
174
1172
    if (!rew_vts_inf.isNull() || rew_delta)
175
    {
176
958
      std::map<Node, Node> msum;
177
958
      if (ArithMSum::getMonomialSumLit(n, msum))
178
      {
179
958
        if (Trace.isOn("quant-vts-debug"))
180
        {
181
          Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl;
182
          ArithMSum::debugPrintMonomialSum(msum, "quant-vts-debug");
183
        }
184
1916
        Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta;
185
958
        Assert(!vts_sym.isNull());
186
1916
        Node iso_n;
187
1916
        Node nlit;
188
958
        int res = ArithMSum::isolate(vts_sym, msum, iso_n, n.getKind(), true);
189
958
        if (res != 0)
190
        {
191
1900
          Trace("quant-vts-debug") << "VTS isolated :  -> " << iso_n
192
950
                                   << ", res = " << res << std::endl;
193
1900
          Node slv = iso_n[res == 1 ? 1 : 0];
194
          // ensure the vts terms have been eliminated
195
950
          if (containsVtsTerm(slv))
196
          {
197
368
            Trace("quant-vts-warn")
198
184
                << "Bad vts literal : " << n << ", contains " << vts_sym
199
184
                << " but bad solved form " << slv << "." << std::endl;
200
            // safe case: just convert to free symbols
201
184
            nlit = substituteVtsFreeTerms(n);
202
184
            Trace("quant-vts-debug") << "...return " << nlit << std::endl;
203
184
            return nlit;
204
          }
205
          else
206
          {
207
766
            if (!rew_vts_inf.isNull())
208
            {
209
444
              nlit = nm->mkConst(n.getKind() == GEQ && res == 1);
210
            }
211
            else
212
            {
213
322
              Assert(iso_n[res == 1 ? 0 : 1] == d_vts_delta);
214
322
              if (n.getKind() == EQUAL)
215
              {
216
38
                nlit = nm->mkConst(false);
217
              }
218
284
              else if (res == 1)
219
              {
220
154
                nlit = nm->mkNode(GEQ, d_zero, slv);
221
              }
222
              else
223
              {
224
130
                nlit = nm->mkNode(GT, slv, d_zero);
225
              }
226
            }
227
          }
228
766
          Trace("quant-vts-debug") << "Return " << nlit << std::endl;
229
766
          return nlit;
230
        }
231
        else
232
        {
233
16
          Trace("quant-vts-warn")
234
8
              << "Bad vts literal : " << n << ", contains " << vts_sym
235
8
              << " but could not isolate." << std::endl;
236
          // safe case: just convert to free symbols
237
8
          nlit = substituteVtsFreeTerms(n);
238
8
          Trace("quant-vts-debug") << "...return " << nlit << std::endl;
239
8
          return nlit;
240
        }
241
      }
242
    }
243
214
    return n;
244
  }
245
723
  else if (n.getKind() == FORALL)
246
  {
247
    // cannot traverse beneath quantifiers
248
36
    return substituteVtsFreeTerms(n);
249
  }
250
687
  bool childChanged = false;
251
1374
  std::vector<Node> children;
252
2503
  for (const Node& nc : n)
253
  {
254
3632
    Node nn = rewriteVtsSymbols(nc);
255
1816
    children.push_back(nn);
256
1816
    childChanged = childChanged || nn != nc;
257
  }
258
687
  if (childChanged)
259
  {
260
671
    if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
261
    {
262
      children.insert(children.begin(), n.getOperator());
263
    }
264
1342
    Node ret = nm->mkNode(n.getKind(), children);
265
671
    Trace("quant-vts-debug") << "...make node " << ret << std::endl;
266
671
    return ret;
267
  }
268
16
  return n;
269
}
270
271
950
bool VtsTermCache::containsVtsTerm(Node n, bool isFree)
272
{
273
1900
  std::vector<Node> t;
274
950
  getVtsTerms(t, isFree, false);
275
1900
  return expr::hasSubterm(n, t);
276
}
277
278
9900
bool VtsTermCache::containsVtsTerm(std::vector<Node>& n, bool isFree)
279
{
280
19800
  std::vector<Node> t;
281
9900
  getVtsTerms(t, isFree, false);
282
9900
  if (!t.empty())
283
  {
284
229
    for (const Node& nc : n)
285
    {
286
165
      if (expr::hasSubterm(nc, t))
287
      {
288
83
        return true;
289
      }
290
    }
291
  }
292
9817
  return false;
293
}
294
295
bool VtsTermCache::containsVtsInfinity(Node n, bool isFree)
296
{
297
  std::vector<Node> t;
298
  getVtsTerms(t, isFree, false, false);
299
  return expr::hasSubterm(n, t);
300
}
301
302
}  // namespace quantifiers
303
}  // namespace theory
304
29502
}  // namespace cvc5