GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/cegqi/vts_term_cache.cpp Lines: 130 145 89.7 %
Date: 2021-03-22 Branches: 288 592 48.6 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file vts_term_cache.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Andres Noetzli, Tianyi Liang
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of virtual term substitution term cache.
13
 **/
14
15
#include "theory/quantifiers/cegqi/vts_term_cache.h"
16
17
#include "expr/node_algorithm.h"
18
#include "theory/arith/arith_msum.h"
19
#include "theory/quantifiers/quantifiers_inference_manager.h"
20
#include "theory/rewriter.h"
21
22
using namespace CVC4::kind;
23
24
namespace CVC4 {
25
namespace theory {
26
namespace quantifiers {
27
28
5975
VtsTermCache::VtsTermCache(QuantifiersInferenceManager& qim) : d_qim(qim)
29
{
30
5975
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
31
5975
}
32
33
6476
void VtsTermCache::getVtsTerms(std::vector<Node>& t,
34
                               bool isFree,
35
                               bool create,
36
                               bool inc_delta)
37
{
38
6476
  if (inc_delta)
39
  {
40
12768
    Node delta = getVtsDelta(isFree, create);
41
6384
    if (!delta.isNull())
42
    {
43
1502
      t.push_back(delta);
44
    }
45
  }
46
6476
  NodeManager* nm = NodeManager::currentNM();
47
19428
  for (unsigned r = 0; r < 2; r++)
48
  {
49
25904
    TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
50
25904
    Node inf = getVtsInfinity(tn, isFree, create);
51
12952
    if (!inf.isNull())
52
    {
53
1292
      t.push_back(inf);
54
    }
55
  }
56
6476
}
57
58
18716
Node VtsTermCache::getVtsDelta(bool isFree, bool create)
59
{
60
18716
  if (create)
61
  {
62
79
    NodeManager* nm = NodeManager::currentNM();
63
79
    if (d_vts_delta_free.isNull())
64
    {
65
14
      d_vts_delta_free =
66
56
          nm->mkSkolem("delta_free",
67
28
                       nm->realType(),
68
14
                       "free delta for virtual term substitution");
69
28
      Node delta_lem = nm->mkNode(GT, d_vts_delta_free, d_zero);
70
14
      d_qim.lemma(delta_lem, InferenceId::QUANTIFIERS_CEGQI_VTS_LB_DELTA);
71
    }
72
79
    if (d_vts_delta.isNull())
73
    {
74
56
      d_vts_delta = nm->mkSkolem(
75
42
          "delta", nm->realType(), "delta for virtual term substitution");
76
      // mark as a virtual term
77
      VirtualTermSkolemAttribute vtsa;
78
14
      d_vts_delta.setAttribute(vtsa, true);
79
    }
80
  }
81
18716
  return isFree ? d_vts_delta_free : d_vts_delta;
82
}
83
84
27533
Node VtsTermCache::getVtsInfinity(TypeNode tn, bool isFree, bool create)
85
{
86
27533
  if (create)
87
  {
88
152
    NodeManager* nm = NodeManager::currentNM();
89
152
    if (d_vts_inf_free[tn].isNull())
90
    {
91
12
      d_vts_inf_free[tn] = nm->mkSkolem(
92
          "inf_free", tn, "free infinity for virtual term substitution");
93
    }
94
152
    if (d_vts_inf[tn].isNull())
95
    {
96
12
      d_vts_inf[tn] =
97
24
          nm->mkSkolem("inf", tn, "infinity for virtual term substitution");
98
      // mark as a virtual term
99
      VirtualTermSkolemAttribute vtsa;
100
12
      d_vts_inf[tn].setAttribute(vtsa, true);
101
    }
102
  }
103
27533
  return isFree ? d_vts_inf_free[tn] : d_vts_inf[tn];
104
}
105
106
220
Node VtsTermCache::substituteVtsFreeTerms(Node n)
107
{
108
440
  std::vector<Node> vars;
109
220
  getVtsTerms(vars, false, false);
110
440
  std::vector<Node> vars_free;
111
220
  getVtsTerms(vars_free, true, false);
112
220
  Assert(vars.size() == vars_free.size());
113
220
  if (vars.empty())
114
  {
115
    return n;
116
  }
117
  return n.substitute(
118
220
      vars.begin(), vars.end(), vars_free.begin(), vars_free.end());
119
}
120
121
1823
Node VtsTermCache::rewriteVtsSymbols(Node n)
122
{
123
1823
  NodeManager* nm = NodeManager::currentNM();
124
1823
  if ((n.getKind() == EQUAL || n.getKind() == GEQ))
125
  {
126
1134
    Trace("quant-vts-debug") << "VTS : process " << n << std::endl;
127
2268
    Node rew_vts_inf;
128
1134
    bool rew_delta = false;
129
    // rewriting infinity always takes precedence over rewriting delta
130
3402
    for (unsigned r = 0; r < 2; r++)
131
    {
132
4536
      TypeNode tn = r == 0 ? nm->realType() : nm->integerType();
133
4536
      Node inf = getVtsInfinity(tn, false, false);
134
2268
      if (!inf.isNull() && expr::hasSubterm(n, inf))
135
      {
136
632
        if (rew_vts_inf.isNull())
137
        {
138
628
          rew_vts_inf = inf;
139
        }
140
        else
141
        {
142
          // for mixed int/real with multiple infinities
143
8
          Trace("quant-vts-debug") << "Multiple infinities...equate " << inf
144
4
                                   << " = " << rew_vts_inf << std::endl;
145
8
          std::vector<Node> subs_lhs;
146
4
          subs_lhs.push_back(inf);
147
8
          std::vector<Node> subs_rhs;
148
4
          subs_rhs.push_back(rew_vts_inf);
149
4
          n = n.substitute(subs_lhs.begin(),
150
                           subs_lhs.end(),
151
                           subs_rhs.begin(),
152
                           subs_rhs.end());
153
4
          n = Rewriter::rewrite(n);
154
          // may have cancelled
155
4
          if (!expr::hasSubterm(n, rew_vts_inf))
156
          {
157
            rew_vts_inf = Node::null();
158
          }
159
        }
160
      }
161
    }
162
1134
    if (rew_vts_inf.isNull())
163
    {
164
506
      if (!d_vts_delta.isNull() && expr::hasSubterm(n, d_vts_delta))
165
      {
166
298
        rew_delta = true;
167
      }
168
    }
169
1134
    if (!rew_vts_inf.isNull() || rew_delta)
170
    {
171
926
      std::map<Node, Node> msum;
172
926
      if (ArithMSum::getMonomialSumLit(n, msum))
173
      {
174
926
        if (Trace.isOn("quant-vts-debug"))
175
        {
176
          Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl;
177
          ArithMSum::debugPrintMonomialSum(msum, "quant-vts-debug");
178
        }
179
1852
        Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta;
180
926
        Assert(!vts_sym.isNull());
181
1852
        Node iso_n;
182
1852
        Node nlit;
183
926
        int res = ArithMSum::isolate(vts_sym, msum, iso_n, n.getKind(), true);
184
926
        if (res != 0)
185
        {
186
1852
          Trace("quant-vts-debug") << "VTS isolated :  -> " << iso_n
187
926
                                   << ", res = " << res << std::endl;
188
1852
          Node slv = iso_n[res == 1 ? 1 : 0];
189
          // ensure the vts terms have been eliminated
190
926
          if (containsVtsTerm(slv))
191
          {
192
368
            Trace("quant-vts-warn")
193
184
                << "Bad vts literal : " << n << ", contains " << vts_sym
194
184
                << " but bad solved form " << slv << "." << std::endl;
195
            // safe case: just convert to free symbols
196
184
            nlit = substituteVtsFreeTerms(n);
197
184
            Trace("quant-vts-debug") << "...return " << nlit << std::endl;
198
184
            return nlit;
199
          }
200
          else
201
          {
202
742
            if (!rew_vts_inf.isNull())
203
            {
204
444
              nlit = nm->mkConst(n.getKind() == GEQ && res == 1);
205
            }
206
            else
207
            {
208
298
              Assert(iso_n[res == 1 ? 0 : 1] == d_vts_delta);
209
298
              if (n.getKind() == EQUAL)
210
              {
211
22
                nlit = nm->mkConst(false);
212
              }
213
276
              else if (res == 1)
214
              {
215
162
                nlit = nm->mkNode(GEQ, d_zero, slv);
216
              }
217
              else
218
              {
219
114
                nlit = nm->mkNode(GT, slv, d_zero);
220
              }
221
            }
222
          }
223
742
          Trace("quant-vts-debug") << "Return " << nlit << std::endl;
224
742
          return nlit;
225
        }
226
        else
227
        {
228
          Trace("quant-vts-warn")
229
              << "Bad vts literal : " << n << ", contains " << vts_sym
230
              << " but could not isolate." << std::endl;
231
          // safe case: just convert to free symbols
232
          nlit = substituteVtsFreeTerms(n);
233
          Trace("quant-vts-debug") << "...return " << nlit << std::endl;
234
          return nlit;
235
        }
236
      }
237
    }
238
208
    return n;
239
  }
240
689
  else if (n.getKind() == FORALL)
241
  {
242
    // cannot traverse beneath quantifiers
243
36
    return substituteVtsFreeTerms(n);
244
  }
245
653
  bool childChanged = false;
246
1306
  std::vector<Node> children;
247
2409
  for (const Node& nc : n)
248
  {
249
3512
    Node nn = rewriteVtsSymbols(nc);
250
1756
    children.push_back(nn);
251
1756
    childChanged = childChanged || nn != nc;
252
  }
253
653
  if (childChanged)
254
  {
255
649
    if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
256
    {
257
      children.insert(children.begin(), n.getOperator());
258
    }
259
1298
    Node ret = nm->mkNode(n.getKind(), children);
260
649
    Trace("quant-vts-debug") << "...make node " << ret << std::endl;
261
649
    return ret;
262
  }
263
4
  return n;
264
}
265
266
926
bool VtsTermCache::containsVtsTerm(Node n, bool isFree)
267
{
268
1852
  std::vector<Node> t;
269
926
  getVtsTerms(t, isFree, false);
270
1852
  return expr::hasSubterm(n, t);
271
}
272
273
5018
bool VtsTermCache::containsVtsTerm(std::vector<Node>& n, bool isFree)
274
{
275
10036
  std::vector<Node> t;
276
5018
  getVtsTerms(t, isFree, false);
277
5018
  if (!t.empty())
278
  {
279
221
    for (const Node& nc : n)
280
    {
281
144
      if (expr::hasSubterm(nc, t))
282
      {
283
67
        return true;
284
      }
285
    }
286
  }
287
4951
  return false;
288
}
289
290
bool VtsTermCache::containsVtsInfinity(Node n, bool isFree)
291
{
292
  std::vector<Node> t;
293
  getVtsTerms(t, isFree, false, false);
294
  return expr::hasSubterm(n, t);
295
}
296
297
}  // namespace quantifiers
298
}  // namespace theory
299
26676
}  // namespace CVC4