GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/bv_inverter.cpp Lines: 186 200 93.0 %
Date: 2021-03-22 Branches: 475 1060 44.8 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file bv_inverter.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Aina Niemetz, Andrew Reynolds, Mathias Preiner
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief inverse rules for bit-vector operators
13
 **/
14
15
#include "theory/quantifiers/bv_inverter.h"
16
17
#include <algorithm>
18
19
#include "options/quantifiers_options.h"
20
#include "theory/bv/theory_bv_utils.h"
21
#include "theory/quantifiers/bv_inverter_utils.h"
22
#include "theory/quantifiers/term_util.h"
23
#include "theory/rewriter.h"
24
25
using namespace CVC4::kind;
26
27
namespace CVC4 {
28
namespace theory {
29
namespace quantifiers {
30
31
/*---------------------------------------------------------------------------*/
32
33
10064
Node BvInverter::getSolveVariable(TypeNode tn)
34
{
35
10064
  std::map<TypeNode, Node>::iterator its = d_solve_var.find(tn);
36
10064
  if (its == d_solve_var.end())
37
  {
38
2128
    Node k = NodeManager::currentNM()->mkSkolem("slv", tn);
39
1064
    d_solve_var[tn] = k;
40
1064
    return k;
41
  }
42
  else
43
  {
44
9000
    return its->second;
45
  }
46
}
47
48
/*---------------------------------------------------------------------------*/
49
50
1446
Node BvInverter::getInversionNode(Node cond, TypeNode tn, BvInverterQuery* m)
51
{
52
2892
  TNode solve_var = getSolveVariable(tn);
53
54
  // condition should be rewritten
55
2892
  Node new_cond = Rewriter::rewrite(cond);
56
1446
  if (new_cond != cond)
57
  {
58
2580
    Trace("cegqi-bv-skvinv-debug")
59
1290
        << "Condition " << cond << " was rewritten to " << new_cond
60
1290
        << std::endl;
61
  }
62
  // optimization : if condition is ( x = solve_var ) should just return
63
  // solve_var and not introduce a Skolem this can happen when we ask for
64
  // the multiplicative inversion with bv1
65
1446
  Node c;
66
1446
  if (new_cond.getKind() == EQUAL)
67
  {
68
168
    for (unsigned i = 0; i < 2; i++)
69
    {
70
112
      if (new_cond[i] == solve_var)
71
      {
72
        c = new_cond[1 - i];
73
        Trace("cegqi-bv-skvinv")
74
            << "SKVINV : " << c << " is trivially associated with conditon "
75
            << new_cond << std::endl;
76
        break;
77
      }
78
    }
79
  }
80
81
1446
  if (c.isNull())
82
  {
83
1446
    NodeManager* nm = NodeManager::currentNM();
84
1446
    if (m)
85
    {
86
2224
      Node x = m->getBoundVariable(tn);
87
2224
      Node ccond = new_cond.substitute(solve_var, x);
88
1112
      c = nm->mkNode(kind::WITNESS, nm->mkNode(BOUND_VAR_LIST, x), ccond);
89
2224
      Trace("cegqi-bv-skvinv")
90
1112
          << "SKVINV : Make " << c << " for " << new_cond << std::endl;
91
    }
92
    else
93
    {
94
668
      Trace("bv-invert") << "...fail for " << cond << " : no inverter query!"
95
334
                         << std::endl;
96
    }
97
  }
98
  // currently shouldn't cache since
99
  // the return value depends on the
100
  // state of m (which bound variable is returned).
101
2892
  return c;
102
}
103
104
/*---------------------------------------------------------------------------*/
105
106
16466
static bool isInvertible(Kind k, unsigned index)
107
{
108
16424
  return k == NOT || k == EQUAL || k == BITVECTOR_ULT || k == BITVECTOR_SLT
109
10812
         || k == BITVECTOR_COMP || k == BITVECTOR_NOT || k == BITVECTOR_NEG
110
9515
         || k == BITVECTOR_CONCAT || k == BITVECTOR_SIGN_EXTEND
111
7258
         || k == BITVECTOR_PLUS || k == BITVECTOR_MULT || k == BITVECTOR_UREM
112
4040
         || k == BITVECTOR_UDIV || k == BITVECTOR_AND || k == BITVECTOR_OR
113
2822
         || k == BITVECTOR_XOR || k == BITVECTOR_LSHR || k == BITVECTOR_ASHR
114
18662
         || k == BITVECTOR_SHL;
115
}
116
117
18555
Node BvInverter::getPathToPv(
118
    Node lit,
119
    Node pv,
120
    Node sv,
121
    std::vector<unsigned>& path,
122
    std::unordered_set<TNode, TNodeHashFunction>& visited)
123
{
124
18555
  if (visited.find(lit) == visited.end())
125
  {
126
18014
    visited.insert(lit);
127
18014
    if (lit == pv)
128
    {
129
3518
      return sv;
130
    }
131
    else
132
    {
133
14496
      unsigned rmod = 0;  // TODO : randomize?
134
23556
      for (size_t i = 0, num = lit.getNumChildren(); i < num; i++)
135
      {
136
16466
        size_t ii = (i + rmod) % lit.getNumChildren();
137
        // only recurse if the kind is invertible
138
        // this allows us to avoid paths that go through skolem functions
139
16466
        if (!isInvertible(lit.getKind(), ii))
140
        {
141
1940
          continue;
142
        }
143
21646
        Node litc = getPathToPv(lit[ii], pv, sv, path, visited);
144
14526
        if (!litc.isNull())
145
        {
146
          // path is outermost term index last
147
7406
          path.push_back(ii);
148
14812
          std::vector<Node> children;
149
7406
          if (lit.getMetaKind() == kind::metakind::PARAMETERIZED)
150
          {
151
40
            children.push_back(lit.getOperator());
152
          }
153
22241
          for (size_t j = 0, num2 = lit.getNumChildren(); j < num2; j++)
154
          {
155
14835
            children.push_back(j == ii ? litc : lit[j]);
156
          }
157
7406
          return NodeManager::currentNM()->mkNode(lit.getKind(), children);
158
        }
159
      }
160
    }
161
  }
162
7631
  return Node::null();
163
}
164
165
4029
Node BvInverter::getPathToPv(Node lit,
166
                             Node pv,
167
                             Node sv,
168
                             Node pvs,
169
                             std::vector<unsigned>& path,
170
                             bool projectNl)
171
{
172
8058
  std::unordered_set<TNode, TNodeHashFunction> visited;
173
8058
  Node slit = getPathToPv(lit, pv, sv, path, visited);
174
  // if we are able to find a (invertible) path to pv
175
4029
  if (!slit.isNull() && !pvs.isNull())
176
  {
177
    // substitute pvs for the other occurrences of pv
178
6005
    TNode tpv = pv;
179
6005
    TNode tpvs = pvs;
180
6005
    Node prev_lit = slit;
181
3092
    slit = slit.substitute(tpv, tpvs);
182
3092
    if (!projectNl && slit != prev_lit)
183
    {
184
      // found another occurrence of pv that was not on the solve path,
185
      // hence lit is non-linear wrt pv and we return null.
186
179
      return Node::null();
187
    }
188
  }
189
3850
  return slit;
190
}
191
192
/*---------------------------------------------------------------------------*/
193
194
/* Drop child at given index from expression.
195
 * E.g., dropChild((x + y + z), 1) -> (x + z)  */
196
3465
static Node dropChild(Node n, unsigned index)
197
{
198
3465
  unsigned nchildren = n.getNumChildren();
199
3465
  Assert(nchildren > 0);
200
3465
  Assert(index < nchildren);
201
202
3465
  if (nchildren < 2) return Node::null();
203
204
3359
  Kind k = n.getKind();
205
6718
  NodeBuilder<> nb(k);
206
10145
  for (unsigned i = 0; i < nchildren; ++i)
207
  {
208
6786
    if (i == index) continue;
209
3427
    nb << n[i];
210
  }
211
3359
  Assert(nb.getNumChildren() > 0);
212
3359
  return nb.getNumChildren() == 1 ? nb[0] : nb.constructNode();
213
}
214
215
3339
Node BvInverter::solveBvLit(Node sv,
216
                            Node lit,
217
                            std::vector<unsigned>& path,
218
                            BvInverterQuery* m)
219
{
220
3339
  Assert(!path.empty());
221
222
3339
  bool pol = true;
223
  unsigned index;
224
3339
  NodeManager* nm = NodeManager::currentNM();
225
  Kind k, litk;
226
227
3339
  Assert(!path.empty());
228
3339
  index = path.back();
229
3339
  Assert(index < lit.getNumChildren());
230
3339
  path.pop_back();
231
3339
  litk = k = lit.getKind();
232
233
  /* Note: option --bool-to-bv is currently disabled when CBQI BV
234
   *       is enabled and the logic is quantified.
235
   *       We currently do not support Boolean operators
236
   *       that are interpreted as bit-vector operators of width 1.  */
237
238
  /* Boolean layer ----------------------------------------------- */
239
240
3339
  if (k == NOT)
241
  {
242
42
    pol = !pol;
243
42
    lit = lit[index];
244
42
    Assert(!path.empty());
245
42
    index = path.back();
246
42
    Assert(index < lit.getNumChildren());
247
42
    path.pop_back();
248
42
    litk = k = lit.getKind();
249
  }
250
251
3339
  Assert(k == EQUAL || k == BITVECTOR_ULT || k == BITVECTOR_SLT);
252
253
6678
  Node sv_t = lit[index];
254
6678
  Node t = lit[1 - index];
255
3339
  if (litk == BITVECTOR_ULT && index == 1)
256
  {
257
2
    litk = BITVECTOR_UGT;
258
  }
259
3337
  else if (litk == BITVECTOR_SLT && index == 1)
260
  {
261
    litk = BITVECTOR_SGT;
262
  }
263
264
  /* Bit-vector layer -------------------------------------------- */
265
266
7529
  while (!path.empty())
267
  {
268
3465
    unsigned nchildren = sv_t.getNumChildren();
269
3465
    Assert(nchildren > 0);
270
3465
    index = path.back();
271
3465
    Assert(index < nchildren);
272
3465
    path.pop_back();
273
3465
    k = sv_t.getKind();
274
275
    /* Note: All n-ary kinds except for CONCAT (i.e., BITVECTOR_AND,
276
     *       BITVECTOR_OR, MULT, PLUS) are commutative (no case split
277
     *       based on index). */
278
5560
    Node s = dropChild(sv_t, index);
279
3465
    Assert((nchildren == 1 && s.isNull()) || (nchildren > 1 && !s.isNull()));
280
5560
    TypeNode solve_tn = sv_t[index].getType();
281
5560
    Node x = getSolveVariable(solve_tn);
282
5560
    Node ic;
283
284
3465
    if (litk == EQUAL && (k == BITVECTOR_NOT || k == BITVECTOR_NEG))
285
    {
286
66
      t = nm->mkNode(k, t);
287
    }
288
3399
    else if (litk == EQUAL && k == BITVECTOR_PLUS)
289
    {
290
628
      t = nm->mkNode(BITVECTOR_SUB, t, s);
291
    }
292
2771
    else if (litk == EQUAL && k == BITVECTOR_XOR)
293
    {
294
8
      t = nm->mkNode(BITVECTOR_XOR, t, s);
295
    }
296
8289
    else if (litk == EQUAL && k == BITVECTOR_MULT && s.isConst()
297
5550
             && bv::utils::getBit(s, 0))
298
    {
299
24
      unsigned w = bv::utils::getSize(s);
300
48
      Integer s_val = s.getConst<BitVector>().toInteger();
301
48
      Integer mod_val = Integer(1).multiplyByPow2(w);
302
48
      Trace("bv-invert-debug")
303
24
          << "Compute inverse : " << s_val << " " << mod_val << std::endl;
304
48
      Integer inv_val = s_val.modInverse(mod_val);
305
24
      Trace("bv-invert-debug") << "Inverse : " << inv_val << std::endl;
306
48
      Node inv = bv::utils::mkConst(w, inv_val);
307
24
      t = nm->mkNode(BITVECTOR_MULT, inv, t);
308
    }
309
2739
    else if (k == BITVECTOR_MULT)
310
    {
311
278
      ic = utils::getICBvMult(pol, litk, k, index, x, s, t);
312
    }
313
2461
    else if (k == BITVECTOR_SHL)
314
    {
315
168
      ic = utils::getICBvShl(pol, litk, k, index, x, s, t);
316
    }
317
2293
    else if (k == BITVECTOR_UREM)
318
    {
319
124
      ic = utils::getICBvUrem(pol, litk, k, index, x, s, t);
320
    }
321
2169
    else if (k == BITVECTOR_UDIV)
322
    {
323
104
      ic = utils::getICBvUdiv(pol, litk, k, index, x, s, t);
324
    }
325
2065
    else if (k == BITVECTOR_AND || k == BITVECTOR_OR)
326
    {
327
440
      ic = utils::getICBvAndOr(pol, litk, k, index, x, s, t);
328
    }
329
1625
    else if (k == BITVECTOR_LSHR)
330
    {
331
134
      ic = utils::getICBvLshr(pol, litk, k, index, x, s, t);
332
    }
333
1491
    else if (k == BITVECTOR_ASHR)
334
    {
335
140
      ic = utils::getICBvAshr(pol, litk, k, index, x, s, t);
336
    }
337
1351
    else if (k == BITVECTOR_CONCAT)
338
    {
339
550
      if (litk == EQUAL && options::cegqiBvConcInv())
340
      {
341
        /* Compute inverse for s1 o x, x o s2, s1 o x o s2
342
         * (while disregarding that invertibility depends on si)
343
         * rather than an invertibility condition (the proper handling).
344
         * This improves performance on a considerable number of benchmarks.
345
         *
346
         * x = t[upper:lower]
347
         * where
348
         * upper = getSize(t) - 1 - sum(getSize(sv_t[i])) for i < index
349
         * lower = getSize(sv_t[i]) for i > index  */
350
        unsigned upper, lower;
351
275
        upper = bv::utils::getSize(t) - 1;
352
275
        lower = 0;
353
550
        NodeBuilder<> nb(BITVECTOR_CONCAT);
354
883
        for (unsigned i = 0; i < nchildren; i++)
355
        {
356
608
          if (i < index)
357
          {
358
122
            upper -= bv::utils::getSize(sv_t[i]);
359
          }
360
486
          else if (i > index)
361
          {
362
211
            lower += bv::utils::getSize(sv_t[i]);
363
          }
364
        }
365
275
        t = bv::utils::mkExtract(t, upper, lower);
366
      }
367
      else
368
      {
369
        ic = utils::getICBvConcat(pol, litk, index, x, sv_t, t);
370
      }
371
    }
372
1076
    else if (k == BITVECTOR_SIGN_EXTEND)
373
    {
374
40
      ic = utils::getICBvSext(pol, litk, index, x, sv_t, t);
375
    }
376
1036
    else if (litk == BITVECTOR_ULT || litk == BITVECTOR_UGT)
377
    {
378
      ic = utils::getICBvUltUgt(pol, litk, x, t);
379
    }
380
1036
    else if (litk == BITVECTOR_SLT || litk == BITVECTOR_SGT)
381
    {
382
      ic = utils::getICBvSltSgt(pol, litk, x, t);
383
    }
384
1036
    else if (pol == false)
385
    {
386
      Assert(litk == EQUAL);
387
      ic = nm->mkNode(DISTINCT, x, t);
388
      Trace("bv-invert") << "Add SC_" << litk << "(" << x << "): " << ic
389
                         << std::endl;
390
    }
391
    else
392
    {
393
2072
      Trace("bv-invert") << "bv-invert : Unknown kind " << k
394
1036
                         << " for bit-vector term " << sv_t << std::endl;
395
1036
      return Node::null();
396
    }
397
398
2429
    if (!ic.isNull())
399
    {
400
      /* We generate a witness term (witness x0. ic => x0 <k> s <litk> t) for
401
       * x <k> s <litk> t. When traversing down, this witness term determines
402
       * the value for x <k> s = (witness x0. ic => x0 <k> s <litk> t), i.e.,
403
       * from here on, the propagated literal is a positive equality. */
404
1428
      litk = EQUAL;
405
1428
      pol = true;
406
      /* t = fresh skolem constant */
407
1428
      t = getInversionNode(ic, solve_tn, m);
408
1428
      if (t.isNull())
409
      {
410
334
        return t;
411
      }
412
    }
413
414
2095
    sv_t = sv_t[index];
415
  }
416
417
  /* Base case  */
418
1969
  Assert(sv_t == sv);
419
3938
  TypeNode solve_tn = sv.getType();
420
3938
  Node x = getSolveVariable(solve_tn);
421
3938
  Node ic;
422
1969
  if (litk == BITVECTOR_ULT || litk == BITVECTOR_UGT)
423
  {
424
6
    ic = utils::getICBvUltUgt(pol, litk, x, t);
425
  }
426
1963
  else if (litk == BITVECTOR_SLT || litk == BITVECTOR_SGT)
427
  {
428
    ic = utils::getICBvSltSgt(pol, litk, x, t);
429
  }
430
1963
  else if (pol == false)
431
  {
432
12
    Assert(litk == EQUAL);
433
12
    ic = nm->mkNode(DISTINCT, x, t);
434
24
    Trace("bv-invert") << "Add SC_" << litk << "(" << x << "): " << ic
435
12
                       << std::endl;
436
  }
437
438
1969
  return ic.isNull() ? t : getInversionNode(ic, solve_tn, m);
439
}
440
441
/*---------------------------------------------------------------------------*/
442
443
}  // namespace quantifiers
444
}  // namespace theory
445
26951
}  // namespace CVC4