1 |
|
/****************************************************************************** |
2 |
|
* Top contributors (to current version): |
3 |
|
* Yancheng Ou |
4 |
|
* |
5 |
|
* This file is part of the cvc5 project. |
6 |
|
* |
7 |
|
* Copyright (c) 2009-2021 by the authors listed in the file AUTHORS |
8 |
|
* in the top-level source directory and their institutional affiliations. |
9 |
|
* All rights reserved. See the file COPYING in the top-level source |
10 |
|
* directory for licensing information. |
11 |
|
* **************************************************************************** |
12 |
|
* |
13 |
|
* Optimizer for BitVector type. |
14 |
|
*/ |
15 |
|
|
16 |
|
#include "omt/bitvector_optimizer.h" |
17 |
|
|
18 |
|
#include "options/smt_options.h" |
19 |
|
#include "smt/smt_engine.h" |
20 |
|
#include "util/bitvector.h" |
21 |
|
|
22 |
|
using namespace cvc5::smt; |
23 |
|
namespace cvc5::omt { |
24 |
|
|
25 |
30 |
OMTOptimizerBitVector::OMTOptimizerBitVector(bool isSigned) |
26 |
30 |
: d_isSigned(isSigned) |
27 |
|
{ |
28 |
30 |
} |
29 |
|
|
30 |
632 |
BitVector OMTOptimizerBitVector::computeAverage(const BitVector& a, |
31 |
|
const BitVector& b, |
32 |
|
bool isSigned) |
33 |
|
{ |
34 |
|
// computes (a + b) / 2 without overflow |
35 |
|
// rounding towards -infinity: -1.5 --> -2, 1.5 --> 1 |
36 |
|
// average = (a / 2) + (b / 2) + (((a % 2) + (b % 2)) / 2) |
37 |
632 |
uint32_t aMod2 = static_cast<uint32_t>(a.isBitSet(0)); |
38 |
632 |
uint32_t bMod2 = static_cast<uint32_t>(b.isBitSet(0)); |
39 |
1264 |
BitVector aMod2PlusbMod2Div2(a.getSize(), (aMod2 + bMod2) / 2); |
40 |
1264 |
BitVector bv1 = BitVector::mkOne(a.getSize()); |
41 |
882 |
return (isSigned) ? ((a.arithRightShift(bv1) + b.arithRightShift(bv1) |
42 |
|
+ aMod2PlusbMod2Div2)) |
43 |
1014 |
: ((a.logicalRightShift(bv1) + b.logicalRightShift(bv1) |
44 |
1896 |
+ aMod2PlusbMod2Div2)); |
45 |
|
} |
46 |
|
|
47 |
14 |
OptimizationResult OMTOptimizerBitVector::minimize(SmtEngine* optChecker, |
48 |
|
TNode target) |
49 |
|
{ |
50 |
|
// the smt engine to which we send intermediate queries |
51 |
|
// for the binary search. |
52 |
14 |
NodeManager* nm = optChecker->getNodeManager(); |
53 |
28 |
Result intermediateSatResult = optChecker->checkSat(); |
54 |
|
// Model-value of objective (used in optimization loop) |
55 |
28 |
Node value; |
56 |
28 |
if (intermediateSatResult.isUnknown() |
57 |
14 |
|| intermediateSatResult.isSat() == Result::UNSAT) |
58 |
|
{ |
59 |
|
return OptimizationResult(intermediateSatResult, value); |
60 |
|
} |
61 |
|
// the last result that is SAT |
62 |
28 |
Result lastSatResult = intermediateSatResult; |
63 |
|
// value equals to upperBound |
64 |
14 |
value = optChecker->getValue(target); |
65 |
|
|
66 |
|
// this gets the bitvector! |
67 |
28 |
BitVector bvValue = value.getConst<BitVector>(); |
68 |
14 |
unsigned int bvSize = bvValue.getSize(); |
69 |
|
|
70 |
|
// lowerbound |
71 |
14 |
BitVector lowerBound = ((this->d_isSigned) ? (BitVector::mkMinSigned(bvSize)) |
72 |
28 |
: (BitVector::mkZero(bvSize))); |
73 |
|
// upperbound must be a satisfying value |
74 |
|
// and value == upperbound |
75 |
28 |
BitVector upperBound = bvValue; |
76 |
|
|
77 |
14 |
Kind LTOperator = |
78 |
14 |
((d_isSigned) ? (kind::BITVECTOR_SLT) : (kind::BITVECTOR_ULT)); |
79 |
14 |
Kind GEOperator = |
80 |
14 |
((d_isSigned) ? (kind::BITVECTOR_SGE) : (kind::BITVECTOR_UGE)); |
81 |
|
|
82 |
|
// the pivot value for binary search, |
83 |
|
// pivot = (lowerBound + upperBound) / 2 |
84 |
|
// rounded towards -infinity |
85 |
28 |
BitVector pivot; |
86 |
250 |
while ((d_isSigned && lowerBound.signedLessThan(upperBound)) |
87 |
258 |
|| (!d_isSigned && lowerBound.unsignedLessThan(upperBound))) |
88 |
|
{ |
89 |
128 |
pivot = computeAverage(lowerBound, upperBound, d_isSigned); |
90 |
128 |
optChecker->push(); |
91 |
128 |
if (lowerBound == pivot) |
92 |
|
{ |
93 |
12 |
optChecker->assertFormula( |
94 |
24 |
nm->mkNode(kind::EQUAL, target, nm->mkConst(lowerBound))); |
95 |
|
} |
96 |
|
else |
97 |
|
{ |
98 |
|
// lowerBound <= target < pivot |
99 |
116 |
optChecker->assertFormula( |
100 |
464 |
nm->mkNode(kind::AND, |
101 |
232 |
nm->mkNode(GEOperator, target, nm->mkConst(lowerBound)), |
102 |
232 |
nm->mkNode(LTOperator, target, nm->mkConst(pivot)))); |
103 |
|
} |
104 |
128 |
intermediateSatResult = optChecker->checkSat(); |
105 |
128 |
switch (intermediateSatResult.isSat()) |
106 |
|
{ |
107 |
|
case Result::SAT_UNKNOWN: |
108 |
|
optChecker->pop(); |
109 |
|
return OptimizationResult(intermediateSatResult, value); |
110 |
8 |
case Result::SAT: |
111 |
8 |
lastSatResult = intermediateSatResult; |
112 |
8 |
value = optChecker->getValue(target); |
113 |
8 |
upperBound = value.getConst<BitVector>(); |
114 |
8 |
break; |
115 |
120 |
case Result::UNSAT: |
116 |
120 |
if (lowerBound == pivot) |
117 |
|
{ |
118 |
|
// lowerBound == pivot ==> upperbound = lowerbound + 1 |
119 |
|
// and lowerbound <= target < upperbound is UNSAT |
120 |
|
// return the upperbound |
121 |
12 |
optChecker->pop(); |
122 |
12 |
return OptimizationResult(lastSatResult, value); |
123 |
|
} |
124 |
|
else |
125 |
|
{ |
126 |
108 |
lowerBound = pivot; |
127 |
|
} |
128 |
108 |
break; |
129 |
|
default: Unreachable(); |
130 |
|
} |
131 |
116 |
optChecker->pop(); |
132 |
|
} |
133 |
2 |
return OptimizationResult(lastSatResult, value); |
134 |
|
} |
135 |
|
|
136 |
16 |
OptimizationResult OMTOptimizerBitVector::maximize(SmtEngine* optChecker, |
137 |
|
TNode target) |
138 |
|
{ |
139 |
|
// the smt engine to which we send intermediate queries |
140 |
|
// for the binary search. |
141 |
16 |
NodeManager* nm = optChecker->getNodeManager(); |
142 |
32 |
Result intermediateSatResult = optChecker->checkSat(); |
143 |
|
// Model-value of objective (used in optimization loop) |
144 |
32 |
Node value; |
145 |
32 |
if (intermediateSatResult.isUnknown() |
146 |
16 |
|| intermediateSatResult.isSat() == Result::UNSAT) |
147 |
|
{ |
148 |
|
return OptimizationResult(intermediateSatResult, value); |
149 |
|
} |
150 |
|
// the last result that is SAT |
151 |
32 |
Result lastSatResult = intermediateSatResult; |
152 |
|
// value equals to upperBound |
153 |
16 |
value = optChecker->getValue(target); |
154 |
|
|
155 |
|
// this gets the bitvector! |
156 |
32 |
BitVector bvValue = value.getConst<BitVector>(); |
157 |
16 |
unsigned int bvSize = bvValue.getSize(); |
158 |
|
|
159 |
|
// lowerbound must be a satisfying value |
160 |
|
// and value == lowerbound |
161 |
32 |
BitVector lowerBound = bvValue; |
162 |
|
|
163 |
|
// upperbound |
164 |
16 |
BitVector upperBound = ((this->d_isSigned) ? (BitVector::mkMaxSigned(bvSize)) |
165 |
32 |
: (BitVector::mkOnes(bvSize))); |
166 |
|
|
167 |
16 |
Kind LEOperator = |
168 |
16 |
((d_isSigned) ? (kind::BITVECTOR_SLE) : (kind::BITVECTOR_ULE)); |
169 |
16 |
Kind GTOperator = |
170 |
16 |
((d_isSigned) ? (kind::BITVECTOR_SGT) : (kind::BITVECTOR_UGT)); |
171 |
|
|
172 |
|
// the pivot value for binary search, |
173 |
|
// pivot = (lowerBound + upperBound) / 2 |
174 |
|
// rounded towards -infinity |
175 |
32 |
BitVector pivot; |
176 |
1280 |
while ((d_isSigned && lowerBound.signedLessThan(upperBound)) |
177 |
792 |
|| (!d_isSigned && lowerBound.unsignedLessThan(upperBound))) |
178 |
|
{ |
179 |
504 |
pivot = computeAverage(lowerBound, upperBound, d_isSigned); |
180 |
|
|
181 |
504 |
optChecker->push(); |
182 |
|
// notice that we don't have boundary condition here |
183 |
|
// because lowerBound == pivot / lowerBound == upperBound + 1 is also |
184 |
|
// covered |
185 |
|
// pivot < target <= upperBound |
186 |
504 |
optChecker->assertFormula( |
187 |
2016 |
nm->mkNode(kind::AND, |
188 |
1008 |
nm->mkNode(GTOperator, target, nm->mkConst(pivot)), |
189 |
1008 |
nm->mkNode(LEOperator, target, nm->mkConst(upperBound)))); |
190 |
504 |
intermediateSatResult = optChecker->checkSat(); |
191 |
504 |
switch (intermediateSatResult.isSat()) |
192 |
|
{ |
193 |
|
case Result::SAT_UNKNOWN: |
194 |
|
optChecker->pop(); |
195 |
|
return OptimizationResult(intermediateSatResult, value); |
196 |
268 |
case Result::SAT: |
197 |
268 |
lastSatResult = intermediateSatResult; |
198 |
268 |
value = optChecker->getValue(target); |
199 |
268 |
lowerBound = value.getConst<BitVector>(); |
200 |
268 |
break; |
201 |
236 |
case Result::UNSAT: |
202 |
236 |
if (lowerBound == pivot) |
203 |
|
{ |
204 |
|
// upperbound = lowerbound + 1 |
205 |
|
// and lowerbound < target <= upperbound is UNSAT |
206 |
|
// return the lowerbound |
207 |
|
optChecker->pop(); |
208 |
|
return OptimizationResult(lastSatResult, value); |
209 |
|
} |
210 |
|
else |
211 |
|
{ |
212 |
236 |
upperBound = pivot; |
213 |
|
} |
214 |
236 |
break; |
215 |
|
default: Unreachable(); |
216 |
|
} |
217 |
504 |
optChecker->pop(); |
218 |
|
} |
219 |
16 |
return OptimizationResult(lastSatResult, value); |
220 |
|
} |
221 |
|
|
222 |
29577 |
} // namespace cvc5::omt |