GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_rewriter.cpp Lines: 443 503 88.1 %
Date: 2021-05-22 Branches: 1023 2443 41.9 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Tim King, Morgan Deters
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * [[ Add one-line brief description here ]]
14
 *
15
 * [[ Add lengthier description here ]]
16
 * \todo document this file
17
 */
18
19
#include <set>
20
#include <sstream>
21
#include <stack>
22
#include <vector>
23
24
#include "smt/logic_exception.h"
25
#include "theory/arith/arith_msum.h"
26
#include "theory/arith/arith_rewriter.h"
27
#include "theory/arith/arith_utilities.h"
28
#include "theory/arith/normal_form.h"
29
#include "theory/arith/operator_elim.h"
30
#include "theory/theory.h"
31
#include "util/iand.h"
32
33
namespace cvc5 {
34
namespace theory {
35
namespace arith {
36
37
9459
ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
38
39
12719964
bool ArithRewriter::isAtom(TNode n) {
40
12719964
  Kind k = n.getKind();
41
16476311
  return arith::isRelationOperator(k) || k == kind::IS_INTEGER
42
16476113
      || k == kind::DIVISIBLE;
43
}
44
45
563146
RewriteResponse ArithRewriter::rewriteConstant(TNode t){
46
563146
  Assert(t.isConst());
47
563146
  Assert(t.getKind() == kind::CONST_RATIONAL);
48
49
563146
  return RewriteResponse(REWRITE_DONE, t);
50
}
51
52
RewriteResponse ArithRewriter::rewriteVariable(TNode t){
53
  Assert(t.isVar());
54
55
  return RewriteResponse(REWRITE_DONE, t);
56
}
57
58
235815
RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
59
235815
  Assert(t.getKind() == kind::MINUS);
60
61
235815
  if(pre){
62
235813
    if(t[0] == t[1]){
63
4672
      Rational zero(0);
64
4672
      Node zeroNode  = mkRationalNode(zero);
65
2336
      return RewriteResponse(REWRITE_DONE, zeroNode);
66
    }else{
67
466954
      Node noMinus = makeSubtractionNode(t[0],t[1]);
68
233477
      return RewriteResponse(REWRITE_DONE, noMinus);
69
    }
70
  }else{
71
4
    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
72
4
    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
73
4
    Polynomial diff = minuend - subtrahend;
74
2
    return RewriteResponse(REWRITE_DONE, diff.getNode());
75
  }
76
}
77
78
10780
RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
79
10780
  Assert(t.getKind() == kind::UMINUS);
80
81
10780
  if(t[0].getKind() == kind::CONST_RATIONAL){
82
12334
    Rational neg = -(t[0].getConst<Rational>());
83
6167
    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
84
  }
85
86
9226
  Node noUminus = makeUnaryMinusNode(t[0]);
87
4613
  if(pre)
88
4609
    return RewriteResponse(REWRITE_DONE, noUminus);
89
  else
90
4
    return RewriteResponse(REWRITE_AGAIN, noUminus);
91
}
92
93
1322375
RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
94
1322375
  if(t.isConst()){
95
120257
    return rewriteConstant(t);
96
1202118
  }else if(t.isVar()){
97
    return rewriteVariable(t);
98
  }else{
99
1202118
    switch(Kind k = t.getKind()){
100
235813
    case kind::MINUS:
101
235813
      return rewriteMinus(t, true);
102
10772
    case kind::UMINUS:
103
10772
      return rewriteUMinus(t, true);
104
2771
    case kind::DIVISION:
105
    case kind::DIVISION_TOTAL:
106
2771
      return rewriteDiv(t,true);
107
573955
    case kind::PLUS:
108
573955
      return preRewritePlus(t);
109
365502
    case kind::MULT:
110
365502
    case kind::NONLINEAR_MULT: return preRewriteMult(t);
111
310
    case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
112
2051
    case kind::EXPONENTIAL:
113
    case kind::SINE:
114
    case kind::COSINE:
115
    case kind::TANGENT:
116
    case kind::COSECANT:
117
    case kind::SECANT:
118
    case kind::COTANGENT:
119
    case kind::ARCSINE:
120
    case kind::ARCCOSINE:
121
    case kind::ARCTANGENT:
122
    case kind::ARCCOSECANT:
123
    case kind::ARCSECANT:
124
    case kind::ARCCOTANGENT:
125
2051
    case kind::SQRT: return preRewriteTranscendental(t);
126
776
    case kind::INTS_DIVISION:
127
776
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
128
5067
    case kind::INTS_DIVISION_TOTAL:
129
    case kind::INTS_MODULUS_TOTAL:
130
5067
      return rewriteIntsDivModTotal(t,true);
131
30
    case kind::ABS:
132
30
      if(t[0].isConst()) {
133
14
        const Rational& rat = t[0].getConst<Rational>();
134
14
        if(rat >= 0) {
135
14
          return RewriteResponse(REWRITE_DONE, t[0]);
136
        } else {
137
          return RewriteResponse(REWRITE_DONE,
138
                                 NodeManager::currentNM()->mkConst(-rat));
139
        }
140
16
      }
141
16
      return RewriteResponse(REWRITE_DONE, t);
142
534
    case kind::IS_INTEGER:
143
    case kind::TO_INTEGER:
144
534
      return RewriteResponse(REWRITE_DONE, t);
145
4356
    case kind::TO_REAL:
146
4356
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
147
96
    case kind::POW:
148
96
      return RewriteResponse(REWRITE_DONE, t);
149
85
    case kind::PI:
150
85
      return RewriteResponse(REWRITE_DONE, t);
151
    default: Unhandled() << k;
152
    }
153
  }
154
}
155
156
2433774
RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
157
2433774
  if(t.isConst()){
158
442889
    return rewriteConstant(t);
159
1990885
  }else if(t.isVar()){
160
    return rewriteVariable(t);
161
  }else{
162
1990885
    switch(t.getKind()){
163
2
    case kind::MINUS:
164
2
      return rewriteMinus(t, false);
165
8
    case kind::UMINUS:
166
8
      return rewriteUMinus(t, false);
167
1185
    case kind::DIVISION:
168
    case kind::DIVISION_TOTAL:
169
1185
      return rewriteDiv(t, false);
170
1462508
    case kind::PLUS:
171
1462508
      return postRewritePlus(t);
172
512479
    case kind::MULT:
173
512479
    case kind::NONLINEAR_MULT: return postRewriteMult(t);
174
537
    case kind::IAND: return postRewriteIAnd(t);
175
3930
    case kind::EXPONENTIAL:
176
    case kind::SINE:
177
    case kind::COSINE:
178
    case kind::TANGENT:
179
    case kind::COSECANT:
180
    case kind::SECANT:
181
    case kind::COTANGENT:
182
    case kind::ARCSINE:
183
    case kind::ARCCOSINE:
184
    case kind::ARCTANGENT:
185
    case kind::ARCCOSECANT:
186
    case kind::ARCSECANT:
187
    case kind::ARCCOTANGENT:
188
3930
    case kind::SQRT: return postRewriteTranscendental(t);
189
1101
    case kind::INTS_DIVISION:
190
1101
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
191
7691
    case kind::INTS_DIVISION_TOTAL:
192
    case kind::INTS_MODULUS_TOTAL:
193
7691
      return rewriteIntsDivModTotal(t, false);
194
38
    case kind::ABS:
195
38
      if(t[0].isConst()) {
196
        const Rational& rat = t[0].getConst<Rational>();
197
        if(rat >= 0) {
198
          return RewriteResponse(REWRITE_DONE, t[0]);
199
        } else {
200
          return RewriteResponse(REWRITE_DONE,
201
                                 NodeManager::currentNM()->mkConst(-rat));
202
        }
203
38
      }
204
38
      return RewriteResponse(REWRITE_DONE, t);
205
    case kind::TO_REAL:
206
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
207
1013
    case kind::TO_INTEGER:
208
1013
      if(t[0].isConst()) {
209
66
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
210
      }
211
947
      if(t[0].getType().isInteger()) {
212
2
        return RewriteResponse(REWRITE_DONE, t[0]);
213
      }
214
      //Unimplemented() << "TO_INTEGER, nonconstant";
215
      //return rewriteToInteger(t);
216
945
      return RewriteResponse(REWRITE_DONE, t);
217
    case kind::IS_INTEGER:
218
      if(t[0].isConst()) {
219
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
220
      }
221
      if(t[0].getType().isInteger()) {
222
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
223
      }
224
      //Unimplemented() << "IS_INTEGER, nonconstant";
225
      //return rewriteIsInteger(t);
226
      return RewriteResponse(REWRITE_DONE, t);
227
158
    case kind::POW:
228
      {
229
158
        if(t[1].getKind() == kind::CONST_RATIONAL){
230
157
          const Rational& exp = t[1].getConst<Rational>();
231
158
          TNode base = t[0];
232
157
          if(exp.sgn() == 0){
233
52
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
234
105
          }else if(exp.sgn() > 0 && exp.isIntegral()){
235
106
            cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
236
105
            if (exp <= r)
237
            {
238
104
              unsigned num = exp.getNumerator().toUnsignedInt();
239
104
              if( num==1 ){
240
4
                return RewriteResponse(REWRITE_AGAIN, base);
241
              }else{
242
200
                NodeBuilder nb(kind::MULT);
243
500
                for(unsigned i=0; i < num; ++i){
244
400
                  nb << base;
245
                }
246
100
                Assert(nb.getNumChildren() > 0);
247
200
                Node mult = nb;
248
100
                return RewriteResponse(REWRITE_AGAIN, mult);
249
              }
250
            }
251
          }
252
        }
253
254
        // Todo improve the exception thrown
255
4
        std::stringstream ss;
256
        ss << "The exponent of the POW(^) operator can only be a positive "
257
2
              "integral constant below "
258
2
           << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
259
2
        ss << "Exception occurred in:" << std::endl;
260
2
        ss << "  " << t;
261
2
        throw LogicException(ss.str());
262
      }
263
235
    case kind::PI:
264
235
      return RewriteResponse(REWRITE_DONE, t);
265
    default:
266
      Unreachable();
267
    }
268
  }
269
}
270
271
272
365502
RewriteResponse ArithRewriter::preRewriteMult(TNode t){
273
365502
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
274
275
365502
  if(t.getNumChildren() == 2){
276
1081935
    if(t[0].getKind() == kind::CONST_RATIONAL
277
1081935
       && t[0].getConst<Rational>().isOne()){
278
32266
      return RewriteResponse(REWRITE_DONE, t[1]);
279
    }
280
985137
    if(t[1].getKind() == kind::CONST_RATIONAL
281
985137
       && t[1].getConst<Rational>().isOne()){
282
5649
      return RewriteResponse(REWRITE_DONE, t[0]);
283
    }
284
  }
285
286
  // Rewrite multiplications with a 0 argument and to 0
287
989651
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
288
675103
    if((*i).getKind() == kind::CONST_RATIONAL) {
289
376170
      if((*i).getConst<Rational>().isZero()) {
290
26078
        TNode zero = (*i);
291
13039
        return RewriteResponse(REWRITE_DONE, zero);
292
      }
293
    }
294
  }
295
314548
  return RewriteResponse(REWRITE_DONE, t);
296
}
297
298
573955
static bool canFlatten(Kind k, TNode t){
299
2080683
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
300
3052428
    TNode child = *i;
301
1545700
    if(child.getKind() == k){
302
38972
      return true;
303
    }
304
  }
305
534983
  return false;
306
}
307
308
82906
static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
309
82906
  if(t.getKind() == k){
310
260759
    for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
311
355706
      TNode child = *i;
312
177853
      if(child.getKind() == k){
313
43934
        flatten(pb, k, child);
314
      }else{
315
133919
        pb.push_back(child);
316
      }
317
    }
318
  }else{
319
    pb.push_back(t);
320
  }
321
82906
}
322
323
38972
static Node flatten(Kind k, TNode t){
324
77944
  std::vector<TNode> pb;
325
38972
  flatten(pb, k, t);
326
38972
  Assert(pb.size() >= 2);
327
77944
  return NodeManager::currentNM()->mkNode(k, pb);
328
}
329
330
573955
RewriteResponse ArithRewriter::preRewritePlus(TNode t){
331
573955
  Assert(t.getKind() == kind::PLUS);
332
333
573955
  if(canFlatten(kind::PLUS, t)){
334
38972
    return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
335
  }else{
336
534983
    return RewriteResponse(REWRITE_DONE, t);
337
  }
338
}
339
340
1462508
RewriteResponse ArithRewriter::postRewritePlus(TNode t){
341
1462508
  Assert(t.getKind() == kind::PLUS);
342
343
2925016
  std::vector<Monomial> monomials;
344
2925016
  std::vector<Polynomial> polynomials;
345
346
5248190
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
347
7571364
    TNode curr = *i;
348
3785682
    if(Monomial::isMember(curr)){
349
3544859
      monomials.push_back(Monomial::parseMonomial(curr));
350
    }else{
351
240823
      polynomials.push_back(Polynomial::parsePolynomial(curr));
352
    }
353
  }
354
355
1462508
  if(!monomials.empty()){
356
1443745
    Monomial::sort(monomials);
357
1443745
    Monomial::combineAdjacentMonomials(monomials);
358
1443745
    polynomials.push_back(Polynomial::mkPolynomial(monomials));
359
  }
360
361
2925016
  Polynomial res = Polynomial::sumPolynomials(polynomials);
362
363
2925016
  return RewriteResponse(REWRITE_DONE, res.getNode());
364
}
365
366
512479
RewriteResponse ArithRewriter::postRewriteMult(TNode t){
367
512479
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
368
369
1024958
  Polynomial res = Polynomial::mkOne();
370
371
1567674
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
372
2110390
    Node curr = *i;
373
2110390
    Polynomial currPoly = Polynomial::parsePolynomial(curr);
374
375
1055195
    res = res * currPoly;
376
  }
377
378
1024958
  return RewriteResponse(REWRITE_DONE, res.getNode());
379
}
380
381
537
RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
382
{
383
537
  Assert(t.getKind() == kind::IAND);
384
537
  NodeManager* nm = NodeManager::currentNM();
385
  // if constant, we eliminate
386
537
  if (t[0].isConst() && t[1].isConst())
387
  {
388
164
    size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
389
328
    Node iToBvop = nm->mkConst(IntToBitVector(bsize));
390
328
    Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
391
328
    Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
392
328
    Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
393
328
    Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
394
164
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
395
  }
396
373
  else if (t[0] > t[1])
397
  {
398
    // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
399
60
    Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
400
30
    return RewriteResponse(REWRITE_AGAIN, ret);
401
  }
402
343
  else if (t[0] == t[1])
403
  {
404
    // ((_ iand k) x x) ---> x
405
6
    return RewriteResponse(REWRITE_DONE, t[0]);
406
  }
407
  // simplifications involving constants
408
1005
  for (unsigned i = 0; i < 2; i++)
409
  {
410
671
    if (!t[i].isConst())
411
    {
412
642
      continue;
413
    }
414
29
    if (t[i].getConst<Rational>().sgn() == 0)
415
    {
416
      // ((_ iand k) 0 y) ---> 0
417
3
      return RewriteResponse(REWRITE_DONE, t[i]);
418
    }
419
  }
420
334
  return RewriteResponse(REWRITE_DONE, t);
421
}
422
423
2051
RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
424
2051
  return RewriteResponse(REWRITE_DONE, t);
425
}
426
427
3930
RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
428
3930
  Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
429
3930
  NodeManager* nm = NodeManager::currentNM();
430
3930
  switch( t.getKind() ){
431
458
  case kind::EXPONENTIAL: {
432
458
    if(t[0].getKind() == kind::CONST_RATIONAL){
433
572
      Node one = nm->mkConst(Rational(1));
434
286
      if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
435
        return RewriteResponse(
436
            REWRITE_AGAIN,
437
31
            nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
438
      }else{
439
255
        return RewriteResponse(REWRITE_DONE, t);
440
      }
441
    }
442
172
    else if (t[0].getKind() == kind::PLUS)
443
    {
444
48
      std::vector<Node> product;
445
72
      for (const Node tc : t[0])
446
      {
447
48
        product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
448
      }
449
      // We need to do a full rewrite here, since we can get exponentials of
450
      // constants, e.g. when we are rewriting exp(2 + x)
451
      return RewriteResponse(REWRITE_AGAIN_FULL,
452
24
                             nm->mkNode(kind::MULT, product));
453
148
    }
454
  }
455
148
    break;
456
3258
  case kind::SINE:
457
3258
    if(t[0].getKind() == kind::CONST_RATIONAL){
458
1762
      const Rational& rat = t[0].getConst<Rational>();
459
1762
      if(rat.sgn() == 0){
460
50
        return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
461
      }
462
1712
      else if (rat.sgn() == -1)
463
      {
464
        Node ret =
465
1122
            nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
466
561
        return RewriteResponse(REWRITE_AGAIN_FULL, ret);
467
      }
468
    }else{
469
      // get the factor of PI in the argument
470
2940
      Node pi_factor;
471
2940
      Node pi;
472
2940
      Node rem;
473
2940
      std::map<Node, Node> msum;
474
1496
      if (ArithMSum::getMonomialSum(t[0], msum))
475
      {
476
1496
        pi = mkPi();
477
1496
        std::map<Node, Node>::iterator itm = msum.find(pi);
478
1496
        if (itm != msum.end())
479
        {
480
400
          if (itm->second.isNull())
481
          {
482
            pi_factor = mkRationalNode(Rational(1));
483
          }
484
          else
485
          {
486
400
            pi_factor = itm->second;
487
          }
488
400
          msum.erase(pi);
489
400
          if (!msum.empty())
490
          {
491
348
            rem = ArithMSum::mkNode(msum);
492
          }
493
        }
494
      }
495
      else
496
      {
497
        Assert(false);
498
      }
499
500
      // if there is a factor of PI
501
1496
      if( !pi_factor.isNull() ){
502
400
        Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
503
748
        Rational r = pi_factor.getConst<Rational>();
504
748
        Rational r_abs = r.abs();
505
748
        Rational rone = Rational(1);
506
748
        Node ntwo = mkRationalNode(Rational(2));
507
400
        if (r_abs > rone)
508
        {
509
          //add/substract 2*pi beyond scope
510
          Node ra_div_two = nm->mkNode(
511
              kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
512
          Node new_pi_factor;
513
          if( r.sgn()==1 ){
514
            new_pi_factor =
515
                nm->mkNode(kind::MINUS,
516
                           pi_factor,
517
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
518
          }else{
519
            Assert(r.sgn() == -1);
520
            new_pi_factor =
521
                nm->mkNode(kind::PLUS,
522
                           pi_factor,
523
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
524
          }
525
          Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
526
          if (!rem.isNull())
527
          {
528
            new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
529
          }
530
          // sin( 2*n*PI + x ) = sin( x )
531
          return RewriteResponse(REWRITE_AGAIN_FULL,
532
                                 nm->mkNode(kind::SINE, new_arg));
533
        }
534
400
        else if (r_abs == rone)
535
        {
536
          // sin( PI + x ) = -sin( x )
537
          if (rem.isNull())
538
          {
539
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
540
          }
541
          else
542
          {
543
            return RewriteResponse(
544
                REWRITE_AGAIN_FULL,
545
                nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
546
          }
547
        }
548
400
        else if (rem.isNull())
549
        {
550
          // other rational cases based on Niven's theorem
551
          // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
552
52
          Integer one = Integer(1);
553
52
          Integer two = Integer(2);
554
52
          Integer six = Integer(6);
555
52
          if (r_abs.getDenominator() == two)
556
          {
557
52
            Assert(r_abs.getNumerator() == one);
558
            return RewriteResponse(REWRITE_DONE,
559
52
                                   mkRationalNode(Rational(r.sgn())));
560
          }
561
          else if (r_abs.getDenominator() == six)
562
          {
563
            Integer five = Integer(5);
564
            if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
565
            {
566
              return RewriteResponse(
567
                  REWRITE_DONE,
568
                  mkRationalNode(Rational(r.sgn()) / Rational(2)));
569
            }
570
          }
571
        }
572
      }
573
    }
574
2595
    break;
575
86
  case kind::COSINE: {
576
    return RewriteResponse(
577
        REWRITE_AGAIN_FULL,
578
172
        nm->mkNode(kind::SINE,
579
344
                   nm->mkNode(kind::MINUS,
580
344
                              nm->mkNode(kind::MULT,
581
172
                                         nm->mkConst(Rational(1) / Rational(2)),
582
172
                                         mkPi()),
583
86
                              t[0])));
584
  }
585
  break;
586
16
  case kind::TANGENT:
587
  {
588
    return RewriteResponse(REWRITE_AGAIN_FULL,
589
64
                           nm->mkNode(kind::DIVISION,
590
32
                                      nm->mkNode(kind::SINE, t[0]),
591
48
                                      nm->mkNode(kind::COSINE, t[0])));
592
  }
593
  break;
594
4
  case kind::COSECANT:
595
  {
596
    return RewriteResponse(REWRITE_AGAIN_FULL,
597
16
                           nm->mkNode(kind::DIVISION,
598
8
                                      mkRationalNode(Rational(1)),
599
12
                                      nm->mkNode(kind::SINE, t[0])));
600
  }
601
  break;
602
4
  case kind::SECANT:
603
  {
604
    return RewriteResponse(REWRITE_AGAIN_FULL,
605
16
                           nm->mkNode(kind::DIVISION,
606
8
                                      mkRationalNode(Rational(1)),
607
12
                                      nm->mkNode(kind::COSINE, t[0])));
608
  }
609
  break;
610
8
  case kind::COTANGENT:
611
  {
612
    return RewriteResponse(REWRITE_AGAIN_FULL,
613
32
                           nm->mkNode(kind::DIVISION,
614
16
                                      nm->mkNode(kind::COSINE, t[0]),
615
24
                                      nm->mkNode(kind::SINE, t[0])));
616
  }
617
  break;
618
96
  default:
619
96
    break;
620
  }
621
2839
  return RewriteResponse(REWRITE_DONE, t);
622
}
623
624
2237423
RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
625
2237423
  if(atom.getKind() == kind::IS_INTEGER) {
626
48
    if(atom[0].isConst()) {
627
6
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
628
    }
629
42
    if(atom[0].getType().isInteger()) {
630
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
631
    }
632
    // not supported, but this isn't the right place to complain
633
42
    return RewriteResponse(REWRITE_DONE, atom);
634
2237375
  } else if(atom.getKind() == kind::DIVISIBLE) {
635
    if(atom[0].isConst()) {
636
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
637
    }
638
    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
639
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
640
    }
641
    return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
642
  }
643
644
  // left |><| right
645
4474750
  TNode left = atom[0];
646
4474750
  TNode right = atom[1];
647
648
4474750
  Polynomial pleft = Polynomial::parsePolynomial(left);
649
4474750
  Polynomial pright = Polynomial::parsePolynomial(right);
650
651
2237375
  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
652
2237375
  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
653
654
4474750
  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
655
2237375
  Assert(cmp.isNormalForm());
656
2237375
  return RewriteResponse(REWRITE_DONE, cmp.getNode());
657
}
658
659
1496323
RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
660
1496323
  Assert(isAtom(atom));
661
662
1496323
  NodeManager* currNM = NodeManager::currentNM();
663
664
1496323
  if(atom.getKind() == kind::EQUAL) {
665
839596
    if(atom[0] == atom[1]) {
666
27666
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
667
    }
668
656727
  }else if(atom.getKind() == kind::GT){
669
111234
    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
670
55617
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
671
601110
  }else if(atom.getKind() == kind::LT){
672
177184
    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
673
88592
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
674
512518
  }else if(atom.getKind() == kind::IS_INTEGER){
675
34
    if(atom[0].getType().isInteger()){
676
7
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
677
    }
678
512484
  }else if(atom.getKind() == kind::DIVISIBLE){
679
    if(atom.getOperator().getConst<Divisible>().k.isOne()){
680
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
681
    }
682
  }
683
684
1324441
  return RewriteResponse(REWRITE_DONE, atom);
685
}
686
687
4671197
RewriteResponse ArithRewriter::postRewrite(TNode t){
688
4671197
  if(isTerm(t)){
689
4867548
    RewriteResponse response = postRewriteTerm(t);
690
2433772
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
691
    {
692
      Polynomial::parsePolynomial(response.d_node);
693
    }
694
2433772
    return response;
695
2237423
  }else if(isAtom(t)){
696
4474846
    RewriteResponse response = postRewriteAtom(t);
697
2237423
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
698
    {
699
      Comparison::parseNormalForm(response.d_node);
700
    }
701
2237423
    return response;
702
  }else{
703
    Unreachable();
704
  }
705
}
706
707
2818698
RewriteResponse ArithRewriter::preRewrite(TNode t){
708
2818698
  if(isTerm(t)){
709
1322375
    return preRewriteTerm(t);
710
1496323
  }else if(isAtom(t)){
711
1496323
    return preRewriteAtom(t);
712
  }else{
713
    Unreachable();
714
  }
715
}
716
717
238090
Node ArithRewriter::makeUnaryMinusNode(TNode n){
718
476180
  Rational qNegOne(-1);
719
476180
  return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
720
}
721
722
233477
Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
723
466954
  Node negR = makeUnaryMinusNode(r);
724
233477
  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
725
726
466954
  return diff;
727
}
728
729
3956
RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
730
3956
  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
731
732
7912
  Node left = t[0];
733
7912
  Node right = t[1];
734
3956
  if(right.getKind() == kind::CONST_RATIONAL){
735
2422
    const Rational& den = right.getConst<Rational>();
736
737
2422
    if(den.isZero()){
738
116
      if(t.getKind() == kind::DIVISION_TOTAL){
739
15
        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
740
      }else{
741
        // This is unsupported, but this is not a good place to complain
742
101
        return RewriteResponse(REWRITE_DONE, t);
743
      }
744
    }
745
2306
    Assert(den != Rational(0));
746
747
2306
    if(left.getKind() == kind::CONST_RATIONAL){
748
755
      const Rational& num = left.getConst<Rational>();
749
1510
      Rational div = num / den;
750
1510
      Node result =  mkRationalNode(div);
751
755
      return RewriteResponse(REWRITE_DONE, result);
752
    }
753
754
3102
    Rational div = den.inverse();
755
756
3102
    Node result = mkRationalNode(div);
757
758
3102
    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
759
1551
    if(pre){
760
1541
      return RewriteResponse(REWRITE_DONE, mult);
761
    }else{
762
10
      return RewriteResponse(REWRITE_AGAIN, mult);
763
    }
764
  }else{
765
1534
    return RewriteResponse(REWRITE_DONE, t);
766
  }
767
}
768
769
1877
RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
770
{
771
1877
  NodeManager* nm = NodeManager::currentNM();
772
1877
  Kind k = t.getKind();
773
3754
  Node zero = nm->mkConst(Rational(0));
774
1877
  if (k == kind::INTS_MODULUS)
775
  {
776
892
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
777
    {
778
      // can immediately replace by INTS_MODULUS_TOTAL
779
128
      Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
780
64
      return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
781
    }
782
  }
783
1813
  if (k == kind::INTS_DIVISION)
784
  {
785
985
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
786
    {
787
      // can immediately replace by INTS_DIVISION_TOTAL
788
378
      Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
789
189
      return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
790
    }
791
  }
792
1624
  return RewriteResponse(REWRITE_DONE, t);
793
}
794
795
12758
RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
796
{
797
12758
  if (pre)
798
  {
799
    // do not rewrite at prewrite.
800
5067
    return RewriteResponse(REWRITE_DONE, t);
801
  }
802
7691
  NodeManager* nm = NodeManager::currentNM();
803
7691
  Kind k = t.getKind();
804
7691
  Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
805
15382
  TNode n = t[0];
806
15382
  TNode d = t[1];
807
7691
  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
808
7691
  if(dIsConstant && d.getConst<Rational>().isZero()){
809
    // (div x 0) ---> 0 or (mod x 0) ---> 0
810
60
    return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
811
7631
  }else if(dIsConstant && d.getConst<Rational>().isOne()){
812
253
    if (k == kind::INTS_MODULUS_TOTAL)
813
    {
814
      // (mod x 1) --> 0
815
19
      return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
816
    }
817
234
    Assert(k == kind::INTS_DIVISION_TOTAL);
818
    // (div x 1) --> x
819
234
    return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
820
  }
821
7378
  else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
822
  {
823
    // pull negation
824
    // (div x (- c)) ---> (- (div x c))
825
    // (mod x (- c)) ---> (mod x c)
826
78
    Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
827
39
    Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
828
39
                   ? nm->mkNode(kind::UMINUS, nn)
829
92
                   : nn;
830
39
    return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
831
  }
832
7339
  else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
833
  {
834
1937
    Assert(d.getConst<Rational>().isIntegral());
835
1937
    Assert(n.getConst<Rational>().isIntegral());
836
1937
    Assert(!d.getConst<Rational>().isZero());
837
3874
    Integer di = d.getConst<Rational>().getNumerator();
838
3874
    Integer ni = n.getConst<Rational>().getNumerator();
839
840
1937
    bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
841
842
3874
    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
843
844
    // constant evaluation
845
    // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
846
3874
    Node resultNode = mkRationalNode(Rational(result));
847
1937
    return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
848
  }
849
5402
  if (k == kind::INTS_MODULUS_TOTAL)
850
  {
851
    // Note these rewrites do not need to account for modulus by zero as being
852
    // a UF, which is handled by the reduction of INTS_MODULUS.
853
2959
    Kind k0 = t[0].getKind();
854
2959
    if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
855
    {
856
      // (mod (mod x c) c) --> (mod x c)
857
      return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
858
    }
859
2959
    else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
860
    {
861
      // can drop all
862
2516
      std::vector<Node> newChildren;
863
1276
      bool childChanged = false;
864
4808
      for (const Node& tc : t[0])
865
      {
866
3668
        if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
867
        {
868
68
          newChildren.push_back(tc[0]);
869
68
          childChanged = true;
870
68
          continue;
871
        }
872
3532
        newChildren.push_back(tc);
873
      }
874
1276
      if (childChanged)
875
      {
876
        // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
877
        // op is one of { NONLINEAR_MULT, MULT, PLUS }.
878
72
        Node ret = nm->mkNode(k0, newChildren);
879
36
        ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
880
36
        return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
881
      }
882
    }
883
  }
884
  else
885
  {
886
2443
    Assert(k == kind::INTS_DIVISION_TOTAL);
887
    // Note these rewrites do not need to account for division by zero as being
888
    // a UF, which is handled by the reduction of INTS_DIVISION.
889
2443
    if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
890
    {
891
      // (div (mod x c) c) --> 0
892
      Node ret = mkRationalNode(0);
893
      return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
894
    }
895
  }
896
5366
  return RewriteResponse(REWRITE_DONE, t);
897
}
898
899
2014
TrustNode ArithRewriter::expandDefinition(Node node)
900
{
901
  // call eliminate operators, to eliminate partial operators only
902
4028
  std::vector<SkolemLemma> lems;
903
2014
  TrustNode ret = d_opElim.eliminate(node, lems, true);
904
2014
  Assert(lems.empty());
905
4028
  return ret;
906
}
907
908
2578
RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
909
{
910
5156
  Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
911
2578
                         << r << std::endl;
912
2578
  return RewriteResponse(REWRITE_AGAIN_FULL, ret);
913
}
914
915
}  // namespace arith
916
}  // namespace theory
917
28191
}  // namespace cvc5