GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/fp/theory_fp_rewriter.cpp Lines: 594 763 77.9 %
Date: 2021-08-14 Branches: 1032 3979 25.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Martin Brain, Andres Noetzli, Aina Niemetz
4
 * Copyright (c) 2013  University of Oxford
5
 *
6
 * This file is part of the cvc5 project.
7
 *
8
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
9
 * in the top-level source directory and their institutional affiliations.
10
 * All rights reserved.  See the file COPYING in the top-level source
11
 * directory for licensing information.
12
 * ****************************************************************************
13
 *
14
 * Rewrite rules for floating point theories.
15
 *
16
 * \todo - Single argument constant propagate / simplify
17
 *       - Push negations through arithmetic operators (include max and min?
18
 *         maybe not due to +0/-0)
19
 *       - classifications to normal tests (maybe)
20
 *       - (= x (fp.neg x)) --> (isNaN x)
21
 *       - (fp.eq x (fp.neg x)) --> (isZero x) (previous and reorganise
22
 *             should be sufficient)
23
 *       - (fp.eq x const) --> various = depending on const
24
 *       - (fp.isPositive (fp.neg x)) --> (fp.isNegative x)
25
 *       - (fp.isNegative (fp.neg x)) --> (fp.isPositive x)
26
 *       - (fp.isPositive (fp.abs x)) --> (not (isNaN x))
27
 *       - (fp.isNegative (fp.abs x)) --> false
28
 *       - A -> castA --> A
29
 *       - A -> castB -> castC  -->  A -> castC if A <= B <= C
30
 *       - A -> castB -> castA  -->  A if A <= B
31
 *       - promotion converts can ignore rounding mode
32
 *       - Samuel Figuer results
33
 */
34
35
#include "theory/fp/theory_fp_rewriter.h"
36
37
#include <algorithm>
38
39
#include "base/check.h"
40
#include "theory/bv/theory_bv_utils.h"
41
#include "theory/fp/fp_converter.h"
42
#include "util/floatingpoint.h"
43
44
namespace cvc5 {
45
namespace theory {
46
namespace fp {
47
48
namespace rewrite {
49
  /** Rewrite rules **/
50
  template <RewriteFunction first, RewriteFunction second>
51
208
  RewriteResponse then (TNode node, bool isPreRewrite) {
52
416
    RewriteResponse result(first(node, isPreRewrite));
53
54
208
    if (result.d_status == REWRITE_DONE)
55
    {
56
208
      return second(result.d_node, isPreRewrite);
57
    }
58
    else
59
    {
60
      return result;
61
    }
62
  }
63
64
  RewriteResponse notFP(TNode node, bool isPreRewrite)
65
  {
66
    Unreachable() << "non floating-point kind (" << node.getKind()
67
                  << ") in floating point rewrite?";
68
  }
69
70
11149
  RewriteResponse identity(TNode node, bool isPreRewrite)
71
  {
72
11149
    return RewriteResponse(REWRITE_DONE, node);
73
  }
74
75
  RewriteResponse type(TNode node, bool isPreRewrite)
76
  {
77
    Unreachable() << "sort kind (" << node.getKind()
78
                  << ") found in expression?";
79
  }
80
81
70
  RewriteResponse removeToFPGeneric(TNode node, bool isPreRewrite)
82
  {
83
70
    Assert(!isPreRewrite);
84
70
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_GENERIC);
85
86
    FloatingPointToFPGeneric info =
87
70
        node.getOperator().getConst<FloatingPointToFPGeneric>();
88
89
70
    uint32_t children = node.getNumChildren();
90
91
140
    Node op;
92
70
    NodeManager* nm = NodeManager::currentNM();
93
94
70
    if (children == 1)
95
    {
96
30
      op = nm->mkConst(FloatingPointToFPIEEEBitVector(info));
97
30
      return RewriteResponse(REWRITE_AGAIN, nm->mkNode(op, node[0]));
98
    }
99
40
    Assert(children == 2);
100
40
    Assert(node[0].getType().isRoundingMode());
101
102
80
    TypeNode t = node[1].getType();
103
104
40
    if (t.isFloatingPoint())
105
    {
106
8
      op = nm->mkConst(FloatingPointToFPFloatingPoint(info));
107
    }
108
32
    else if (t.isReal())
109
    {
110
16
      op = nm->mkConst(FloatingPointToFPReal(info));
111
    }
112
    else
113
    {
114
16
      Assert(t.isBitVector());
115
16
      op = nm->mkConst(FloatingPointToFPSignedBitVector(info));
116
    }
117
118
40
    return RewriteResponse(REWRITE_AGAIN, nm->mkNode(op, node[0], node[1]));
119
  }
120
121
433
  RewriteResponse removeDoubleNegation(TNode node, bool isPreRewrite)
122
  {
123
433
    Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
124
433
    if (node[0].getKind() == kind::FLOATINGPOINT_NEG) {
125
4
      return RewriteResponse(REWRITE_AGAIN, node[0][0]);
126
    }
127
128
429
    return RewriteResponse(REWRITE_DONE, node);
129
  }
130
131
110
  RewriteResponse compactAbs(TNode node, bool isPreRewrite)
132
  {
133
110
    Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
134
330
    if (node[0].getKind() == kind::FLOATINGPOINT_NEG
135
330
        || node[0].getKind() == kind::FLOATINGPOINT_ABS)
136
    {
137
      Node ret =
138
16
          NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_ABS, node[0][0]);
139
8
      return RewriteResponse(REWRITE_AGAIN, ret);
140
    }
141
142
102
    return RewriteResponse(REWRITE_DONE, node);
143
  }
144
145
33
  RewriteResponse convertSubtractionToAddition(TNode node, bool isPreRewrite)
146
  {
147
33
    Assert(node.getKind() == kind::FLOATINGPOINT_SUB);
148
66
    Node negation = NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_NEG,node[2]);
149
    Node addition = NodeManager::currentNM()->mkNode(
150
66
        kind::FLOATINGPOINT_ADD, node[0], node[1], negation);
151
66
    return RewriteResponse(REWRITE_DONE, addition);
152
  }
153
154
208
  RewriteResponse breakChain (TNode node, bool isPreRewrite) {
155
208
    Assert(isPreRewrite);  // Should be run first
156
157
208
    Kind k = node.getKind();
158
208
    Assert(k == kind::FLOATINGPOINT_EQ || k == kind::FLOATINGPOINT_GEQ
159
           || k == kind::FLOATINGPOINT_LEQ || k == kind::FLOATINGPOINT_GT
160
           || k == kind::FLOATINGPOINT_LT);
161
162
208
    size_t children = node.getNumChildren();
163
208
    if (children > 2) {
164
      NodeBuilder conjunction(kind::AND);
165
166
      for (size_t i = 0; i < children - 1; ++i) {
167
	for (size_t j = i + 1; j < children; ++j) {
168
	  conjunction << NodeManager::currentNM()->mkNode(k, node[i], node[j]);
169
	}
170
      }
171
      return RewriteResponse(REWRITE_AGAIN_FULL, conjunction);
172
173
    } else {
174
208
      return RewriteResponse(REWRITE_DONE, node);
175
    }
176
  }
177
178
179
  /* Implies (fp.eq x x) --> (not (isNaN x))
180
   */
181
182
12
  RewriteResponse ieeeEqToEq(TNode node, bool isPreRewrite)
183
  {
184
12
    Assert(node.getKind() == kind::FLOATINGPOINT_EQ);
185
12
    NodeManager *nm = NodeManager::currentNM();
186
187
    return RewriteResponse(REWRITE_DONE,
188
48
			   nm->mkNode(kind::AND,
189
48
				      nm->mkNode(kind::AND,
190
24
						 nm->mkNode(kind::NOT, nm->mkNode(kind::FLOATINGPOINT_ISNAN, node[0])),
191
24
						 nm->mkNode(kind::NOT, nm->mkNode(kind::FLOATINGPOINT_ISNAN, node[1]))),
192
48
				      nm->mkNode(kind::OR,
193
24
						 nm->mkNode(kind::EQUAL, node[0], node[1]),
194
48
						 nm->mkNode(kind::AND,
195
24
							    nm->mkNode(kind::FLOATINGPOINT_ISZ, node[0]),
196
36
							    nm->mkNode(kind::FLOATINGPOINT_ISZ, node[1])))));
197
  }
198
199
17
  RewriteResponse geqToleq(TNode node, bool isPreRewrite)
200
  {
201
17
    Assert(node.getKind() == kind::FLOATINGPOINT_GEQ);
202
17
    return RewriteResponse(REWRITE_DONE,NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_LEQ,node[1],node[0]));
203
  }
204
205
4
  RewriteResponse gtTolt(TNode node, bool isPreRewrite)
206
  {
207
4
    Assert(node.getKind() == kind::FLOATINGPOINT_GT);
208
4
    return RewriteResponse(REWRITE_DONE,NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_LT,node[1],node[0]));
209
  }
210
211
  RewriteResponse removed(TNode node, bool isPreRewrite)
212
  {
213
    Unreachable() << "kind (" << node.getKind()
214
                  << ") should have been removed?";
215
  }
216
217
  RewriteResponse variable(TNode node, bool isPreRewrite)
218
  {
219
    // We should only get floating point and rounding mode variables to rewrite.
220
    TypeNode tn = node.getType(true);
221
    Assert(tn.isFloatingPoint() || tn.isRoundingMode());
222
223
    // Not that we do anything with them...
224
    return RewriteResponse(REWRITE_DONE, node);
225
  }
226
227
2315
  RewriteResponse equal (TNode node, bool isPreRewrite) {
228
2315
    Assert(node.getKind() == kind::EQUAL);
229
230
    // We should only get equalities of floating point or rounding mode types.
231
4630
    TypeNode tn = node[0].getType(true);
232
233
2315
    Assert(tn.isFloatingPoint() || tn.isRoundingMode());
234
2315
    Assert(tn
235
           == node[1].getType(true));  // Should be ensured by the typing rules
236
237
2315
    if (node[0] == node[1]) {
238
65
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
239
2250
    } else if (!isPreRewrite && (node[0] > node[1])) {
240
      Node normal =
241
390
          NodeManager::currentNM()->mkNode(kind::EQUAL, node[1], node[0]);
242
195
      return RewriteResponse(REWRITE_DONE, normal);
243
    } else {
244
2055
      return RewriteResponse(REWRITE_DONE, node);
245
    }
246
  }
247
248
249
  // Note these cannot be assumed to be symmetric for +0/-0, thus no symmetry reorder
250
360
  RewriteResponse compactMinMax (TNode node, bool isPreRewrite) {
251
#ifdef CVC5_ASSERTIONS
252
360
    Kind k = node.getKind();
253
360
    Assert((k == kind::FLOATINGPOINT_MIN) || (k == kind::FLOATINGPOINT_MAX)
254
           || (k == kind::FLOATINGPOINT_MIN_TOTAL)
255
           || (k == kind::FLOATINGPOINT_MAX_TOTAL));
256
#endif
257
360
    if (node[0] == node[1]) {
258
23
      return RewriteResponse(REWRITE_AGAIN, node[0]);
259
    } else {
260
337
      return RewriteResponse(REWRITE_DONE, node);
261
    }
262
  }
263
264
265
  RewriteResponse reorderFPEquality (TNode node, bool isPreRewrite) {
266
    Assert(node.getKind() == kind::FLOATINGPOINT_EQ);
267
    Assert(!isPreRewrite);  // Likely redundant in pre-rewrite
268
269
    if (node[0] > node[1]) {
270
      Node normal = NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_EQ,node[1],node[0]);
271
      return RewriteResponse(REWRITE_DONE, normal);
272
    } else {
273
      return RewriteResponse(REWRITE_DONE, node);
274
    }
275
  }
276
277
1178
  RewriteResponse reorderBinaryOperation (TNode node, bool isPreRewrite) {
278
1178
    Kind k = node.getKind();
279
1178
    Assert((k == kind::FLOATINGPOINT_ADD) || (k == kind::FLOATINGPOINT_MULT));
280
1178
    Assert(!isPreRewrite);  // Likely redundant in pre-rewrite
281
282
1178
    if (node[1] > node[2]) {
283
488
      Node normal = NodeManager::currentNM()->mkNode(k,node[0],node[2],node[1]);
284
244
      return RewriteResponse(REWRITE_DONE, normal);
285
    } else {
286
934
      return RewriteResponse(REWRITE_DONE, node);
287
    }
288
  }
289
290
  RewriteResponse reorderFMA (TNode node, bool isPreRewrite) {
291
    Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
292
    Assert(!isPreRewrite);  // Likely redundant in pre-rewrite
293
294
    if (node[1] > node[2]) {
295
      Node normal = NodeManager::currentNM()->mkNode(
296
          kind::FLOATINGPOINT_FMA, {node[0], node[2], node[1], node[3]});
297
      return RewriteResponse(REWRITE_DONE, normal);
298
    } else {
299
      return RewriteResponse(REWRITE_DONE, node);
300
    }
301
  }
302
303
496
  RewriteResponse removeSignOperations (TNode node, bool isPreRewrite) {
304
496
    Assert(node.getKind() == kind::FLOATINGPOINT_ISN
305
           || node.getKind() == kind::FLOATINGPOINT_ISSN
306
           || node.getKind() == kind::FLOATINGPOINT_ISZ
307
           || node.getKind() == kind::FLOATINGPOINT_ISINF
308
           || node.getKind() == kind::FLOATINGPOINT_ISNAN);
309
496
    Assert(node.getNumChildren() == 1);
310
311
496
    Kind childKind(node[0].getKind());
312
313
496
    if ((childKind == kind::FLOATINGPOINT_NEG) ||
314
	(childKind == kind::FLOATINGPOINT_ABS)) {
315
316
      Node rewritten = NodeManager::currentNM()->mkNode(node.getKind(),node[0][0]);
317
      return RewriteResponse(REWRITE_AGAIN_FULL, rewritten);
318
    } else {
319
496
      return RewriteResponse(REWRITE_DONE, node);
320
    }
321
  }
322
323
189
  RewriteResponse compactRemainder (TNode node, bool isPreRewrite) {
324
189
    Assert(node.getKind() == kind::FLOATINGPOINT_REM);
325
189
    Assert(!isPreRewrite);  // status assumes parts have been rewritten
326
327
378
    Node working = node;
328
329
    // (fp.rem (fp.rem X Y) Y) == (fp.rem X Y)
330
378
    if (working[0].getKind() == kind::FLOATINGPOINT_REM && // short-cut matters!
331
189
	working[0][1] == working[1]) {
332
      working = working[0];
333
    }
334
335
    // Sign of the RHS does not matter
336
756
    if (working[1].getKind() == kind::FLOATINGPOINT_NEG ||
337
567
	working[1].getKind() == kind::FLOATINGPOINT_ABS) {
338
      working[1] = working[1][0];
339
    }
340
341
    // Lift negation out of the LHS so it can be cancelled out
342
189
    if (working[0].getKind() == kind::FLOATINGPOINT_NEG) {
343
      NodeManager * nm = NodeManager::currentNM();
344
      working = nm->mkNode(
345
          kind::FLOATINGPOINT_NEG,
346
          nm->mkNode(kind::FLOATINGPOINT_REM, working[0][0], working[1]));
347
      // in contrast to other rewrites here, this requires rewrite again full
348
      return RewriteResponse(REWRITE_AGAIN_FULL, working);
349
    }
350
351
189
    return RewriteResponse(REWRITE_DONE, working);
352
  }
353
354
410
  RewriteResponse leqId(TNode node, bool isPreRewrite)
355
  {
356
410
    Assert(node.getKind() == kind::FLOATINGPOINT_LEQ);
357
358
410
    if (node[0] == node[1])
359
    {
360
      NodeManager *nm = NodeManager::currentNM();
361
      return RewriteResponse(
362
          isPreRewrite ? REWRITE_DONE : REWRITE_AGAIN_FULL,
363
          nm->mkNode(kind::NOT,
364
                     nm->mkNode(kind::FLOATINGPOINT_ISNAN, node[0])));
365
    }
366
410
    return RewriteResponse(REWRITE_DONE, node);
367
  }
368
369
135
  RewriteResponse ltId(TNode node, bool isPreRewrite)
370
  {
371
135
    Assert(node.getKind() == kind::FLOATINGPOINT_LT);
372
373
135
    if (node[0] == node[1])
374
    {
375
      return RewriteResponse(REWRITE_DONE,
376
                             NodeManager::currentNM()->mkConst(false));
377
    }
378
135
    return RewriteResponse(REWRITE_DONE, node);
379
  }
380
381
32
  RewriteResponse toFPSignedBV(TNode node, bool isPreRewrite)
382
  {
383
32
    Assert(!isPreRewrite);
384
32
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
385
386
    /* symFPU does not allow conversions from signed bit-vector of size 1 */
387
32
    if (node[1].getType().getBitVectorSize() == 1)
388
    {
389
24
      NodeManager* nm = NodeManager::currentNM();
390
48
      Node op = nm->mkConst(FloatingPointToFPUnsignedBitVector(
391
96
          node.getOperator().getConst<FloatingPointToFPSignedBitVector>()));
392
48
      Node fromubv = nm->mkNode(op, node[0], node[1]);
393
      return RewriteResponse(
394
          REWRITE_AGAIN_FULL,
395
144
          nm->mkNode(kind::ITE,
396
48
                     node[1].eqNode(bv::utils::mkOne(1)),
397
48
                     nm->mkNode(kind::FLOATINGPOINT_NEG, fromubv),
398
24
                     fromubv));
399
    }
400
8
    return RewriteResponse(REWRITE_DONE, node);
401
  }
402
403
  };  // namespace rewrite
404
405
namespace constantFold {
406
407
26
RewriteResponse fpLiteral(TNode node, bool isPreRewrite)
408
{
409
26
  Assert(node.getKind() == kind::FLOATINGPOINT_FP);
410
411
52
  BitVector bv(node[0].getConst<BitVector>());
412
26
  bv = bv.concat(node[1].getConst<BitVector>());
413
26
  bv = bv.concat(node[2].getConst<BitVector>());
414
415
  // +1 to support the hidden bit
416
  Node lit = NodeManager::currentNM()->mkConst(
417
52
      FloatingPoint(node[1].getConst<BitVector>().getSize(),
418
52
                    node[2].getConst<BitVector>().getSize() + 1,
419
78
                    bv));
420
421
52
  return RewriteResponse(REWRITE_DONE, lit);
422
}
423
424
13
RewriteResponse abs(TNode node, bool isPreRewrite)
425
{
426
13
  Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
427
13
  Assert(node.getNumChildren() == 1);
428
429
  return RewriteResponse(REWRITE_DONE,
430
39
                         NodeManager::currentNM()->mkConst(
431
39
                             node[0].getConst<FloatingPoint>().absolute()));
432
}
433
434
129
RewriteResponse neg(TNode node, bool isPreRewrite)
435
{
436
129
  Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
437
129
  Assert(node.getNumChildren() == 1);
438
439
  return RewriteResponse(REWRITE_DONE,
440
387
                         NodeManager::currentNM()->mkConst(
441
387
                             node[0].getConst<FloatingPoint>().negate()));
442
}
443
444
484
RewriteResponse add(TNode node, bool isPreRewrite)
445
{
446
484
  Assert(node.getKind() == kind::FLOATINGPOINT_ADD);
447
484
  Assert(node.getNumChildren() == 3);
448
449
484
  RoundingMode rm(node[0].getConst<RoundingMode>());
450
968
  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
451
968
  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
452
453
484
  Assert(arg1.getSize() == arg2.getSize());
454
455
  return RewriteResponse(REWRITE_DONE,
456
968
                         NodeManager::currentNM()->mkConst(arg1.add(rm, arg2)));
457
}
458
459
145
RewriteResponse mult(TNode node, bool isPreRewrite)
460
{
461
145
  Assert(node.getKind() == kind::FLOATINGPOINT_MULT);
462
145
  Assert(node.getNumChildren() == 3);
463
464
145
  RoundingMode rm(node[0].getConst<RoundingMode>());
465
290
  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
466
290
  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
467
468
145
  Assert(arg1.getSize() == arg2.getSize());
469
470
  return RewriteResponse(
471
290
      REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2)));
472
}
473
474
RewriteResponse fma(TNode node, bool isPreRewrite)
475
{
476
  Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
477
  Assert(node.getNumChildren() == 4);
478
479
  RoundingMode rm(node[0].getConst<RoundingMode>());
480
  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
481
  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
482
  FloatingPoint arg3(node[3].getConst<FloatingPoint>());
483
484
  Assert(arg1.getSize() == arg2.getSize());
485
  Assert(arg1.getSize() == arg3.getSize());
486
487
  return RewriteResponse(
488
      REWRITE_DONE,
489
      NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3)));
490
}
491
492
196
RewriteResponse div(TNode node, bool isPreRewrite)
493
{
494
196
  Assert(node.getKind() == kind::FLOATINGPOINT_DIV);
495
196
  Assert(node.getNumChildren() == 3);
496
497
196
  RoundingMode rm(node[0].getConst<RoundingMode>());
498
392
  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
499
392
  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
500
501
196
  Assert(arg1.getSize() == arg2.getSize());
502
503
  return RewriteResponse(REWRITE_DONE,
504
392
                         NodeManager::currentNM()->mkConst(arg1.div(rm, arg2)));
505
}
506
507
30
RewriteResponse sqrt(TNode node, bool isPreRewrite)
508
{
509
30
  Assert(node.getKind() == kind::FLOATINGPOINT_SQRT);
510
30
  Assert(node.getNumChildren() == 2);
511
512
30
  RoundingMode rm(node[0].getConst<RoundingMode>());
513
60
  FloatingPoint arg(node[1].getConst<FloatingPoint>());
514
515
  return RewriteResponse(REWRITE_DONE,
516
60
                         NodeManager::currentNM()->mkConst(arg.sqrt(rm)));
517
}
518
519
140
RewriteResponse rti(TNode node, bool isPreRewrite)
520
{
521
140
  Assert(node.getKind() == kind::FLOATINGPOINT_RTI);
522
140
  Assert(node.getNumChildren() == 2);
523
524
140
  RoundingMode rm(node[0].getConst<RoundingMode>());
525
280
  FloatingPoint arg(node[1].getConst<FloatingPoint>());
526
527
  return RewriteResponse(REWRITE_DONE,
528
280
                         NodeManager::currentNM()->mkConst(arg.rti(rm)));
529
}
530
531
109
RewriteResponse rem(TNode node, bool isPreRewrite)
532
{
533
109
  Assert(node.getKind() == kind::FLOATINGPOINT_REM);
534
109
  Assert(node.getNumChildren() == 2);
535
536
218
  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
537
218
  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
538
539
109
  Assert(arg1.getSize() == arg2.getSize());
540
541
  return RewriteResponse(REWRITE_DONE,
542
218
                         NodeManager::currentNM()->mkConst(arg1.rem(arg2)));
543
}
544
545
RewriteResponse min(TNode node, bool isPreRewrite)
546
{
547
  Assert(node.getKind() == kind::FLOATINGPOINT_MIN);
548
  Assert(node.getNumChildren() == 2);
549
550
  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
551
  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
552
553
  Assert(arg1.getSize() == arg2.getSize());
554
555
  FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
556
557
  if (res.second)
558
  {
559
    Node lit = NodeManager::currentNM()->mkConst(res.first);
560
    return RewriteResponse(REWRITE_DONE, lit);
561
  }
562
  else
563
  {
564
    // Can't constant fold the underspecified case
565
    return RewriteResponse(REWRITE_DONE, node);
566
  }
567
}
568
569
4
RewriteResponse max(TNode node, bool isPreRewrite)
570
{
571
4
  Assert(node.getKind() == kind::FLOATINGPOINT_MAX);
572
4
  Assert(node.getNumChildren() == 2);
573
574
8
  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
575
8
  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
576
577
4
  Assert(arg1.getSize() == arg2.getSize());
578
579
8
  FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
580
581
4
  if (res.second)
582
  {
583
8
    Node lit = NodeManager::currentNM()->mkConst(res.first);
584
4
    return RewriteResponse(REWRITE_DONE, lit);
585
  }
586
  else
587
  {
588
    // Can't constant fold the underspecified case
589
    return RewriteResponse(REWRITE_DONE, node);
590
  }
591
}
592
593
RewriteResponse minTotal(TNode node, bool isPreRewrite)
594
{
595
  Assert(node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL);
596
  Assert(node.getNumChildren() == 3);
597
598
  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
599
  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
600
601
  Assert(arg1.getSize() == arg2.getSize());
602
603
  // Can be called with the third argument non-constant
604
  if (node[2].getMetaKind() == kind::metakind::CONSTANT)
605
  {
606
    BitVector arg3(node[2].getConst<BitVector>());
607
608
    FloatingPoint folded(arg1.minTotal(arg2, arg3.isBitSet(0)));
609
    Node lit = NodeManager::currentNM()->mkConst(folded);
610
    return RewriteResponse(REWRITE_DONE, lit);
611
  }
612
  else
613
  {
614
    FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
615
616
    if (res.second)
617
    {
618
      Node lit = NodeManager::currentNM()->mkConst(res.first);
619
      return RewriteResponse(REWRITE_DONE, lit);
620
    }
621
    else
622
    {
623
      // Can't constant fold the underspecified case
624
      return RewriteResponse(REWRITE_DONE, node);
625
    }
626
  }
627
}
628
629
104
RewriteResponse maxTotal(TNode node, bool isPreRewrite)
630
{
631
104
  Assert(node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL);
632
104
  Assert(node.getNumChildren() == 3);
633
634
208
  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
635
208
  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
636
637
104
  Assert(arg1.getSize() == arg2.getSize());
638
639
  // Can be called with the third argument non-constant
640
104
  if (node[2].getMetaKind() == kind::metakind::CONSTANT)
641
  {
642
    BitVector arg3(node[2].getConst<BitVector>());
643
644
    FloatingPoint folded(arg1.maxTotal(arg2, arg3.isBitSet(0)));
645
    Node lit = NodeManager::currentNM()->mkConst(folded);
646
    return RewriteResponse(REWRITE_DONE, lit);
647
  }
648
  else
649
  {
650
208
    FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
651
652
104
    if (res.second)
653
    {
654
208
      Node lit = NodeManager::currentNM()->mkConst(res.first);
655
104
      return RewriteResponse(REWRITE_DONE, lit);
656
    }
657
    else
658
    {
659
      // Can't constant fold the underspecified case
660
      return RewriteResponse(REWRITE_DONE, node);
661
    }
662
  }
663
}
664
665
26
  RewriteResponse equal (TNode node, bool isPreRewrite) {
666
26
    Assert(node.getKind() == kind::EQUAL);
667
668
    // We should only get equalities of floating point or rounding mode types.
669
52
    TypeNode tn = node[0].getType(true);
670
671
26
    if (tn.isFloatingPoint()) {
672
34
      FloatingPoint arg1(node[0].getConst<FloatingPoint>());
673
34
      FloatingPoint arg2(node[1].getConst<FloatingPoint>());
674
675
17
      Assert(arg1.getSize() == arg2.getSize());
676
677
17
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2));
678
679
9
    } else if (tn.isRoundingMode()) {
680
9
      RoundingMode arg1(node[0].getConst<RoundingMode>());
681
9
      RoundingMode arg2(node[1].getConst<RoundingMode>());
682
683
9
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2));
684
685
    }
686
    Unreachable() << "Equality of unknown type";
687
  }
688
689
10
  RewriteResponse leq(TNode node, bool isPreRewrite)
690
  {
691
10
    Assert(node.getKind() == kind::FLOATINGPOINT_LEQ);
692
10
    Assert(node.getNumChildren() == 2);
693
694
20
    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
695
20
    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
696
697
10
    Assert(arg1.getSize() == arg2.getSize());
698
699
20
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 <= arg2));
700
  }
701
702
2
  RewriteResponse lt(TNode node, bool isPreRewrite)
703
  {
704
2
    Assert(node.getKind() == kind::FLOATINGPOINT_LT);
705
2
    Assert(node.getNumChildren() == 2);
706
707
4
    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
708
4
    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
709
710
2
    Assert(arg1.getSize() == arg2.getSize());
711
712
4
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 < arg2));
713
  }
714
715
7
  RewriteResponse isNormal(TNode node, bool isPreRewrite)
716
  {
717
7
    Assert(node.getKind() == kind::FLOATINGPOINT_ISN);
718
7
    Assert(node.getNumChildren() == 1);
719
720
7
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNormal()));
721
  }
722
723
7
  RewriteResponse isSubnormal(TNode node, bool isPreRewrite)
724
  {
725
7
    Assert(node.getKind() == kind::FLOATINGPOINT_ISSN);
726
7
    Assert(node.getNumChildren() == 1);
727
728
7
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isSubnormal()));
729
  }
730
731
136
  RewriteResponse isZero(TNode node, bool isPreRewrite)
732
  {
733
136
    Assert(node.getKind() == kind::FLOATINGPOINT_ISZ);
734
136
    Assert(node.getNumChildren() == 1);
735
736
136
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isZero()));
737
  }
738
739
5
  RewriteResponse isInfinite(TNode node, bool isPreRewrite)
740
  {
741
5
    Assert(node.getKind() == kind::FLOATINGPOINT_ISINF);
742
5
    Assert(node.getNumChildren() == 1);
743
744
5
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isInfinite()));
745
  }
746
747
23
  RewriteResponse isNaN(TNode node, bool isPreRewrite)
748
  {
749
23
    Assert(node.getKind() == kind::FLOATINGPOINT_ISNAN);
750
23
    Assert(node.getNumChildren() == 1);
751
752
23
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNaN()));
753
  }
754
755
8
  RewriteResponse isNegative(TNode node, bool isPreRewrite)
756
  {
757
8
    Assert(node.getKind() == kind::FLOATINGPOINT_ISNEG);
758
8
    Assert(node.getNumChildren() == 1);
759
760
8
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNegative()));
761
  }
762
763
5
  RewriteResponse isPositive(TNode node, bool isPreRewrite)
764
  {
765
5
    Assert(node.getKind() == kind::FLOATINGPOINT_ISPOS);
766
5
    Assert(node.getNumChildren() == 1);
767
768
5
    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isPositive()));
769
  }
770
771
201
  RewriteResponse convertFromIEEEBitVectorLiteral(TNode node, bool isPreRewrite)
772
  {
773
201
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR);
774
775
402
    TNode op = node.getOperator();
776
201
    const FloatingPointToFPIEEEBitVector &param = op.getConst<FloatingPointToFPIEEEBitVector>();
777
201
    const BitVector &bv = node[0].getConst<BitVector>();
778
779
    Node lit = NodeManager::currentNM()->mkConst(
780
402
        FloatingPoint(param.getSize().exponentWidth(),
781
402
                      param.getSize().significandWidth(),
782
603
                      bv));
783
784
402
    return RewriteResponse(REWRITE_DONE, lit);
785
  }
786
787
19
  RewriteResponse constantConvert(TNode node, bool isPreRewrite)
788
  {
789
19
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT);
790
19
    Assert(node.getNumChildren() == 2);
791
792
19
    RoundingMode rm(node[0].getConst<RoundingMode>());
793
38
    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
794
19
    FloatingPointToFPFloatingPoint info = node.getOperator().getConst<FloatingPointToFPFloatingPoint>();
795
796
    return RewriteResponse(
797
        REWRITE_DONE,
798
38
        NodeManager::currentNM()->mkConst(arg1.convert(info.getSize(), rm)));
799
  }
800
801
38
  RewriteResponse convertFromRealLiteral(TNode node, bool isPreRewrite)
802
  {
803
38
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
804
805
76
    TNode op = node.getOperator();
806
    const FloatingPointSize& size =
807
38
        op.getConst<FloatingPointToFPReal>().getSize();
808
809
38
    RoundingMode rm(node[0].getConst<RoundingMode>());
810
76
    Rational arg(node[1].getConst<Rational>());
811
812
76
    FloatingPoint res(size, rm, arg);
813
814
76
    Node lit = NodeManager::currentNM()->mkConst(res);
815
816
76
    return RewriteResponse(REWRITE_DONE, lit);
817
  }
818
819
8
  RewriteResponse convertFromSBV(TNode node, bool isPreRewrite)
820
  {
821
8
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
822
823
16
    TNode op = node.getOperator();
824
    const FloatingPointSize& size =
825
8
        op.getConst<FloatingPointToFPSignedBitVector>().getSize();
826
827
8
    RoundingMode rm(node[0].getConst<RoundingMode>());
828
16
    BitVector sbv(node[1].getConst<BitVector>());
829
830
8
    NodeManager* nm = NodeManager::currentNM();
831
832
    /* symFPU does not allow conversions from signed bit-vector of size 1 */
833
8
    if (sbv.getSize() == 1)
834
    {
835
      FloatingPoint fromubv(size, rm, sbv, false);
836
      if (sbv.isBitSet(0))
837
      {
838
        return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv.negate()));
839
      }
840
      return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv));
841
    }
842
843
    return RewriteResponse(REWRITE_DONE,
844
8
                           nm->mkConst(FloatingPoint(size, rm, sbv, true)));
845
  }
846
847
60
  RewriteResponse convertFromUBV(TNode node, bool isPreRewrite)
848
  {
849
60
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR);
850
851
120
    TNode op = node.getOperator();
852
    const FloatingPointSize& size =
853
60
        op.getConst<FloatingPointToFPUnsignedBitVector>().getSize();
854
855
60
    RoundingMode rm(node[0].getConst<RoundingMode>());
856
120
    BitVector arg(node[1].getConst<BitVector>());
857
858
120
    FloatingPoint res(size, rm, arg, false);
859
860
120
    Node lit = NodeManager::currentNM()->mkConst(res);
861
862
120
    return RewriteResponse(REWRITE_DONE, lit);
863
  }
864
865
  RewriteResponse convertToUBV(TNode node, bool isPreRewrite)
866
  {
867
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV);
868
869
    TNode op = node.getOperator();
870
    const BitVectorSize& size = op.getConst<FloatingPointToUBV>().d_bv_size;
871
872
    RoundingMode rm(node[0].getConst<RoundingMode>());
873
    FloatingPoint arg(node[1].getConst<FloatingPoint>());
874
875
    FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false));
876
877
    if (res.second) {
878
      Node lit = NodeManager::currentNM()->mkConst(res.first);
879
      return RewriteResponse(REWRITE_DONE, lit);
880
    } else {
881
      // Can't constant fold the underspecified case
882
      return RewriteResponse(REWRITE_DONE, node);
883
    }
884
  }
885
886
  RewriteResponse convertToSBV(TNode node, bool isPreRewrite)
887
  {
888
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV);
889
890
    TNode op = node.getOperator();
891
    const BitVectorSize& size = op.getConst<FloatingPointToSBV>().d_bv_size;
892
893
    RoundingMode rm(node[0].getConst<RoundingMode>());
894
    FloatingPoint arg(node[1].getConst<FloatingPoint>());
895
896
    FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true));
897
898
    if (res.second) {
899
      Node lit = NodeManager::currentNM()->mkConst(res.first);
900
      return RewriteResponse(REWRITE_DONE, lit);
901
    } else {
902
      // Can't constant fold the underspecified case
903
      return RewriteResponse(REWRITE_DONE, node);
904
    }
905
  }
906
907
4
  RewriteResponse convertToReal(TNode node, bool isPreRewrite)
908
  {
909
4
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL);
910
911
8
    FloatingPoint arg(node[0].getConst<FloatingPoint>());
912
913
8
    FloatingPoint::PartialRational res(arg.convertToRational());
914
915
4
    if (res.second) {
916
      Node lit = NodeManager::currentNM()->mkConst(res.first);
917
      return RewriteResponse(REWRITE_DONE, lit);
918
    } else {
919
      // Can't constant fold the underspecified case
920
4
      return RewriteResponse(REWRITE_DONE, node);
921
    }
922
  }
923
924
  RewriteResponse convertToUBVTotal(TNode node, bool isPreRewrite)
925
  {
926
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL);
927
928
    TNode op = node.getOperator();
929
    const BitVectorSize& size =
930
        op.getConst<FloatingPointToUBVTotal>().d_bv_size;
931
932
    RoundingMode rm(node[0].getConst<RoundingMode>());
933
    FloatingPoint arg(node[1].getConst<FloatingPoint>());
934
935
    // Can be called with the third argument non-constant
936
    if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
937
      BitVector partialValue(node[2].getConst<BitVector>());
938
939
      BitVector folded(arg.convertToBVTotal(size, rm, false, partialValue));
940
      Node lit = NodeManager::currentNM()->mkConst(folded);
941
      return RewriteResponse(REWRITE_DONE, lit);
942
943
    } else {
944
      FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false));
945
946
      if (res.second) {
947
	Node lit = NodeManager::currentNM()->mkConst(res.first);
948
	return RewriteResponse(REWRITE_DONE, lit);
949
      } else {
950
	// Can't constant fold the underspecified case
951
	return RewriteResponse(REWRITE_DONE, node);
952
      }
953
    }
954
  }
955
956
  RewriteResponse convertToSBVTotal(TNode node, bool isPreRewrite)
957
  {
958
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL);
959
960
    TNode op = node.getOperator();
961
    const BitVectorSize& size =
962
        op.getConst<FloatingPointToSBVTotal>().d_bv_size;
963
964
    RoundingMode rm(node[0].getConst<RoundingMode>());
965
    FloatingPoint arg(node[1].getConst<FloatingPoint>());
966
967
    // Can be called with the third argument non-constant
968
    if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
969
      BitVector partialValue(node[2].getConst<BitVector>());
970
971
      BitVector folded(arg.convertToBVTotal(size, rm, true, partialValue));
972
      Node lit = NodeManager::currentNM()->mkConst(folded);
973
      return RewriteResponse(REWRITE_DONE, lit);
974
975
    } else {
976
      FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true));
977
978
      if (res.second) {
979
	Node lit = NodeManager::currentNM()->mkConst(res.first);
980
	return RewriteResponse(REWRITE_DONE, lit);
981
      } else {
982
	// Can't constant fold the underspecified case
983
	return RewriteResponse(REWRITE_DONE, node);
984
      }
985
    }
986
  }
987
988
4
  RewriteResponse convertToRealTotal(TNode node, bool isPreRewrite)
989
  {
990
4
    Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL_TOTAL);
991
992
8
    FloatingPoint arg(node[0].getConst<FloatingPoint>());
993
994
    // Can be called with the third argument non-constant
995
4
    if (node[1].getMetaKind() == kind::metakind::CONSTANT) {
996
8
      Rational partialValue(node[1].getConst<Rational>());
997
998
8
      Rational folded(arg.convertToRationalTotal(partialValue));
999
8
      Node lit = NodeManager::currentNM()->mkConst(folded);
1000
4
      return RewriteResponse(REWRITE_DONE, lit);
1001
1002
    } else {
1003
      FloatingPoint::PartialRational res(arg.convertToRational());
1004
1005
      if (res.second) {
1006
	Node lit = NodeManager::currentNM()->mkConst(res.first);
1007
	return RewriteResponse(REWRITE_DONE, lit);
1008
      } else {
1009
	// Can't constant fold the underspecified case
1010
	return RewriteResponse(REWRITE_DONE, node);
1011
      }
1012
    }
1013
  }
1014
1015
24
  RewriteResponse componentFlag(TNode node, bool isPreRewrite)
1016
  {
1017
24
    Kind k = node.getKind();
1018
1019
24
    Assert((k == kind::FLOATINGPOINT_COMPONENT_NAN)
1020
           || (k == kind::FLOATINGPOINT_COMPONENT_INF)
1021
           || (k == kind::FLOATINGPOINT_COMPONENT_ZERO)
1022
           || (k == kind::FLOATINGPOINT_COMPONENT_SIGN));
1023
1024
48
    FloatingPoint arg0(node[0].getConst<FloatingPoint>());
1025
1026
    bool result;
1027
24
    switch (k)
1028
    {
1029
6
      case kind::FLOATINGPOINT_COMPONENT_NAN: result = arg0.isNaN(); break;
1030
6
      case kind::FLOATINGPOINT_COMPONENT_INF: result = arg0.isInfinite(); break;
1031
6
      case kind::FLOATINGPOINT_COMPONENT_ZERO: result = arg0.isZero(); break;
1032
6
      case kind::FLOATINGPOINT_COMPONENT_SIGN: result = arg0.getSign(); break;
1033
      default: Unreachable() << "Unknown kind used in componentFlag"; break;
1034
    }
1035
1036
48
    BitVector res(1U, (result) ? 1U : 0U);
1037
1038
    return RewriteResponse(REWRITE_DONE,
1039
48
                           NodeManager::currentNM()->mkConst(res));
1040
  }
1041
1042
6
  RewriteResponse componentExponent(TNode node, bool isPreRewrite)
1043
  {
1044
6
    Assert(node.getKind() == kind::FLOATINGPOINT_COMPONENT_EXPONENT);
1045
1046
12
    FloatingPoint arg0(node[0].getConst<FloatingPoint>());
1047
1048
    // \todo Add a proper interface for this sort of thing to FloatingPoint #1915
1049
    return RewriteResponse(
1050
        REWRITE_DONE,
1051
12
        NodeManager::currentNM()->mkConst((BitVector)arg0.getExponent())
1052
12
    );
1053
  }
1054
1055
6
  RewriteResponse componentSignificand(TNode node, bool isPreRewrite)
1056
  {
1057
6
    Assert(node.getKind() == kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND);
1058
1059
12
    FloatingPoint arg0(node[0].getConst<FloatingPoint>());
1060
1061
    return RewriteResponse(
1062
        REWRITE_DONE,
1063
12
        NodeManager::currentNM()->mkConst((BitVector)arg0.getSignificand())
1064
12
    );
1065
  }
1066
1067
1
  RewriteResponse roundingModeBitBlast(TNode node, bool isPreRewrite)
1068
  {
1069
1
    Assert(node.getKind() == kind::ROUNDINGMODE_BITBLAST);
1070
1071
2
    BitVector value;
1072
1073
    /* \todo fix the numbering of rounding modes so this doesn't need
1074
     * to call symfpu at all and remove the dependency on fp_converter.h #1915 */
1075
1
    RoundingMode arg0(node[0].getConst<RoundingMode>());
1076
1
    switch (arg0)
1077
    {
1078
      case RoundingMode::ROUND_NEAREST_TIES_TO_EVEN:
1079
        value = symfpuSymbolic::traits::RNE().getConst<BitVector>();
1080
        break;
1081
1082
      case RoundingMode::ROUND_NEAREST_TIES_TO_AWAY:
1083
        value = symfpuSymbolic::traits::RNA().getConst<BitVector>();
1084
        break;
1085
1086
      case RoundingMode::ROUND_TOWARD_POSITIVE:
1087
        value = symfpuSymbolic::traits::RTP().getConst<BitVector>();
1088
        break;
1089
1090
1
      case RoundingMode::ROUND_TOWARD_NEGATIVE:
1091
1
        value = symfpuSymbolic::traits::RTN().getConst<BitVector>();
1092
1
        break;
1093
1094
      case RoundingMode::ROUND_TOWARD_ZERO:
1095
        value = symfpuSymbolic::traits::RTZ().getConst<BitVector>();
1096
        break;
1097
1098
      default:
1099
        Unreachable() << "Unknown rounding mode in roundingModeBitBlast";
1100
        break;
1101
    }
1102
    return RewriteResponse(REWRITE_DONE,
1103
2
                           NodeManager::currentNM()->mkConst(value));
1104
  }
1105
1106
  };  // namespace constantFold
1107
1108
  /**
1109
   * Initialize the rewriter.
1110
   */
1111
9853
TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
1112
{
1113
  /* Set up the pre-rewrite dispatch table */
1114
3221931
  for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
1115
  {
1116
3212078
    d_preRewriteTable[i] = rewrite::notFP;
1117
  }
1118
1119
  /******** Constants ********/
1120
  /* No rewriting possible for constants */
1121
9853
  d_preRewriteTable[kind::CONST_FLOATINGPOINT] = rewrite::identity;
1122
9853
  d_preRewriteTable[kind::CONST_ROUNDINGMODE] = rewrite::identity;
1123
1124
  /******** Sorts(?) ********/
1125
  /* These kinds should only appear in types */
1126
  // d_preRewriteTable[kind::ROUNDINGMODE_TYPE] = rewrite::type;
1127
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type;
1128
1129
  /******** Operations ********/
1130
9853
  d_preRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity;
1131
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs;
1132
9853
  d_preRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation;
1133
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ADD] = rewrite::identity;
1134
9853
  d_preRewriteTable[kind::FLOATINGPOINT_SUB] =
1135
      rewrite::convertSubtractionToAddition;
1136
9853
  d_preRewriteTable[kind::FLOATINGPOINT_MULT] = rewrite::identity;
1137
9853
  d_preRewriteTable[kind::FLOATINGPOINT_DIV] = rewrite::identity;
1138
9853
  d_preRewriteTable[kind::FLOATINGPOINT_FMA] = rewrite::identity;
1139
9853
  d_preRewriteTable[kind::FLOATINGPOINT_SQRT] = rewrite::identity;
1140
9853
  d_preRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::identity;
1141
9853
  d_preRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity;
1142
9853
  d_preRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax;
1143
9853
  d_preRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax;
1144
9853
  d_preRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax;
1145
9853
  d_preRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax;
1146
1147
  /******** Comparisons ********/
1148
9853
  d_preRewriteTable[kind::FLOATINGPOINT_EQ] =
1149
      rewrite::then<rewrite::breakChain, rewrite::ieeeEqToEq>;
1150
9853
  d_preRewriteTable[kind::FLOATINGPOINT_LEQ] =
1151
      rewrite::then<rewrite::breakChain, rewrite::leqId>;
1152
9853
  d_preRewriteTable[kind::FLOATINGPOINT_LT] =
1153
      rewrite::then<rewrite::breakChain, rewrite::ltId>;
1154
9853
  d_preRewriteTable[kind::FLOATINGPOINT_GEQ] =
1155
      rewrite::then<rewrite::breakChain, rewrite::geqToleq>;
1156
9853
  d_preRewriteTable[kind::FLOATINGPOINT_GT] =
1157
      rewrite::then<rewrite::breakChain, rewrite::gtTolt>;
1158
1159
  /******** Classifications ********/
1160
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ISN] = rewrite::identity;
1161
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ISSN] = rewrite::identity;
1162
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ISZ] = rewrite::identity;
1163
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ISINF] = rewrite::identity;
1164
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ISNAN] = rewrite::identity;
1165
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ISNEG] = rewrite::identity;
1166
9853
  d_preRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity;
1167
1168
  /******** Conversions ********/
1169
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] =
1170
      rewrite::identity;
1171
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] =
1172
      rewrite::identity;
1173
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity;
1174
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] =
1175
      rewrite::identity;
1176
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] =
1177
      rewrite::identity;
1178
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::identity;
1179
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity;
1180
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity;
1181
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity;
1182
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity;
1183
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity;
1184
9853
  d_preRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity;
1185
1186
  /******** Variables ********/
1187
9853
  d_preRewriteTable[kind::VARIABLE] = rewrite::variable;
1188
9853
  d_preRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable;
1189
9853
  d_preRewriteTable[kind::SKOLEM] = rewrite::variable;
1190
9853
  d_preRewriteTable[kind::INST_CONSTANT] = rewrite::variable;
1191
1192
9853
  d_preRewriteTable[kind::EQUAL] = rewrite::equal;
1193
1194
  /******** Components for bit-blasting ********/
1195
9853
  d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_NAN] = rewrite::identity;
1196
9853
  d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_INF] = rewrite::identity;
1197
9853
  d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = rewrite::identity;
1198
9853
  d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = rewrite::identity;
1199
9853
  d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] = rewrite::identity;
1200
9853
  d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] =
1201
      rewrite::identity;
1202
9853
  d_preRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity;
1203
1204
  /* Set up the post-rewrite dispatch table */
1205
3221931
  for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
1206
  {
1207
3212078
    d_postRewriteTable[i] = rewrite::notFP;
1208
  }
1209
1210
  /******** Constants ********/
1211
  /* No rewriting possible for constants */
1212
9853
  d_postRewriteTable[kind::CONST_FLOATINGPOINT] = rewrite::identity;
1213
9853
  d_postRewriteTable[kind::CONST_ROUNDINGMODE] = rewrite::identity;
1214
1215
  /******** Sorts(?) ********/
1216
  /* These kinds should only appear in types */
1217
  // d_postRewriteTable[kind::ROUNDINGMODE_TYPE] = rewrite::type;
1218
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type;
1219
1220
  /******** Operations ********/
1221
9853
  d_postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity;
1222
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs;
1223
9853
  d_postRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation;
1224
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ADD] = rewrite::reorderBinaryOperation;
1225
9853
  d_postRewriteTable[kind::FLOATINGPOINT_SUB] = rewrite::identity;
1226
9853
  d_postRewriteTable[kind::FLOATINGPOINT_MULT] =
1227
      rewrite::reorderBinaryOperation;
1228
9853
  d_postRewriteTable[kind::FLOATINGPOINT_DIV] = rewrite::identity;
1229
9853
  d_postRewriteTable[kind::FLOATINGPOINT_FMA] = rewrite::reorderFMA;
1230
9853
  d_postRewriteTable[kind::FLOATINGPOINT_SQRT] = rewrite::identity;
1231
9853
  d_postRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::compactRemainder;
1232
9853
  d_postRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity;
1233
9853
  d_postRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax;
1234
9853
  d_postRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax;
1235
9853
  d_postRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax;
1236
9853
  d_postRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax;
1237
1238
  /******** Comparisons ********/
1239
9853
  d_postRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::identity;
1240
9853
  d_postRewriteTable[kind::FLOATINGPOINT_LEQ] = rewrite::leqId;
1241
9853
  d_postRewriteTable[kind::FLOATINGPOINT_LT] = rewrite::ltId;
1242
9853
  d_postRewriteTable[kind::FLOATINGPOINT_GEQ] = rewrite::identity;
1243
9853
  d_postRewriteTable[kind::FLOATINGPOINT_GT] = rewrite::identity;
1244
1245
  /******** Classifications ********/
1246
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ISN] = rewrite::removeSignOperations;
1247
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ISSN] = rewrite::removeSignOperations;
1248
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ISZ] = rewrite::removeSignOperations;
1249
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ISINF] = rewrite::removeSignOperations;
1250
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ISNAN] = rewrite::removeSignOperations;
1251
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ISNEG] = rewrite::identity;
1252
9853
  d_postRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity;
1253
1254
  /******** Conversions ********/
1255
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] =
1256
      rewrite::identity;
1257
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] =
1258
      rewrite::identity;
1259
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity;
1260
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] =
1261
      rewrite::toFPSignedBV;
1262
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] =
1263
      rewrite::identity;
1264
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] =
1265
      rewrite::removeToFPGeneric;
1266
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity;
1267
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity;
1268
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity;
1269
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity;
1270
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity;
1271
9853
  d_postRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity;
1272
1273
  /******** Variables ********/
1274
9853
  d_postRewriteTable[kind::VARIABLE] = rewrite::variable;
1275
9853
  d_postRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable;
1276
9853
  d_postRewriteTable[kind::SKOLEM] = rewrite::variable;
1277
9853
  d_postRewriteTable[kind::INST_CONSTANT] = rewrite::variable;
1278
1279
9853
  d_postRewriteTable[kind::EQUAL] = rewrite::equal;
1280
1281
  /******** Components for bit-blasting ********/
1282
9853
  d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_NAN] = rewrite::identity;
1283
9853
  d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_INF] = rewrite::identity;
1284
9853
  d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = rewrite::identity;
1285
9853
  d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = rewrite::identity;
1286
9853
  d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] =
1287
      rewrite::identity;
1288
9853
  d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] =
1289
      rewrite::identity;
1290
9853
  d_postRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity;
1291
1292
  /* Set up the post-rewrite constant fold table */
1293
3221931
  for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
1294
  {
1295
    // Note that this is identity, not notFP
1296
    // Constant folding is called after post-rewrite
1297
    // So may have to deal with cases of things being
1298
    // re-written to non-floating-point sorts (i.e. true).
1299
3212078
    d_constantFoldTable[i] = rewrite::identity;
1300
  }
1301
1302
  /******** Constants ********/
1303
  /* Already folded! */
1304
9853
  d_constantFoldTable[kind::CONST_FLOATINGPOINT] = rewrite::identity;
1305
9853
  d_constantFoldTable[kind::CONST_ROUNDINGMODE] = rewrite::identity;
1306
1307
  /******** Sorts(?) ********/
1308
  /* These kinds should only appear in types */
1309
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TYPE] = rewrite::type;
1310
1311
  /******** Operations ********/
1312
9853
  d_constantFoldTable[kind::FLOATINGPOINT_FP] = constantFold::fpLiteral;
1313
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ABS] = constantFold::abs;
1314
9853
  d_constantFoldTable[kind::FLOATINGPOINT_NEG] = constantFold::neg;
1315
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ADD] = constantFold::add;
1316
9853
  d_constantFoldTable[kind::FLOATINGPOINT_MULT] = constantFold::mult;
1317
9853
  d_constantFoldTable[kind::FLOATINGPOINT_DIV] = constantFold::div;
1318
9853
  d_constantFoldTable[kind::FLOATINGPOINT_FMA] = constantFold::fma;
1319
9853
  d_constantFoldTable[kind::FLOATINGPOINT_SQRT] = constantFold::sqrt;
1320
9853
  d_constantFoldTable[kind::FLOATINGPOINT_REM] = constantFold::rem;
1321
9853
  d_constantFoldTable[kind::FLOATINGPOINT_RTI] = constantFold::rti;
1322
9853
  d_constantFoldTable[kind::FLOATINGPOINT_MIN] = constantFold::min;
1323
9853
  d_constantFoldTable[kind::FLOATINGPOINT_MAX] = constantFold::max;
1324
9853
  d_constantFoldTable[kind::FLOATINGPOINT_MIN_TOTAL] = constantFold::minTotal;
1325
9853
  d_constantFoldTable[kind::FLOATINGPOINT_MAX_TOTAL] = constantFold::maxTotal;
1326
1327
  /******** Comparisons ********/
1328
9853
  d_constantFoldTable[kind::FLOATINGPOINT_LEQ] = constantFold::leq;
1329
9853
  d_constantFoldTable[kind::FLOATINGPOINT_LT] = constantFold::lt;
1330
1331
  /******** Classifications ********/
1332
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ISN] = constantFold::isNormal;
1333
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ISSN] = constantFold::isSubnormal;
1334
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ISZ] = constantFold::isZero;
1335
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ISINF] = constantFold::isInfinite;
1336
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ISNAN] = constantFold::isNaN;
1337
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ISNEG] = constantFold::isNegative;
1338
9853
  d_constantFoldTable[kind::FLOATINGPOINT_ISPOS] = constantFold::isPositive;
1339
1340
  /******** Conversions ********/
1341
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] =
1342
      constantFold::convertFromIEEEBitVectorLiteral;
1343
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] =
1344
      constantFold::constantConvert;
1345
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_REAL] =
1346
      constantFold::convertFromRealLiteral;
1347
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] =
1348
      constantFold::convertFromSBV;
1349
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] =
1350
      constantFold::convertFromUBV;
1351
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_UBV] = constantFold::convertToUBV;
1352
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_SBV] = constantFold::convertToSBV;
1353
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_REAL] =
1354
      constantFold::convertToReal;
1355
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] =
1356
      constantFold::convertToUBVTotal;
1357
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] =
1358
      constantFold::convertToSBVTotal;
1359
9853
  d_constantFoldTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] =
1360
      constantFold::convertToRealTotal;
1361
1362
  /******** Variables ********/
1363
9853
  d_constantFoldTable[kind::VARIABLE] = rewrite::variable;
1364
9853
  d_constantFoldTable[kind::BOUND_VARIABLE] = rewrite::variable;
1365
1366
9853
  d_constantFoldTable[kind::EQUAL] = constantFold::equal;
1367
1368
  /******** Components for bit-blasting ********/
1369
9853
  d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_NAN] =
1370
      constantFold::componentFlag;
1371
9853
  d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_INF] =
1372
      constantFold::componentFlag;
1373
9853
  d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_ZERO] =
1374
      constantFold::componentFlag;
1375
9853
  d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_SIGN] =
1376
      constantFold::componentFlag;
1377
9853
  d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] =
1378
      constantFold::componentExponent;
1379
9853
  d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] =
1380
      constantFold::componentSignificand;
1381
9853
  d_constantFoldTable[kind::ROUNDINGMODE_BITBLAST] =
1382
      constantFold::roundingModeBitBlast;
1383
9853
}
1384
1385
  /**
1386
   * Rewrite a node into the normal form for the theory of fp
1387
   * in pre-order (really topological order)---meaning that the
1388
   * children may not be in the normal form.  This is an optimization
1389
   * for theories with cancelling terms (e.g., 0 * (big-nasty-expression)
1390
   * in arithmetic rewrites to 0 without the need to look at the big
1391
   * nasty expression).  Since it's only an optimization, the
1392
   * implementation here can do nothing.
1393
   */
1394
1395
5130
  RewriteResponse TheoryFpRewriter::preRewrite(TNode node) {
1396
5130
    Trace("fp-rewrite") << "TheoryFpRewriter::preRewrite(): " << node << std::endl;
1397
5130
    RewriteResponse res = d_preRewriteTable[node.getKind()](node, true);
1398
5130
    if (res.d_node != node)
1399
    {
1400
130
      Debug("fp-rewrite") << "TheoryFpRewriter::preRewrite(): before " << node << std::endl;
1401
260
      Debug("fp-rewrite") << "TheoryFpRewriter::preRewrite(): after  "
1402
130
                          << res.d_node << std::endl;
1403
    }
1404
5130
    return res;
1405
  }
1406
1407
1408
  /**
1409
   * Rewrite a node into the normal form for the theory of fp.
1410
   * Called in post-order (really reverse-topological order) when
1411
   * traversing the expression DAG during rewriting.  This is the
1412
   * main function of the rewriter, and because of the ordering,
1413
   * it can assume its children are all rewritten already.
1414
   *
1415
   * This function can return one of three rewrite response codes
1416
   * along with the rewritten node:
1417
   *
1418
   *   REWRITE_DONE indicates that no more rewriting is needed.
1419
   *   REWRITE_AGAIN means that the top-level expression should be
1420
   *     rewritten again, but that its children are in final form.
1421
   *   REWRITE_AGAIN_FULL means that the entire returned expression
1422
   *     should be rewritten again (top-down with preRewrite(), then
1423
   *     bottom-up with postRewrite()).
1424
   *
1425
   * Even if this function returns REWRITE_DONE, if the returned
1426
   * expression belongs to a different theory, it will be fully
1427
   * rewritten by that theory's rewriter.
1428
   */
1429
1430
9237
  RewriteResponse TheoryFpRewriter::postRewrite(TNode node) {
1431
9237
    Trace("fp-rewrite") << "TheoryFpRewriter::postRewrite(): " << node << std::endl;
1432
18474
    RewriteResponse res = d_postRewriteTable[node.getKind()](node, false);
1433
9237
    if (res.d_node != node)
1434
    {
1435
569
      Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before " << node << std::endl;
1436
1138
      Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after  "
1437
569
                          << res.d_node << std::endl;
1438
    }
1439
1440
9237
    if (res.d_status == REWRITE_DONE)
1441
    {
1442
9127
      bool allChildrenConst = true;
1443
9127
      bool apartFromRoundingMode = false;
1444
9127
      bool apartFromPartiallyDefinedArgument = false;
1445
15339
      for (Node::const_iterator i = res.d_node.begin(); i != res.d_node.end();
1446
           ++i)
1447
      {
1448
10544
        if ((*i).getMetaKind() != kind::metakind::CONSTANT) {
1449
4959
	  if ((*i).getType().isRoundingMode() && !apartFromRoundingMode) {
1450
523
	    apartFromRoundingMode = true;
1451
          }
1452
8872
          else if ((res.d_node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL
1453
4436
                    || res.d_node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL
1454
4262
                    || res.d_node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL
1455
4262
                    || res.d_node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL
1456
4252
                    || res.d_node.getKind()
1457
                           == kind::FLOATINGPOINT_TO_REAL_TOTAL)
1458
4628
                   && ((*i).getType().isBitVector() || (*i).getType().isReal())
1459
8976
                   && !apartFromPartiallyDefinedArgument)
1460
          {
1461
104
            apartFromPartiallyDefinedArgument = true;
1462
          }
1463
          else
1464
          {
1465
4332
            allChildrenConst = false;
1466
4332
	    break;
1467
          }
1468
        }
1469
      }
1470
1471
9127
      if (allChildrenConst)
1472
      {
1473
4795
        RewriteStatus rs = REWRITE_DONE;  // This is a bit messy because
1474
9590
        Node rn = res.d_node;             // RewriteResponse is too functional..
1475
1476
4795
        if (apartFromRoundingMode)
1477
        {
1478
478
          if (!(res.d_node.getKind() == kind::EQUAL)
1479
376
              &&  // Avoid infinite recursion...
1480
137
              !(res.d_node.getKind() == kind::ROUNDINGMODE_BITBLAST))
1481
          {
1482
            // Don't eliminate the bit-blast
1483
            // We are close to being able to constant fold this
1484
            // and in many cases the rounding mode really doesn't matter.
1485
            // So we can try brute forcing our way through them.
1486
1487
125
            NodeManager* nm = NodeManager::currentNM();
1488
1489
250
            Node rne(nm->mkConst(RoundingMode::ROUND_NEAREST_TIES_TO_EVEN));
1490
250
            Node rna(nm->mkConst(RoundingMode::ROUND_NEAREST_TIES_TO_AWAY));
1491
250
            Node rtz(nm->mkConst(RoundingMode::ROUND_TOWARD_POSITIVE));
1492
250
            Node rtn(nm->mkConst(RoundingMode::ROUND_TOWARD_NEGATIVE));
1493
250
            Node rtp(nm->mkConst(RoundingMode::ROUND_TOWARD_ZERO));
1494
1495
250
            TNode rm(res.d_node[0]);
1496
1497
250
            Node w_rne(res.d_node.substitute(rm, TNode(rne)));
1498
250
            Node w_rna(res.d_node.substitute(rm, TNode(rna)));
1499
250
            Node w_rtz(res.d_node.substitute(rm, TNode(rtz)));
1500
250
            Node w_rtn(res.d_node.substitute(rm, TNode(rtn)));
1501
250
            Node w_rtp(res.d_node.substitute(rm, TNode(rtp)));
1502
1503
125
            rs = REWRITE_AGAIN_FULL;
1504
375
            rn = nm->mkNode(
1505
                kind::ITE,
1506
250
                nm->mkNode(kind::EQUAL, rm, rne),
1507
                w_rne,
1508
500
                nm->mkNode(
1509
                    kind::ITE,
1510
250
                    nm->mkNode(kind::EQUAL, rm, rna),
1511
                    w_rna,
1512
500
                    nm->mkNode(kind::ITE,
1513
250
                               nm->mkNode(kind::EQUAL, rm, rtz),
1514
                               w_rtz,
1515
500
                               nm->mkNode(kind::ITE,
1516
250
                                          nm->mkNode(kind::EQUAL, rm, rtn),
1517
                                          w_rtn,
1518
                                          w_rtp))));
1519
          }
1520
        }
1521
        else
1522
        {
1523
          RewriteResponse tmp =
1524
9112
              d_constantFoldTable[res.d_node.getKind()](res.d_node, false);
1525
4556
          rs = tmp.d_status;
1526
4556
          rn = tmp.d_node;
1527
        }
1528
1529
9590
        RewriteResponse constRes(rs, rn);
1530
1531
4795
        if (constRes.d_node != res.d_node)
1532
        {
1533
4202
          Debug("fp-rewrite")
1534
2101
              << "TheoryFpRewriter::postRewrite(): before constant fold "
1535
2101
              << res.d_node << std::endl;
1536
4202
          Debug("fp-rewrite")
1537
2101
              << "TheoryFpRewriter::postRewrite(): after constant fold "
1538
2101
              << constRes.d_node << std::endl;
1539
        }
1540
1541
4795
        return constRes;
1542
      }
1543
    }
1544
1545
4442
    return res;
1546
  }
1547
2730
  TrustNode TheoryFpRewriter::expandDefinition(Node node)
1548
  {
1549
2730
    return d_fpExpDef.expandDefinition(node);
1550
  }
1551
1552
  }  // namespace fp
1553
  }  // namespace theory
1554
29340
  }  // namespace cvc5