GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/omt/bitvector_optimizer.cpp Lines: 95 107 88.8 %
Date: 2021-09-29 Branches: 174 360 48.3 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Yancheng Ou
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Optimizer for BitVector type.
14
 */
15
16
#include "omt/bitvector_optimizer.h"
17
18
#include "options/smt_options.h"
19
#include "smt/smt_engine.h"
20
#include "util/bitvector.h"
21
22
using namespace cvc5::smt;
23
namespace cvc5::omt {
24
25
30
OMTOptimizerBitVector::OMTOptimizerBitVector(bool isSigned)
26
30
    : d_isSigned(isSigned)
27
{
28
30
}
29
30
632
BitVector OMTOptimizerBitVector::computeAverage(const BitVector& a,
31
                                                const BitVector& b,
32
                                                bool isSigned)
33
{
34
  // computes (a + b) / 2 without overflow
35
  // rounding towards -infinity: -1.5 --> -2,  1.5 --> 1
36
  // average = (a / 2) + (b / 2) + (((a % 2) + (b % 2)) / 2)
37
632
  uint32_t aMod2 = static_cast<uint32_t>(a.isBitSet(0));
38
632
  uint32_t bMod2 = static_cast<uint32_t>(b.isBitSet(0));
39
1264
  BitVector aMod2PlusbMod2Div2(a.getSize(), (aMod2 + bMod2) / 2);
40
1264
  BitVector bv1 = BitVector::mkOne(a.getSize());
41
882
  return (isSigned) ? ((a.arithRightShift(bv1) + b.arithRightShift(bv1)
42
                        + aMod2PlusbMod2Div2))
43
1014
                    : ((a.logicalRightShift(bv1) + b.logicalRightShift(bv1)
44
1896
                        + aMod2PlusbMod2Div2));
45
}
46
47
14
OptimizationResult OMTOptimizerBitVector::minimize(SmtEngine* optChecker,
48
                                                   TNode target)
49
{
50
  // the smt engine to which we send intermediate queries
51
  // for the binary search.
52
14
  NodeManager* nm = optChecker->getNodeManager();
53
28
  Result intermediateSatResult = optChecker->checkSat();
54
  // Model-value of objective (used in optimization loop)
55
28
  Node value;
56
28
  if (intermediateSatResult.isUnknown()
57
14
      || intermediateSatResult.isSat() == Result::UNSAT)
58
  {
59
    return OptimizationResult(intermediateSatResult, value);
60
  }
61
  // the last result that is SAT
62
28
  Result lastSatResult = intermediateSatResult;
63
  // value equals to upperBound
64
14
  value = optChecker->getValue(target);
65
66
  // this gets the bitvector!
67
28
  BitVector bvValue = value.getConst<BitVector>();
68
14
  unsigned int bvSize = bvValue.getSize();
69
70
  // lowerbound
71
14
  BitVector lowerBound = ((this->d_isSigned) ? (BitVector::mkMinSigned(bvSize))
72
28
                                             : (BitVector::mkZero(bvSize)));
73
  // upperbound must be a satisfying value
74
  // and value == upperbound
75
28
  BitVector upperBound = bvValue;
76
77
14
  Kind LTOperator =
78
14
      ((d_isSigned) ? (kind::BITVECTOR_SLT) : (kind::BITVECTOR_ULT));
79
14
  Kind GEOperator =
80
14
      ((d_isSigned) ? (kind::BITVECTOR_SGE) : (kind::BITVECTOR_UGE));
81
82
  // the pivot value for binary search,
83
  // pivot = (lowerBound + upperBound) / 2
84
  // rounded towards -infinity
85
28
  BitVector pivot;
86
250
  while ((d_isSigned && lowerBound.signedLessThan(upperBound))
87
258
         || (!d_isSigned && lowerBound.unsignedLessThan(upperBound)))
88
  {
89
128
    pivot = computeAverage(lowerBound, upperBound, d_isSigned);
90
128
    optChecker->push();
91
128
    if (lowerBound == pivot)
92
    {
93
12
      optChecker->assertFormula(
94
24
          nm->mkNode(kind::EQUAL, target, nm->mkConst(lowerBound)));
95
    }
96
    else
97
    {
98
      // lowerBound <= target < pivot
99
116
      optChecker->assertFormula(
100
464
          nm->mkNode(kind::AND,
101
232
                     nm->mkNode(GEOperator, target, nm->mkConst(lowerBound)),
102
232
                     nm->mkNode(LTOperator, target, nm->mkConst(pivot))));
103
    }
104
128
    intermediateSatResult = optChecker->checkSat();
105
128
    switch (intermediateSatResult.isSat())
106
    {
107
      case Result::SAT_UNKNOWN:
108
        optChecker->pop();
109
        return OptimizationResult(intermediateSatResult, value);
110
8
      case Result::SAT:
111
8
        lastSatResult = intermediateSatResult;
112
8
        value = optChecker->getValue(target);
113
8
        upperBound = value.getConst<BitVector>();
114
8
        break;
115
120
      case Result::UNSAT:
116
120
        if (lowerBound == pivot)
117
        {
118
          // lowerBound == pivot ==> upperbound = lowerbound + 1
119
          // and lowerbound <= target < upperbound is UNSAT
120
          // return the upperbound
121
12
          optChecker->pop();
122
12
          return OptimizationResult(lastSatResult, value);
123
        }
124
        else
125
        {
126
108
          lowerBound = pivot;
127
        }
128
108
        break;
129
      default: Unreachable();
130
    }
131
116
    optChecker->pop();
132
  }
133
2
  return OptimizationResult(lastSatResult, value);
134
}
135
136
16
OptimizationResult OMTOptimizerBitVector::maximize(SmtEngine* optChecker,
137
                                                   TNode target)
138
{
139
  // the smt engine to which we send intermediate queries
140
  // for the binary search.
141
16
  NodeManager* nm = optChecker->getNodeManager();
142
32
  Result intermediateSatResult = optChecker->checkSat();
143
  // Model-value of objective (used in optimization loop)
144
32
  Node value;
145
32
  if (intermediateSatResult.isUnknown()
146
16
      || intermediateSatResult.isSat() == Result::UNSAT)
147
  {
148
    return OptimizationResult(intermediateSatResult, value);
149
  }
150
  // the last result that is SAT
151
32
  Result lastSatResult = intermediateSatResult;
152
  // value equals to upperBound
153
16
  value = optChecker->getValue(target);
154
155
  // this gets the bitvector!
156
32
  BitVector bvValue = value.getConst<BitVector>();
157
16
  unsigned int bvSize = bvValue.getSize();
158
159
  // lowerbound must be a satisfying value
160
  // and value == lowerbound
161
32
  BitVector lowerBound = bvValue;
162
163
  // upperbound
164
16
  BitVector upperBound = ((this->d_isSigned) ? (BitVector::mkMaxSigned(bvSize))
165
32
                                             : (BitVector::mkOnes(bvSize)));
166
167
16
  Kind LEOperator =
168
16
      ((d_isSigned) ? (kind::BITVECTOR_SLE) : (kind::BITVECTOR_ULE));
169
16
  Kind GTOperator =
170
16
      ((d_isSigned) ? (kind::BITVECTOR_SGT) : (kind::BITVECTOR_UGT));
171
172
  // the pivot value for binary search,
173
  // pivot = (lowerBound + upperBound) / 2
174
  // rounded towards -infinity
175
32
  BitVector pivot;
176
1280
  while ((d_isSigned && lowerBound.signedLessThan(upperBound))
177
792
         || (!d_isSigned && lowerBound.unsignedLessThan(upperBound)))
178
  {
179
504
    pivot = computeAverage(lowerBound, upperBound, d_isSigned);
180
181
504
    optChecker->push();
182
    // notice that we don't have boundary condition here
183
    // because lowerBound == pivot / lowerBound == upperBound + 1 is also
184
    // covered
185
    // pivot < target <= upperBound
186
504
    optChecker->assertFormula(
187
2016
        nm->mkNode(kind::AND,
188
1008
                   nm->mkNode(GTOperator, target, nm->mkConst(pivot)),
189
1008
                   nm->mkNode(LEOperator, target, nm->mkConst(upperBound))));
190
504
    intermediateSatResult = optChecker->checkSat();
191
504
    switch (intermediateSatResult.isSat())
192
    {
193
      case Result::SAT_UNKNOWN:
194
        optChecker->pop();
195
        return OptimizationResult(intermediateSatResult, value);
196
268
      case Result::SAT:
197
268
        lastSatResult = intermediateSatResult;
198
268
        value = optChecker->getValue(target);
199
268
        lowerBound = value.getConst<BitVector>();
200
268
        break;
201
236
      case Result::UNSAT:
202
236
        if (lowerBound == pivot)
203
        {
204
          // upperbound = lowerbound + 1
205
          // and lowerbound < target <= upperbound is UNSAT
206
          // return the lowerbound
207
          optChecker->pop();
208
          return OptimizationResult(lastSatResult, value);
209
        }
210
        else
211
        {
212
236
          upperBound = pivot;
213
        }
214
236
        break;
215
      default: Unreachable();
216
    }
217
504
    optChecker->pop();
218
  }
219
16
  return OptimizationResult(lastSatResult, value);
220
}
221
222
22746
}  // namespace cvc5::omt