GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/preprocessing/passes/bv_gauss.cpp Lines: 311 342 90.9 %
Date: 2021-09-29 Branches: 676 1686 40.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Aina Niemetz, Mathias Preiner, Andrew Reynolds
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Gaussian Elimination preprocessing pass.
14
 *
15
 * Simplify a given equation system modulo a (prime) number via Gaussian
16
 * Elimination if possible.
17
 */
18
19
#include "preprocessing/passes/bv_gauss.h"
20
21
#include <unordered_map>
22
#include <vector>
23
24
#include "expr/node.h"
25
#include "preprocessing/assertion_pipeline.h"
26
#include "preprocessing/preprocessing_pass_context.h"
27
#include "theory/bv/theory_bv_rewrite_rules_normalization.h"
28
#include "theory/bv/theory_bv_utils.h"
29
#include "theory/rewriter.h"
30
#include "util/bitvector.h"
31
32
using namespace cvc5;
33
using namespace cvc5::theory;
34
using namespace cvc5::theory::bv;
35
36
namespace cvc5 {
37
namespace preprocessing {
38
namespace passes {
39
40
5802
bool BVGauss::is_bv_const(Node n)
41
{
42
5802
  if (n.isConst()) { return true; }
43
3992
  return rewrite(n).getKind() == kind::CONST_BITVECTOR;
44
}
45
46
902
Node BVGauss::get_bv_const(Node n)
47
{
48
902
  Assert(is_bv_const(n));
49
902
  return rewrite(n);
50
}
51
52
482
Integer BVGauss::get_bv_const_value(Node n)
53
{
54
482
  Assert(is_bv_const(n));
55
482
  return get_bv_const(n).getConst<BitVector>().getValue();
56
}
57
58
/**
59
 * Determines if an overflow may occur in given 'expr'.
60
 *
61
 * Returns 0 if an overflow may occur, and the minimum required
62
 * bit-width such that no overflow occurs, otherwise.
63
 *
64
 * Note that it would suffice for this function to be Boolean.
65
 * However, it is handy to determine the minimum required bit-width for
66
 * debugging purposes.
67
 *
68
 * Note: getMinBwExpr assumes that 'expr' is rewritten.
69
 *
70
 * If not, all operators that are removed via rewriting (e.g., ror, rol, ...)
71
 * will be handled via the default case, which is not incorrect but also not
72
 * necessarily the minimum.
73
 */
74
202
uint32_t BVGauss::getMinBwExpr(Node expr)
75
{
76
404
  std::vector<Node> visit;
77
  /* Maps visited nodes to the determined minimum bit-width required. */
78
404
  std::unordered_map<Node, unsigned> visited;
79
202
  std::unordered_map<Node, unsigned>::iterator it;
80
81
202
  visit.push_back(expr);
82
5634
  while (!visit.empty())
83
  {
84
5438
    Node n = visit.back();
85
2722
    visit.pop_back();
86
2722
    it = visited.find(n);
87
2722
    if (it == visited.end())
88
    {
89
1490
      if (is_bv_const(n))
90
      {
91
        /* Rewrite const expr, overflows in consts are irrelevant. */
92
420
        visited[n] = get_bv_const(n).getConst<BitVector>().getValue().length();
93
      }
94
      else
95
      {
96
1070
        visited[n] = 0;
97
1070
        visit.push_back(n);
98
1070
        for (const Node &nn : n) { visit.push_back(nn); }
99
      }
100
    }
101
1232
    else if (it->second == 0)
102
    {
103
1066
      Kind k = n.getKind();
104
1066
      Assert(k != kind::CONST_BITVECTOR);
105
1066
      Assert(!is_bv_const(n));
106
1066
      switch (k)
107
      {
108
56
        case kind::BITVECTOR_EXTRACT:
109
        {
110
56
          const unsigned size = bv::utils::getSize(n);
111
56
          const unsigned low = bv::utils::getExtractLow(n);
112
56
          const unsigned child_min_width = visited[n[0]];
113
56
          visited[n] = std::min(
114
112
              size, child_min_width >= low ? child_min_width - low : 0u);
115
56
          Assert(visited[n] <= visited[n[0]]);
116
56
          break;
117
        }
118
119
72
        case kind::BITVECTOR_ZERO_EXTEND:
120
        {
121
72
          visited[n] = visited[n[0]];
122
72
          break;
123
        }
124
125
138
        case kind::BITVECTOR_MULT:
126
        {
127
138
          Integer maxval = Integer(1);
128
426
          for (const Node& nn : n)
129
          {
130
288
            if (is_bv_const(nn))
131
            {
132
114
              maxval *= get_bv_const_value(nn);
133
            }
134
            else
135
            {
136
174
              maxval *= BitVector::mkOnes(visited[nn]).getValue();
137
            }
138
          }
139
138
          unsigned w = maxval.length();
140
138
          if (w > bv::utils::getSize(n)) { return 0; } /* overflow */
141
134
          visited[n] = w;
142
134
          break;
143
        }
144
145
328
        case kind::BITVECTOR_CONCAT:
146
        {
147
          unsigned i, wnz, nc;
148
650
          for (i = 0, wnz = 0, nc = n.getNumChildren() - 1; i < nc; ++i)
149
          {
150
378
            unsigned wni = bv::utils::getSize(n[i]);
151
378
            if (n[i] != bv::utils::mkZero(wni)) { break; }
152
            /* sum of all bit-widths of leading zero concats */
153
322
            wnz += wni;
154
          }
155
          /* Do not consider leading zero concats, i.e.,
156
           * min bw of current concat is determined as
157
           *   min bw of first non-zero term
158
           *   plus actual bw of all subsequent terms */
159
984
          visited[n] = bv::utils::getSize(n) + visited[n[i]]
160
656
                       - bv::utils::getSize(n[i]) - wnz;
161
328
          break;
162
        }
163
164
6
        case kind::BITVECTOR_UREM:
165
        case kind::BITVECTOR_LSHR:
166
        case kind::BITVECTOR_ASHR:
167
        {
168
6
          visited[n] = visited[n[0]];
169
6
          break;
170
        }
171
172
        case kind::BITVECTOR_OR:
173
        case kind::BITVECTOR_NOR:
174
        case kind::BITVECTOR_XOR:
175
        case kind::BITVECTOR_XNOR:
176
        case kind::BITVECTOR_AND:
177
        case kind::BITVECTOR_NAND:
178
        {
179
          unsigned wmax = 0;
180
          for (const Node &nn : n)
181
          {
182
            if (visited[nn] > wmax)
183
            {
184
              wmax = visited[nn];
185
            }
186
          }
187
          visited[n] = wmax;
188
          break;
189
        }
190
191
104
        case kind::BITVECTOR_ADD:
192
        {
193
104
          Integer maxval = Integer(0);
194
364
          for (const Node& nn : n)
195
          {
196
260
            if (is_bv_const(nn))
197
            {
198
              maxval += get_bv_const_value(nn);
199
            }
200
            else
201
            {
202
260
              maxval += BitVector::mkOnes(visited[nn]).getValue();
203
            }
204
          }
205
104
          unsigned w = maxval.length();
206
104
          if (w > bv::utils::getSize(n)) { return 0; } /* overflow */
207
102
          visited[n] = w;
208
102
          break;
209
        }
210
211
362
        default:
212
        {
213
          /* BITVECTOR_UDIV (since x / 0 = -1)
214
           * BITVECTOR_NOT
215
           * BITVECTOR_NEG
216
           * BITVECTOR_SHL */
217
362
          visited[n] = bv::utils::getSize(n);
218
        }
219
      }
220
    }
221
  }
222
196
  Assert(visited.find(expr) != visited.end());
223
196
  return visited[expr];
224
}
225
226
/**
227
 * Apply Gaussian Elimination modulo a (prime) number.
228
 * The given equation system is represented as a matrix of Integers.
229
 *
230
 * Note that given 'prime' does not have to be prime but can be any
231
 * arbitrary number. However, if 'prime' is indeed prime, GE is guaranteed
232
 * to succeed, which is not the case, otherwise.
233
 *
234
 * Returns INVALID if GE can not be applied, UNIQUE and PARTIAL if GE was
235
 * successful, and NONE, otherwise.
236
 *
237
 * Vectors 'rhs' and 'lhs' represent the right hand side and left hand side
238
 * of the given matrix, respectively. The resulting matrix (in row echelon
239
 * form) is stored in 'rhs' and 'lhs', i.e., the given matrix is overwritten
240
 * with the resulting matrix.
241
 */
242
140
BVGauss::Result BVGauss::gaussElim(Integer prime,
243
                                   std::vector<Integer>& rhs,
244
                                   std::vector<std::vector<Integer>>& lhs)
245
{
246
140
  Assert(prime > 0);
247
140
  Assert(lhs.size());
248
140
  Assert(lhs.size() == rhs.size());
249
140
  Assert(lhs.size() <= lhs[0].size());
250
251
  /* special case: zero ring */
252
140
  if (prime == 1)
253
  {
254
2
    rhs = std::vector<Integer>(rhs.size(), Integer(0));
255
6
    lhs = std::vector<std::vector<Integer>>(
256
4
        lhs.size(), std::vector<Integer>(lhs[0].size(), Integer(0)));
257
2
    return BVGauss::Result::UNIQUE;
258
  }
259
260
138
  size_t nrows = lhs.size();
261
138
  size_t ncols = lhs[0].size();
262
263
#ifdef CVC5_ASSERTIONS
264
138
  for (size_t i = 1; i < nrows; ++i) Assert(lhs[i].size() == ncols);
265
#endif
266
  /* (1) if element in pivot column is non-zero and != 1, divide row elements
267
   *     by element in pivot column modulo prime, i.e., multiply row with
268
   *     multiplicative inverse of element in pivot column modulo prime
269
   *
270
   * (2) subtract pivot row from all rows below pivot row
271
   *
272
   * (3) subtract (multiple of) current row from all rows above s.t. all
273
   *     elements in current pivot column above current row become equal to one
274
   *
275
   * Note: we do not normalize the given matrix to values modulo prime
276
   *       beforehand but on-the-fly. */
277
278
  /* pivot = lhs[pcol][pcol] */
279
468
  for (size_t pcol = 0, prow = 0; pcol < ncols && prow < nrows; ++pcol, ++prow)
280
  {
281
    /* lhs[j][pcol]: element in pivot column */
282
1020
    for (size_t j = prow; j < nrows; ++j)
283
    {
284
#ifdef CVC5_ASSERTIONS
285
1140
      for (size_t k = 0; k < pcol; ++k)
286
      {
287
450
        Assert(lhs[j][k] == 0);
288
      }
289
#endif
290
      /* normalize element in pivot column to modulo prime */
291
690
      lhs[j][pcol] = lhs[j][pcol].euclidianDivideRemainder(prime);
292
      /* exchange rows if pivot elem is 0 */
293
690
      if (j == prow)
294
      {
295
500
        while (lhs[j][pcol] == 0)
296
        {
297
172
          for (size_t k = prow + 1; k < nrows; ++k)
298
          {
299
106
            lhs[k][pcol] = lhs[k][pcol].euclidianDivideRemainder(prime);
300
106
            if (lhs[k][pcol] != 0)
301
            {
302
55
              std::swap(rhs[j], rhs[k]);
303
55
              std::swap(lhs[j], lhs[k]);
304
55
              break;
305
            }
306
          }
307
121
          if (pcol >= ncols - 1) break;
308
79
          if (lhs[j][pcol] == 0)
309
          {
310
34
            pcol += 1;
311
34
            if (lhs[j][pcol] != 0)
312
20
              lhs[j][pcol] = lhs[j][pcol].euclidianDivideRemainder(prime);
313
          }
314
        }
315
      }
316
317
690
      if (lhs[j][pcol] != 0)
318
      {
319
        /* (1) */
320
522
        if (lhs[j][pcol] != 1)
321
        {
322
740
          Integer inv = lhs[j][pcol].modInverse(prime);
323
376
          if (inv == -1)
324
          {
325
12
            return BVGauss::Result::INVALID; /* not coprime */
326
          }
327
1203
          for (size_t k = pcol; k < ncols; ++k)
328
          {
329
839
            lhs[j][k] = lhs[j][k].modMultiply(inv, prime);
330
839
            if (j <= prow) continue; /* pivot */
331
488
            lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k], prime);
332
          }
333
364
          rhs[j] = rhs[j].modMultiply(inv, prime);
334
364
          if (j > prow) { rhs[j] = rhs[j].modAdd(-rhs[prow], prime); }
335
        }
336
        /* (2) */
337
146
        else if (j != prow)
338
        {
339
78
          for (size_t k = pcol; k < ncols; ++k)
340
          {
341
58
            lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k], prime);
342
          }
343
20
          rhs[j] = rhs[j].modAdd(-rhs[prow], prime);
344
        }
345
      }
346
    }
347
    /* (3) */
348
610
    for (size_t j = 0; j < prow; ++j)
349
    {
350
560
      Integer mul = lhs[j][pcol];
351
280
      if (mul != 0)
352
      {
353
512
        for (size_t k = pcol; k < ncols; ++k)
354
        {
355
296
          lhs[j][k] = lhs[j][k].modAdd(-lhs[prow][k] * mul, prime);
356
        }
357
216
        rhs[j] = rhs[j].modAdd(-rhs[prow] * mul, prime);
358
      }
359
    }
360
  }
361
362
126
  bool ispart = false;
363
472
  for (size_t i = 0; i < nrows; ++i)
364
  {
365
354
    size_t pcol = i;
366
459
    while (pcol < ncols && lhs[i][pcol] == 0) ++pcol;
367
404
    if (pcol >= ncols)
368
    {
369
58
      rhs[i] = rhs[i].euclidianDivideRemainder(prime);
370
58
      if (rhs[i] != 0)
371
      {
372
        /* no solution */
373
8
        return BVGauss::Result::NONE;
374
      }
375
50
      continue;
376
    }
377
976
    for (size_t j = i; j < ncols; ++j)
378
    {
379
680
      if (lhs[i][j] >= prime || lhs[i][j] <= -prime)
380
      {
381
2
        lhs[i][j] = lhs[i][j].euclidianDivideRemainder(prime);
382
      }
383
680
      if (j > pcol && lhs[i][j] != 0)
384
      {
385
72
        ispart = true;
386
      }
387
    }
388
  }
389
390
118
  if (ispart)
391
  {
392
42
    return BVGauss::Result::PARTIAL;
393
  }
394
395
76
  return BVGauss::Result::UNIQUE;
396
}
397
398
/**
399
 * Apply Gaussian Elimination on a set of equations modulo some (prime)
400
 * number given as bit-vector equations.
401
 *
402
 * IMPORTANT: Applying GE modulo some number (rather than modulo 2^bw)
403
 * on a set of bit-vector equations is only sound if this set of equations
404
 * has a solution that does not produce overflows. Consequently, we only
405
 * apply GE if the given bit-width guarantees that no overflows can occur
406
 * in the given set of equations.
407
 *
408
 * Note that the given set of equations does not have to be modulo a prime
409
 * but can be modulo any arbitrary number. However, if it is indeed modulo
410
 * prime, GE is guaranteed to succeed, which is not the case, otherwise.
411
 *
412
 * Returns INVALID if GE can not be applied, UNIQUE and PARTIAL if GE was
413
 * successful, and NONE, otherwise.
414
 *
415
 * The resulting constraints are stored in 'res' as a mapping of unknown
416
 * to result (modulo prime). These mapped results are added as constraints
417
 * of the form 'unknown = mapped result' in applyInternal.
418
 */
419
36
BVGauss::Result BVGauss::gaussElimRewriteForUrem(
420
    const std::vector<Node>& equations, std::unordered_map<Node, Node>& res)
421
{
422
36
  Assert(res.empty());
423
424
72
  Node prime;
425
72
  Integer iprime;
426
72
  std::unordered_map<Node, std::vector<Integer>> vars;
427
36
  size_t neqs = equations.size();
428
72
  std::vector<Integer> rhs;
429
  std::vector<std::vector<Integer>> lhs =
430
72
      std::vector<std::vector<Integer>>(neqs, std::vector<Integer>());
431
432
36
  res = std::unordered_map<Node, Node>();
433
434
128
  for (size_t i = 0; i < neqs; ++i)
435
  {
436
184
    Node eq = equations[i];
437
92
    Assert(eq.getKind() == kind::EQUAL);
438
184
    Node urem, eqrhs;
439
440
92
    if (eq[0].getKind() == kind::BITVECTOR_UREM)
441
    {
442
92
      urem = eq[0];
443
92
      Assert(is_bv_const(eq[1]));
444
92
      eqrhs = eq[1];
445
    }
446
    else
447
    {
448
      Assert(eq[1].getKind() == kind::BITVECTOR_UREM);
449
      urem = eq[1];
450
      Assert(is_bv_const(eq[0]));
451
      eqrhs = eq[0];
452
    }
453
92
    if (getMinBwExpr(rewrite(urem[0])) == 0)
454
    {
455
      Trace("bv-gauss-elim")
456
          << "Minimum required bit-width exceeds given bit-width, "
457
             "will not apply Gaussian Elimination."
458
          << std::endl;
459
      return BVGauss::Result::INVALID;
460
    }
461
92
    rhs.push_back(get_bv_const_value(eqrhs));
462
463
92
    Assert(is_bv_const(urem[1]));
464
92
    Assert(i == 0 || get_bv_const_value(urem[1]) == iprime);
465
92
    if (i == 0)
466
    {
467
36
      prime = urem[1];
468
36
      iprime = get_bv_const_value(prime);
469
    }
470
471
184
    std::unordered_map<Node, Integer> tmp;
472
184
    std::vector<Node> stack;
473
92
    stack.push_back(urem[0]);
474
700
    while (!stack.empty())
475
    {
476
608
      Node n = stack.back();
477
304
      stack.pop_back();
478
479
      /* Subtract from rhs if const */
480
304
      if (is_bv_const(n))
481
      {
482
        Integer val = get_bv_const_value(n);
483
        if (val > 0) rhs.back() -= val;
484
        continue;
485
      }
486
487
      /* Split into matrix columns */
488
304
      Kind k = n.getKind();
489
304
      if (k == kind::BITVECTOR_ADD)
490
      {
491
106
        for (const Node& nn : n) { stack.push_back(nn); }
492
      }
493
198
      else if (k == kind::BITVECTOR_MULT)
494
      {
495
368
        Node n0, n1;
496
        /* Flatten mult expression. */
497
184
        n = RewriteRule<FlattenAssocCommut>::run<true>(n);
498
        /* Split operands into consts and non-consts */
499
368
        NodeBuilder nb_consts(NodeManager::currentNM(), k);
500
368
        NodeBuilder nb_nonconsts(NodeManager::currentNM(), k);
501
568
        for (const Node& nn : n)
502
        {
503
768
          Node nnrw = rewrite(nn);
504
384
          if (is_bv_const(nnrw))
505
          {
506
176
            nb_consts << nnrw;
507
          }
508
          else
509
          {
510
208
            nb_nonconsts << nnrw;
511
          }
512
        }
513
184
        Assert(nb_nonconsts.getNumChildren() > 0);
514
        /* n0 is const */
515
184
        unsigned nc = nb_consts.getNumChildren();
516
184
        if (nc > 1)
517
        {
518
          n0 = rewrite(nb_consts.constructNode());
519
        }
520
184
        else if (nc == 1)
521
        {
522
176
          n0 = nb_consts[0];
523
        }
524
        else
525
        {
526
8
          n0 = bv::utils::mkOne(bv::utils::getSize(n));
527
        }
528
        /* n1 is a mult with non-const operands */
529
184
        if (nb_nonconsts.getNumChildren() > 1)
530
        {
531
20
          n1 = rewrite(nb_nonconsts.constructNode());
532
        }
533
        else
534
        {
535
164
          n1 = nb_nonconsts[0];
536
        }
537
184
        Assert(is_bv_const(n0));
538
184
        Assert(!is_bv_const(n1));
539
184
        tmp[n1] += get_bv_const_value(n0);
540
      }
541
      else
542
      {
543
14
        tmp[n] += Integer(1);
544
      }
545
    }
546
547
    /* Note: "var" is not necessarily a VARIABLE but can be an arbitrary expr */
548
549
290
    for (const auto& p : tmp)
550
    {
551
396
      Node var = p.first;
552
396
      Integer val = p.second;
553
198
      if (i > 0 && vars.find(var) == vars.end())
554
      {
555
        /* Add column and fill column elements of rows above with 0. */
556
28
        vars[var].insert(vars[var].end(), i, Integer(0));
557
      }
558
198
      vars[var].push_back(val);
559
    }
560
561
332
    for (const auto& p : vars)
562
    {
563
240
      if (tmp.find(p.first) == tmp.end())
564
      {
565
42
        vars[p.first].push_back(Integer(0));
566
      }
567
    }
568
  }
569
570
36
  size_t nvars = vars.size();
571
36
  if (nvars == 0)
572
  {
573
    return BVGauss::Result::INVALID;
574
  }
575
36
  size_t nrows = vars.begin()->second.size();
576
#ifdef CVC5_ASSERTIONS
577
142
  for (const auto& p : vars)
578
  {
579
106
    Assert(p.second.size() == nrows);
580
  }
581
#endif
582
583
36
  if (nrows < 1)
584
  {
585
    return BVGauss::Result::INVALID;
586
  }
587
588
128
  for (size_t i = 0; i < nrows; ++i)
589
  {
590
364
    for (const auto& p : vars)
591
    {
592
272
      lhs[i].push_back(p.second[i]);
593
    }
594
  }
595
596
#ifdef CVC5_ASSERTIONS
597
128
  for (const auto& row : lhs)
598
  {
599
92
    Assert(row.size() == nvars);
600
  }
601
36
  Assert(lhs.size() == rhs.size());
602
#endif
603
604
36
  if (lhs.size() > lhs[0].size())
605
  {
606
2
    return BVGauss::Result::INVALID;
607
  }
608
609
34
  Trace("bv-gauss-elim") << "Applying Gaussian Elimination..." << std::endl;
610
34
  BVGauss::Result ret = gaussElim(iprime, rhs, lhs);
611
612
34
  if (ret != BVGauss::Result::NONE && ret != BVGauss::Result::INVALID)
613
  {
614
68
    std::vector<Node> vvars;
615
34
    for (const auto& p : vars) { vvars.push_back(p.first); }
616
34
    Assert(nvars == vvars.size());
617
34
    Assert(nrows == lhs.size());
618
34
    Assert(nrows == rhs.size());
619
34
    NodeManager *nm = NodeManager::currentNM();
620
34
    if (ret == BVGauss::Result::UNIQUE)
621
    {
622
54
      for (size_t i = 0; i < nvars; ++i)
623
      {
624
40
        res[vvars[i]] = nm->mkConst<BitVector>(
625
80
            BitVector(bv::utils::getSize(vvars[i]), rhs[i]));
626
      }
627
    }
628
    else
629
    {
630
20
      Assert(ret == BVGauss::Result::PARTIAL);
631
632
60
      for (size_t pcol = 0, prow = 0; pcol < nvars && prow < nrows;
633
           ++pcol, ++prow)
634
      {
635
44
        Assert(lhs[prow][pcol] == 0 || lhs[prow][pcol] == 1);
636
52
        while (pcol < nvars && lhs[prow][pcol] == 0) pcol += 1;
637
44
        if (pcol >= nvars)
638
        {
639
4
          Assert(rhs[prow] == 0);
640
4
          break;
641
        }
642
40
        if (lhs[prow][pcol] == 0)
643
        {
644
          Assert(rhs[prow] == 0);
645
          continue;
646
        }
647
40
        Assert(lhs[prow][pcol] == 1);
648
80
        std::vector<Node> stack;
649
101
        for (size_t i = pcol + 1; i < nvars; ++i)
650
        {
651
61
          if (lhs[prow][i] == 0) continue;
652
          /* Normalize (no negative numbers, hence no subtraction)
653
           * e.g., x = 4 - 2y  --> x = 4 + 9y (modulo 11) */
654
68
          Integer m = iprime - lhs[prow][i];
655
68
          Node bv = bv::utils::mkConst(bv::utils::getSize(vvars[i]), m);
656
68
          Node mult = nm->mkNode(kind::BITVECTOR_MULT, vvars[i], bv);
657
34
          stack.push_back(mult);
658
        }
659
660
40
        if (stack.empty())
661
        {
662
6
          res[vvars[pcol]] = nm->mkConst<BitVector>(
663
12
              BitVector(bv::utils::getSize(vvars[pcol]), rhs[prow]));
664
        }
665
        else
666
        {
667
68
          Node tmp = stack.size() == 1 ? stack[0]
668
102
                                       : nm->mkNode(kind::BITVECTOR_ADD, stack);
669
670
34
          if (rhs[prow] != 0)
671
          {
672
96
            tmp = nm->mkNode(
673
                kind::BITVECTOR_ADD,
674
64
                bv::utils::mkConst(bv::utils::getSize(vvars[pcol]), rhs[prow]),
675
                tmp);
676
          }
677
34
          Assert(!is_bv_const(tmp));
678
34
          res[vvars[pcol]] = nm->mkNode(kind::BITVECTOR_UREM, tmp, prime);
679
        }
680
      }
681
    }
682
  }
683
34
  return ret;
684
}
685
686
6351
BVGauss::BVGauss(PreprocessingPassContext* preprocContext,
687
6351
                 const std::string& name)
688
6351
    : PreprocessingPass(preprocContext, name)
689
{
690
6351
}
691
692
6
PreprocessingPassResult BVGauss::applyInternal(
693
    AssertionPipeline* assertionsToPreprocess)
694
{
695
12
  std::vector<Node> assertions(assertionsToPreprocess->ref());
696
12
  std::unordered_map<Node, std::vector<Node>> equations;
697
698
62
  while (!assertions.empty())
699
  {
700
56
    Node a = assertions.back();
701
28
    assertions.pop_back();
702
28
    cvc5::Kind k = a.getKind();
703
704
28
    if (k == kind::AND)
705
    {
706
24
      for (const Node& aa : a)
707
      {
708
16
        assertions.push_back(aa);
709
      }
710
    }
711
20
    else if (k == kind::EQUAL)
712
    {
713
40
      Node urem;
714
715
20
      if (is_bv_const(a[1]) && a[0].getKind() == kind::BITVECTOR_UREM)
716
      {
717
20
        urem = a[0];
718
      }
719
      else if (is_bv_const(a[0]) && a[1].getKind() == kind::BITVECTOR_UREM)
720
      {
721
        urem = a[1];
722
      }
723
      else
724
      {
725
        continue;
726
      }
727
728
20
      if (urem[0].getKind() == kind::BITVECTOR_ADD && is_bv_const(urem[1]))
729
      {
730
20
        equations[urem[1]].push_back(a);
731
      }
732
    }
733
  }
734
735
12
  std::unordered_map<Node, Node> subst;
736
737
6
  NodeManager* nm = NodeManager::currentNM();
738
14
  for (const auto& eq : equations)
739
  {
740
8
    if (eq.second.size() <= 1) { continue; }
741
742
16
    std::unordered_map<Node, Node> res;
743
8
    BVGauss::Result ret = gaussElimRewriteForUrem(eq.second, res);
744
16
    Trace("bv-gauss-elim") << "result: "
745
                           << (ret == BVGauss::Result::INVALID
746
24
                                   ? "INVALID"
747
                                   : (ret == BVGauss::Result::UNIQUE
748
10
                                          ? "UNIQUE"
749
                                          : (ret == BVGauss::Result::PARTIAL
750
2
                                                 ? "PARTIAL"
751
8
                                                 : "NONE")))
752
8
                           << std::endl;
753
8
    if (ret != BVGauss::Result::INVALID)
754
    {
755
8
      if (ret == BVGauss::Result::NONE)
756
      {
757
        assertionsToPreprocess->clear();
758
        Node n = nm->mkConst<bool>(false);
759
        assertionsToPreprocess->push_back(n);
760
        return PreprocessingPassResult::CONFLICT;
761
      }
762
      else
763
      {
764
28
        for (const Node& e : eq.second)
765
        {
766
20
          subst[e] = nm->mkConst<bool>(true);
767
        }
768
        /* add resulting constraints */
769
28
        for (const auto& p : res)
770
        {
771
40
          Node a = nm->mkNode(kind::EQUAL, p.first, p.second);
772
20
          Trace("bv-gauss-elim") << "added assertion: " << a << std::endl;
773
          // add new assertion
774
20
          assertionsToPreprocess->push_back(a);
775
        }
776
      }
777
    }
778
  }
779
780
6
  if (!subst.empty())
781
  {
782
    /* delete (= substitute with true) obsolete assertions */
783
6
    const std::vector<Node>& aref = assertionsToPreprocess->ref();
784
38
    for (size_t i = 0, asize = aref.size(); i < asize; ++i)
785
    {
786
64
      Node a = aref[i];
787
64
      Node as = a.substitute(subst.begin(), subst.end());
788
      // replace the assertion
789
32
      assertionsToPreprocess->replace(i, as);
790
    }
791
  }
792
6
  return PreprocessingPassResult::NO_CONFLICT;
793
}
794
795
796
}  // namespace passes
797
}  // namespace preprocessing
798
22746
}  // namespace cvc5