GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_rewriter.cpp Lines: 437 497 87.9 %
Date: 2021-03-22 Branches: 1022 2423 42.2 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file arith_rewriter.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Tim King, Morgan Deters
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief [[ Add one-line brief description here ]]
13
 **
14
 ** [[ Add lengthier description here ]]
15
 ** \todo document this file
16
 **/
17
18
#include <set>
19
#include <stack>
20
#include <vector>
21
22
#include "smt/logic_exception.h"
23
#include "theory/arith/arith_msum.h"
24
#include "theory/arith/arith_rewriter.h"
25
#include "theory/arith/arith_utilities.h"
26
#include "theory/arith/normal_form.h"
27
#include "theory/theory.h"
28
#include "util/iand.h"
29
30
namespace CVC4 {
31
namespace theory {
32
namespace arith {
33
34
13265943
bool ArithRewriter::isAtom(TNode n) {
35
13265943
  Kind k = n.getKind();
36
17268926
  return arith::isRelationOperator(k) || k == kind::IS_INTEGER
37
17268812
      || k == kind::DIVISIBLE;
38
}
39
40
608328
RewriteResponse ArithRewriter::rewriteConstant(TNode t){
41
608328
  Assert(t.isConst());
42
608328
  Assert(t.getKind() == kind::CONST_RATIONAL);
43
44
608328
  return RewriteResponse(REWRITE_DONE, t);
45
}
46
47
RewriteResponse ArithRewriter::rewriteVariable(TNode t){
48
  Assert(t.isVar());
49
50
  return RewriteResponse(REWRITE_DONE, t);
51
}
52
53
263055
RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
54
263055
  Assert(t.getKind() == kind::MINUS);
55
56
263055
  if(pre){
57
263053
    if(t[0] == t[1]){
58
5018
      Rational zero(0);
59
5018
      Node zeroNode  = mkRationalNode(zero);
60
2509
      return RewriteResponse(REWRITE_DONE, zeroNode);
61
    }else{
62
521088
      Node noMinus = makeSubtractionNode(t[0],t[1]);
63
260544
      return RewriteResponse(REWRITE_DONE, noMinus);
64
    }
65
  }else{
66
4
    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
67
4
    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
68
4
    Polynomial diff = minuend - subtrahend;
69
2
    return RewriteResponse(REWRITE_DONE, diff.getNode());
70
  }
71
}
72
73
13447
RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
74
13447
  Assert(t.getKind() == kind::UMINUS);
75
76
13447
  if(t[0].getKind() == kind::CONST_RATIONAL){
77
10432
    Rational neg = -(t[0].getConst<Rational>());
78
5216
    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
79
  }
80
81
16462
  Node noUminus = makeUnaryMinusNode(t[0]);
82
8231
  if(pre)
83
8231
    return RewriteResponse(REWRITE_DONE, noUminus);
84
  else
85
    return RewriteResponse(REWRITE_AGAIN, noUminus);
86
}
87
88
1407158
RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
89
1407158
  if(t.isConst()){
90
125415
    return rewriteConstant(t);
91
1281743
  }else if(t.isVar()){
92
    return rewriteVariable(t);
93
  }else{
94
1281743
    switch(Kind k = t.getKind()){
95
263053
    case kind::MINUS:
96
263053
      return rewriteMinus(t, true);
97
13443
    case kind::UMINUS:
98
13443
      return rewriteUMinus(t, true);
99
2792
    case kind::DIVISION:
100
    case kind::DIVISION_TOTAL:
101
2792
      return rewriteDiv(t,true);
102
614422
    case kind::PLUS:
103
614422
      return preRewritePlus(t);
104
375225
    case kind::MULT:
105
375225
    case kind::NONLINEAR_MULT: return preRewriteMult(t);
106
279
    case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
107
2025
    case kind::EXPONENTIAL:
108
    case kind::SINE:
109
    case kind::COSINE:
110
    case kind::TANGENT:
111
    case kind::COSECANT:
112
    case kind::SECANT:
113
    case kind::COTANGENT:
114
    case kind::ARCSINE:
115
    case kind::ARCCOSINE:
116
    case kind::ARCTANGENT:
117
    case kind::ARCCOSECANT:
118
    case kind::ARCSECANT:
119
    case kind::ARCCOTANGENT:
120
2025
    case kind::SQRT: return preRewriteTranscendental(t);
121
742
    case kind::INTS_DIVISION:
122
742
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
123
5032
    case kind::INTS_DIVISION_TOTAL:
124
    case kind::INTS_MODULUS_TOTAL:
125
5032
      return rewriteIntsDivModTotal(t,true);
126
16
    case kind::ABS:
127
16
      if(t[0].isConst()) {
128
9
        const Rational& rat = t[0].getConst<Rational>();
129
9
        if(rat >= 0) {
130
9
          return RewriteResponse(REWRITE_DONE, t[0]);
131
        } else {
132
          return RewriteResponse(REWRITE_DONE,
133
                                 NodeManager::currentNM()->mkConst(-rat));
134
        }
135
7
      }
136
7
      return RewriteResponse(REWRITE_DONE, t);
137
209
    case kind::IS_INTEGER:
138
    case kind::TO_INTEGER:
139
209
      return RewriteResponse(REWRITE_DONE, t);
140
4332
    case kind::TO_REAL:
141
4332
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
142
88
    case kind::POW:
143
88
      return RewriteResponse(REWRITE_DONE, t);
144
85
    case kind::PI:
145
85
      return RewriteResponse(REWRITE_DONE, t);
146
    default: Unhandled() << k;
147
    }
148
  }
149
}
150
151
2595711
RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
152
2595711
  if(t.isConst()){
153
482913
    return rewriteConstant(t);
154
2112798
  }else if(t.isVar()){
155
    return rewriteVariable(t);
156
  }else{
157
2112798
    switch(t.getKind()){
158
2
    case kind::MINUS:
159
2
      return rewriteMinus(t, false);
160
4
    case kind::UMINUS:
161
4
      return rewriteUMinus(t, false);
162
1202
    case kind::DIVISION:
163
    case kind::DIVISION_TOTAL:
164
1202
      return rewriteDiv(t, false);
165
1563950
    case kind::PLUS:
166
1563950
      return postRewritePlus(t);
167
533862
    case kind::MULT:
168
533862
    case kind::NONLINEAR_MULT: return postRewriteMult(t);
169
475
    case kind::IAND: return postRewriteIAnd(t);
170
3961
    case kind::EXPONENTIAL:
171
    case kind::SINE:
172
    case kind::COSINE:
173
    case kind::TANGENT:
174
    case kind::COSECANT:
175
    case kind::SECANT:
176
    case kind::COTANGENT:
177
    case kind::ARCSINE:
178
    case kind::ARCCOSINE:
179
    case kind::ARCTANGENT:
180
    case kind::ARCCOSECANT:
181
    case kind::ARCSECANT:
182
    case kind::ARCCOTANGENT:
183
3961
    case kind::SQRT: return postRewriteTranscendental(t);
184
1015
    case kind::INTS_DIVISION:
185
1015
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
186
7565
    case kind::INTS_DIVISION_TOTAL:
187
    case kind::INTS_MODULUS_TOTAL:
188
7565
      return rewriteIntsDivModTotal(t, false);
189
14
    case kind::ABS:
190
14
      if(t[0].isConst()) {
191
        const Rational& rat = t[0].getConst<Rational>();
192
        if(rat >= 0) {
193
          return RewriteResponse(REWRITE_DONE, t[0]);
194
        } else {
195
          return RewriteResponse(REWRITE_DONE,
196
                                 NodeManager::currentNM()->mkConst(-rat));
197
        }
198
14
      }
199
14
      return RewriteResponse(REWRITE_DONE, t);
200
    case kind::TO_REAL:
201
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
202
368
    case kind::TO_INTEGER:
203
368
      if(t[0].isConst()) {
204
61
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
205
      }
206
307
      if(t[0].getType().isInteger()) {
207
1
        return RewriteResponse(REWRITE_DONE, t[0]);
208
      }
209
      //Unimplemented() << "TO_INTEGER, nonconstant";
210
      //return rewriteToInteger(t);
211
306
      return RewriteResponse(REWRITE_DONE, t);
212
    case kind::IS_INTEGER:
213
      if(t[0].isConst()) {
214
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
215
      }
216
      if(t[0].getType().isInteger()) {
217
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
218
      }
219
      //Unimplemented() << "IS_INTEGER, nonconstant";
220
      //return rewriteIsInteger(t);
221
      return RewriteResponse(REWRITE_DONE, t);
222
158
    case kind::POW:
223
      {
224
158
        if(t[1].getKind() == kind::CONST_RATIONAL){
225
157
          const Rational& exp = t[1].getConst<Rational>();
226
158
          TNode base = t[0];
227
157
          if(exp.sgn() == 0){
228
60
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
229
97
          }else if(exp.sgn() > 0 && exp.isIntegral()){
230
98
            CVC4::Rational r(expr::NodeValue::MAX_CHILDREN);
231
97
            if (exp <= r)
232
            {
233
96
              unsigned num = exp.getNumerator().toUnsignedInt();
234
96
              if( num==1 ){
235
4
                return RewriteResponse(REWRITE_AGAIN, base);
236
              }else{
237
184
                NodeBuilder<> nb(kind::MULT);
238
460
                for(unsigned i=0; i < num; ++i){
239
368
                  nb << base;
240
                }
241
92
                Assert(nb.getNumChildren() > 0);
242
184
                Node mult = nb;
243
92
                return RewriteResponse(REWRITE_AGAIN, mult);
244
              }
245
            }
246
          }
247
        }
248
249
        // Todo improve the exception thrown
250
4
        std::stringstream ss;
251
        ss << "The exponent of the POW(^) operator can only be a positive "
252
2
              "integral constant below "
253
2
           << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
254
2
        ss << "Exception occurred in:" << std::endl;
255
2
        ss << "  " << t;
256
2
        throw LogicException(ss.str());
257
      }
258
222
    case kind::PI:
259
222
      return RewriteResponse(REWRITE_DONE, t);
260
    default:
261
      Unreachable();
262
    }
263
  }
264
}
265
266
267
375225
RewriteResponse ArithRewriter::preRewriteMult(TNode t){
268
375225
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
269
270
375225
  if(t.getNumChildren() == 2){
271
1107918
    if(t[0].getKind() == kind::CONST_RATIONAL
272
1107918
       && t[0].getConst<Rational>().isOne()){
273
9804
      return RewriteResponse(REWRITE_DONE, t[1]);
274
    }
275
1078506
    if(t[1].getKind() == kind::CONST_RATIONAL
276
1078506
       && t[1].getConst<Rational>().isOne()){
277
25537
      return RewriteResponse(REWRITE_DONE, t[0]);
278
    }
279
  }
280
281
  // Rewrite multiplications with a 0 argument and to 0
282
1021254
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
283
694068
    if((*i).getKind() == kind::CONST_RATIONAL) {
284
386447
      if((*i).getConst<Rational>().isZero()) {
285
25396
        TNode zero = (*i);
286
12698
        return RewriteResponse(REWRITE_DONE, zero);
287
      }
288
    }
289
  }
290
327186
  return RewriteResponse(REWRITE_DONE, t);
291
}
292
293
614422
static bool canFlatten(Kind k, TNode t){
294
2218833
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
295
3249688
    TNode child = *i;
296
1645277
    if(child.getKind() == k){
297
40866
      return true;
298
    }
299
  }
300
573556
  return false;
301
}
302
303
86875
static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
304
86875
  if(t.getKind() == k){
305
275188
    for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
306
376626
      TNode child = *i;
307
188313
      if(child.getKind() == k){
308
46009
        flatten(pb, k, child);
309
      }else{
310
142304
        pb.push_back(child);
311
      }
312
    }
313
  }else{
314
    pb.push_back(t);
315
  }
316
86875
}
317
318
40866
static Node flatten(Kind k, TNode t){
319
81732
  std::vector<TNode> pb;
320
40866
  flatten(pb, k, t);
321
40866
  Assert(pb.size() >= 2);
322
81732
  return NodeManager::currentNM()->mkNode(k, pb);
323
}
324
325
614422
RewriteResponse ArithRewriter::preRewritePlus(TNode t){
326
614422
  Assert(t.getKind() == kind::PLUS);
327
328
614422
  if(canFlatten(kind::PLUS, t)){
329
40866
    return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
330
  }else{
331
573556
    return RewriteResponse(REWRITE_DONE, t);
332
  }
333
}
334
335
1563950
RewriteResponse ArithRewriter::postRewritePlus(TNode t){
336
1563950
  Assert(t.getKind() == kind::PLUS);
337
338
3127900
  std::vector<Monomial> monomials;
339
3127900
  std::vector<Polynomial> polynomials;
340
341
5586911
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
342
8045922
    TNode curr = *i;
343
4022961
    if(Monomial::isMember(curr)){
344
3772589
      monomials.push_back(Monomial::parseMonomial(curr));
345
    }else{
346
250372
      polynomials.push_back(Polynomial::parsePolynomial(curr));
347
    }
348
  }
349
350
1563950
  if(!monomials.empty()){
351
1544815
    Monomial::sort(monomials);
352
1544815
    Monomial::combineAdjacentMonomials(monomials);
353
1544815
    polynomials.push_back(Polynomial::mkPolynomial(monomials));
354
  }
355
356
3127900
  Polynomial res = Polynomial::sumPolynomials(polynomials);
357
358
3127900
  return RewriteResponse(REWRITE_DONE, res.getNode());
359
}
360
361
533862
RewriteResponse ArithRewriter::postRewriteMult(TNode t){
362
533862
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
363
364
1067724
  Polynomial res = Polynomial::mkOne();
365
366
1634568
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
367
2201412
    Node curr = *i;
368
2201412
    Polynomial currPoly = Polynomial::parsePolynomial(curr);
369
370
1100706
    res = res * currPoly;
371
  }
372
373
1067724
  return RewriteResponse(REWRITE_DONE, res.getNode());
374
}
375
376
475
RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
377
{
378
475
  Assert(t.getKind() == kind::IAND);
379
475
  NodeManager* nm = NodeManager::currentNM();
380
  // if constant, we eliminate
381
475
  if (t[0].isConst() && t[1].isConst())
382
  {
383
147
    size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
384
294
    Node iToBvop = nm->mkConst(IntToBitVector(bsize));
385
294
    Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
386
294
    Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
387
294
    Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
388
294
    Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
389
147
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
390
  }
391
328
  else if (t[0] > t[1])
392
  {
393
    // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
394
46
    Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
395
23
    return RewriteResponse(REWRITE_AGAIN, ret);
396
  }
397
305
  else if (t[0] == t[1])
398
  {
399
    // ((_ iand k) x x) ---> x
400
6
    return RewriteResponse(REWRITE_DONE, t[0]);
401
  }
402
  // simplifications involving constants
403
891
  for (unsigned i = 0; i < 2; i++)
404
  {
405
595
    if (!t[i].isConst())
406
    {
407
572
      continue;
408
    }
409
23
    if (t[i].getConst<Rational>().sgn() == 0)
410
    {
411
      // ((_ iand k) 0 y) ---> 0
412
3
      return RewriteResponse(REWRITE_DONE, t[i]);
413
    }
414
  }
415
296
  return RewriteResponse(REWRITE_DONE, t);
416
}
417
418
2025
RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
419
2025
  return RewriteResponse(REWRITE_DONE, t);
420
}
421
422
3961
RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
423
3961
  Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
424
3961
  NodeManager* nm = NodeManager::currentNM();
425
3961
  switch( t.getKind() ){
426
501
  case kind::EXPONENTIAL: {
427
501
    if(t[0].getKind() == kind::CONST_RATIONAL){
428
628
      Node one = nm->mkConst(Rational(1));
429
314
      if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
430
        return RewriteResponse(
431
            REWRITE_AGAIN,
432
35
            nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
433
      }else{
434
279
        return RewriteResponse(REWRITE_DONE, t);
435
      }
436
    }
437
187
    else if (t[0].getKind() == kind::PLUS)
438
    {
439
56
      std::vector<Node> product;
440
84
      for (const Node tc : t[0])
441
      {
442
56
        product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
443
      }
444
      // We need to do a full rewrite here, since we can get exponentials of
445
      // constants, e.g. when we are rewriting exp(2 + x)
446
      return RewriteResponse(REWRITE_AGAIN_FULL,
447
28
                             nm->mkNode(kind::MULT, product));
448
159
    }
449
  }
450
159
    break;
451
3242
  case kind::SINE:
452
3242
    if(t[0].getKind() == kind::CONST_RATIONAL){
453
1646
      const Rational& rat = t[0].getConst<Rational>();
454
1646
      if(rat.sgn() == 0){
455
39
        return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
456
      }
457
1607
      else if (rat.sgn() == -1)
458
      {
459
        Node ret =
460
962
            nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
461
481
        return RewriteResponse(REWRITE_AGAIN_FULL, ret);
462
      }
463
    }else{
464
      // get the factor of PI in the argument
465
3148
      Node pi_factor;
466
3148
      Node pi;
467
3148
      Node rem;
468
3148
      std::map<Node, Node> msum;
469
1596
      if (ArithMSum::getMonomialSum(t[0], msum))
470
      {
471
1596
        pi = mkPi();
472
1596
        std::map<Node, Node>::iterator itm = msum.find(pi);
473
1596
        if (itm != msum.end())
474
        {
475
429
          if (itm->second.isNull())
476
          {
477
            pi_factor = mkRationalNode(Rational(1));
478
          }
479
          else
480
          {
481
429
            pi_factor = itm->second;
482
          }
483
429
          msum.erase(pi);
484
429
          if (!msum.empty())
485
          {
486
385
            rem = ArithMSum::mkNode(msum);
487
          }
488
        }
489
      }
490
      else
491
      {
492
        Assert(false);
493
      }
494
495
      // if there is a factor of PI
496
1596
      if( !pi_factor.isNull() ){
497
429
        Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
498
814
        Rational r = pi_factor.getConst<Rational>();
499
814
        Rational r_abs = r.abs();
500
814
        Rational rone = Rational(1);
501
814
        Node ntwo = mkRationalNode(Rational(2));
502
429
        if (r_abs > rone)
503
        {
504
          //add/substract 2*pi beyond scope
505
          Node ra_div_two = nm->mkNode(
506
              kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
507
          Node new_pi_factor;
508
          if( r.sgn()==1 ){
509
            new_pi_factor =
510
                nm->mkNode(kind::MINUS,
511
                           pi_factor,
512
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
513
          }else{
514
            Assert(r.sgn() == -1);
515
            new_pi_factor =
516
                nm->mkNode(kind::PLUS,
517
                           pi_factor,
518
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
519
          }
520
          Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
521
          if (!rem.isNull())
522
          {
523
            new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
524
          }
525
          // sin( 2*n*PI + x ) = sin( x )
526
          return RewriteResponse(REWRITE_AGAIN_FULL,
527
                                 nm->mkNode(kind::SINE, new_arg));
528
        }
529
429
        else if (r_abs == rone)
530
        {
531
          // sin( PI + x ) = -sin( x )
532
          if (rem.isNull())
533
          {
534
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
535
          }
536
          else
537
          {
538
            return RewriteResponse(
539
                REWRITE_AGAIN_FULL,
540
                nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
541
          }
542
        }
543
429
        else if (rem.isNull())
544
        {
545
          // other rational cases based on Niven's theorem
546
          // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
547
44
          Integer one = Integer(1);
548
44
          Integer two = Integer(2);
549
44
          Integer six = Integer(6);
550
44
          if (r_abs.getDenominator() == two)
551
          {
552
44
            Assert(r_abs.getNumerator() == one);
553
            return RewriteResponse(REWRITE_DONE,
554
44
                                   mkRationalNode(Rational(r.sgn())));
555
          }
556
          else if (r_abs.getDenominator() == six)
557
          {
558
            Integer five = Integer(5);
559
            if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
560
            {
561
              return RewriteResponse(
562
                  REWRITE_DONE,
563
                  mkRationalNode(Rational(r.sgn()) / Rational(2)));
564
            }
565
          }
566
        }
567
      }
568
    }
569
2678
    break;
570
87
  case kind::COSINE: {
571
    return RewriteResponse(
572
        REWRITE_AGAIN_FULL,
573
174
        nm->mkNode(kind::SINE,
574
348
                   nm->mkNode(kind::MINUS,
575
348
                              nm->mkNode(kind::MULT,
576
174
                                         nm->mkConst(Rational(1) / Rational(2)),
577
174
                                         mkPi()),
578
87
                              t[0])));
579
  }
580
  break;
581
17
  case kind::TANGENT:
582
  {
583
    return RewriteResponse(REWRITE_AGAIN_FULL,
584
68
                           nm->mkNode(kind::DIVISION,
585
34
                                      nm->mkNode(kind::SINE, t[0]),
586
51
                                      nm->mkNode(kind::COSINE, t[0])));
587
  }
588
  break;
589
4
  case kind::COSECANT:
590
  {
591
    return RewriteResponse(REWRITE_AGAIN_FULL,
592
16
                           nm->mkNode(kind::DIVISION,
593
8
                                      mkRationalNode(Rational(1)),
594
12
                                      nm->mkNode(kind::SINE, t[0])));
595
  }
596
  break;
597
4
  case kind::SECANT:
598
  {
599
    return RewriteResponse(REWRITE_AGAIN_FULL,
600
16
                           nm->mkNode(kind::DIVISION,
601
8
                                      mkRationalNode(Rational(1)),
602
12
                                      nm->mkNode(kind::COSINE, t[0])));
603
  }
604
  break;
605
8
  case kind::COTANGENT:
606
  {
607
    return RewriteResponse(REWRITE_AGAIN_FULL,
608
32
                           nm->mkNode(kind::DIVISION,
609
16
                                      nm->mkNode(kind::COSINE, t[0]),
610
24
                                      nm->mkNode(kind::SINE, t[0])));
611
  }
612
  break;
613
98
  default:
614
98
    break;
615
  }
616
2935
  return RewriteResponse(REWRITE_DONE, t);
617
}
618
619
2304002
RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
620
2304002
  if(atom.getKind() == kind::IS_INTEGER) {
621
27
    if(atom[0].isConst()) {
622
3
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
623
    }
624
24
    if(atom[0].getType().isInteger()) {
625
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
626
    }
627
    // not supported, but this isn't the right place to complain
628
24
    return RewriteResponse(REWRITE_DONE, atom);
629
2303975
  } else if(atom.getKind() == kind::DIVISIBLE) {
630
    if(atom[0].isConst()) {
631
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
632
    }
633
    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
634
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
635
    }
636
    return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
637
  }
638
639
  // left |><| right
640
4607950
  TNode left = atom[0];
641
4607950
  TNode right = atom[1];
642
643
4607950
  Polynomial pleft = Polynomial::parsePolynomial(left);
644
4607950
  Polynomial pright = Polynomial::parsePolynomial(right);
645
646
2303975
  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
647
2303975
  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
648
649
4607950
  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
650
2303975
  Assert(cmp.isNormalForm());
651
2303975
  return RewriteResponse(REWRITE_DONE, cmp.getNode());
652
}
653
654
1551690
RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
655
1551690
  Assert(isAtom(atom));
656
657
1551690
  NodeManager* currNM = NodeManager::currentNM();
658
659
1551690
  if(atom.getKind() == kind::EQUAL) {
660
876214
    if(atom[0] == atom[1]) {
661
21985
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
662
    }
663
675476
  }else if(atom.getKind() == kind::GT){
664
118166
    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
665
59083
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
666
616393
  }else if(atom.getKind() == kind::LT){
667
166968
    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
668
83484
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
669
532909
  }else if(atom.getKind() == kind::IS_INTEGER){
670
20
    if(atom[0].getType().isInteger()){
671
5
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
672
    }
673
532889
  }else if(atom.getKind() == kind::DIVISIBLE){
674
    if(atom.getOperator().getConst<Divisible>().k.isOne()){
675
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
676
    }
677
  }
678
679
1387133
  return RewriteResponse(REWRITE_DONE, atom);
680
}
681
682
4899713
RewriteResponse ArithRewriter::postRewrite(TNode t){
683
4899713
  if(isTerm(t)){
684
5191422
    RewriteResponse response = postRewriteTerm(t);
685
2595709
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
686
    {
687
      Polynomial::parsePolynomial(response.d_node);
688
    }
689
2595709
    return response;
690
2304002
  }else if(isAtom(t)){
691
4608004
    RewriteResponse response = postRewriteAtom(t);
692
2304002
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
693
    {
694
      Comparison::parseNormalForm(response.d_node);
695
    }
696
2304002
    return response;
697
  }else{
698
    Unreachable();
699
  }
700
}
701
702
2958848
RewriteResponse ArithRewriter::preRewrite(TNode t){
703
2958848
  if(isTerm(t)){
704
1407158
    return preRewriteTerm(t);
705
1551690
  }else if(isAtom(t)){
706
1551690
    return preRewriteAtom(t);
707
  }else{
708
    Unreachable();
709
  }
710
}
711
712
268775
Node ArithRewriter::makeUnaryMinusNode(TNode n){
713
537550
  Rational qNegOne(-1);
714
537550
  return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
715
}
716
717
260544
Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
718
521088
  Node negR = makeUnaryMinusNode(r);
719
260544
  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
720
721
521088
  return diff;
722
}
723
724
3994
RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
725
3994
  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
726
727
7988
  Node left = t[0];
728
7988
  Node right = t[1];
729
3994
  if(right.getKind() == kind::CONST_RATIONAL){
730
2458
    const Rational& den = right.getConst<Rational>();
731
732
2458
    if(den.isZero()){
733
135
      if(t.getKind() == kind::DIVISION_TOTAL){
734
19
        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
735
      }else{
736
        // This is unsupported, but this is not a good place to complain
737
116
        return RewriteResponse(REWRITE_DONE, t);
738
      }
739
    }
740
2323
    Assert(den != Rational(0));
741
742
2323
    if(left.getKind() == kind::CONST_RATIONAL){
743
769
      const Rational& num = left.getConst<Rational>();
744
1538
      Rational div = num / den;
745
1538
      Node result =  mkRationalNode(div);
746
769
      return RewriteResponse(REWRITE_DONE, result);
747
    }
748
749
3108
    Rational div = den.inverse();
750
751
3108
    Node result = mkRationalNode(div);
752
753
3108
    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
754
1554
    if(pre){
755
1540
      return RewriteResponse(REWRITE_DONE, mult);
756
    }else{
757
14
      return RewriteResponse(REWRITE_AGAIN, mult);
758
    }
759
  }else{
760
1536
    return RewriteResponse(REWRITE_DONE, t);
761
  }
762
}
763
764
1757
RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
765
{
766
1757
  NodeManager* nm = NodeManager::currentNM();
767
1757
  Kind k = t.getKind();
768
3514
  Node zero = nm->mkConst(Rational(0));
769
1757
  if (k == kind::INTS_MODULUS)
770
  {
771
910
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
772
    {
773
      // can immediately replace by INTS_MODULUS_TOTAL
774
146
      Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
775
73
      return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
776
    }
777
  }
778
1684
  if (k == kind::INTS_DIVISION)
779
  {
780
847
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
781
    {
782
      // can immediately replace by INTS_DIVISION_TOTAL
783
352
      Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
784
176
      return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
785
    }
786
  }
787
1508
  return RewriteResponse(REWRITE_DONE, t);
788
}
789
790
12597
RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
791
{
792
12597
  if (pre)
793
  {
794
    // do not rewrite at prewrite.
795
5032
    return RewriteResponse(REWRITE_DONE, t);
796
  }
797
7565
  NodeManager* nm = NodeManager::currentNM();
798
7565
  Kind k = t.getKind();
799
7565
  Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
800
15130
  TNode n = t[0];
801
15130
  TNode d = t[1];
802
7565
  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
803
7565
  if(dIsConstant && d.getConst<Rational>().isZero()){
804
    // (div x 0) ---> 0 or (mod x 0) ---> 0
805
163
    return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
806
7402
  }else if(dIsConstant && d.getConst<Rational>().isOne()){
807
238
    if (k == kind::INTS_MODULUS_TOTAL)
808
    {
809
      // (mod x 1) --> 0
810
43
      return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
811
    }
812
195
    Assert(k == kind::INTS_DIVISION_TOTAL);
813
    // (div x 1) --> x
814
195
    return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
815
  }
816
7164
  else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
817
  {
818
    // pull negation
819
    // (div x (- c)) ---> (- (div x c))
820
    // (mod x (- c)) ---> (mod x c)
821
80
    Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
822
40
    Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
823
40
                   ? nm->mkNode(kind::UMINUS, nn)
824
94
                   : nn;
825
40
    return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
826
  }
827
7124
  else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
828
  {
829
1974
    Assert(d.getConst<Rational>().isIntegral());
830
1974
    Assert(n.getConst<Rational>().isIntegral());
831
1974
    Assert(!d.getConst<Rational>().isZero());
832
3948
    Integer di = d.getConst<Rational>().getNumerator();
833
3948
    Integer ni = n.getConst<Rational>().getNumerator();
834
835
1974
    bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
836
837
3948
    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
838
839
    // constant evaluation
840
    // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
841
3948
    Node resultNode = mkRationalNode(Rational(result));
842
1974
    return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
843
  }
844
5150
  if (k == kind::INTS_MODULUS_TOTAL)
845
  {
846
    // Note these rewrites do not need to account for modulus by zero as being
847
    // a UF, which is handled by the reduction of INTS_MODULUS.
848
2888
    Kind k0 = t[0].getKind();
849
2888
    if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
850
    {
851
      // (mod (mod x c) c) --> (mod x c)
852
2
      return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
853
    }
854
2886
    else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
855
    {
856
      // can drop all
857
2456
      std::vector<Node> newChildren;
858
1240
      bool childChanged = false;
859
4508
      for (const Node& tc : t[0])
860
      {
861
3356
        if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
862
        {
863
44
          newChildren.push_back(tc[0]);
864
44
          childChanged = true;
865
44
          continue;
866
        }
867
3268
        newChildren.push_back(tc);
868
      }
869
1240
      if (childChanged)
870
      {
871
        // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
872
        // op is one of { NONLINEAR_MULT, MULT, PLUS }.
873
48
        Node ret = nm->mkNode(k0, newChildren);
874
24
        ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
875
24
        return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
876
      }
877
    }
878
  }
879
  else
880
  {
881
2262
    Assert(k == kind::INTS_DIVISION_TOTAL);
882
    // Note these rewrites do not need to account for division by zero as being
883
    // a UF, which is handled by the reduction of INTS_DIVISION.
884
2262
    if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
885
    {
886
      // (div (mod x c) c) --> 0
887
      Node ret = mkRationalNode(0);
888
      return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
889
    }
890
  }
891
5124
  return RewriteResponse(REWRITE_DONE, t);
892
}
893
894
2690
RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
895
{
896
5380
  Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
897
2690
                         << r << std::endl;
898
2690
  return RewriteResponse(REWRITE_AGAIN_FULL, ret);
899
}
900
901
}/* CVC4::theory::arith namespace */
902
}/* CVC4::theory namespace */
903
26676
}/* CVC4 namespace */