GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_rewriter.cpp Lines: 461 520 88.7 %
Date: 2021-08-06 Branches: 1082 2569 42.1 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Tim King, Morgan Deters
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * [[ Add one-line brief description here ]]
14
 *
15
 * [[ Add lengthier description here ]]
16
 * \todo document this file
17
 */
18
19
#include "theory/arith/arith_rewriter.h"
20
21
#include <set>
22
#include <sstream>
23
#include <stack>
24
#include <vector>
25
26
#include "smt/logic_exception.h"
27
#include "theory/arith/arith_msum.h"
28
#include "theory/arith/arith_utilities.h"
29
#include "theory/arith/normal_form.h"
30
#include "theory/arith/operator_elim.h"
31
#include "theory/theory.h"
32
#include "util/bitvector.h"
33
#include "util/divisible.h"
34
#include "util/iand.h"
35
36
namespace cvc5 {
37
namespace theory {
38
namespace arith {
39
40
9853
ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
41
42
13222713
bool ArithRewriter::isAtom(TNode n) {
43
13222713
  Kind k = n.getKind();
44
16948096
  return arith::isRelationOperator(k) || k == kind::IS_INTEGER
45
16947898
      || k == kind::DIVISIBLE;
46
}
47
48
553962
RewriteResponse ArithRewriter::rewriteConstant(TNode t){
49
553962
  Assert(t.isConst());
50
553962
  Assert(t.getKind() == kind::CONST_RATIONAL);
51
52
553962
  return RewriteResponse(REWRITE_DONE, t);
53
}
54
55
RewriteResponse ArithRewriter::rewriteVariable(TNode t){
56
  Assert(t.isVar());
57
58
  return RewriteResponse(REWRITE_DONE, t);
59
}
60
61
185975
RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
62
185975
  Assert(t.getKind() == kind::MINUS);
63
64
185975
  if(pre){
65
185973
    if(t[0] == t[1]){
66
6312
      Rational zero(0);
67
6312
      Node zeroNode  = mkRationalNode(zero);
68
3156
      return RewriteResponse(REWRITE_DONE, zeroNode);
69
    }else{
70
365634
      Node noMinus = makeSubtractionNode(t[0],t[1]);
71
182817
      return RewriteResponse(REWRITE_DONE, noMinus);
72
    }
73
  }else{
74
4
    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
75
4
    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
76
4
    Polynomial diff = minuend - subtrahend;
77
2
    return RewriteResponse(REWRITE_DONE, diff.getNode());
78
  }
79
}
80
81
12665
RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
82
12665
  Assert(t.getKind() == kind::UMINUS);
83
84
12665
  if(t[0].getKind() == kind::CONST_RATIONAL){
85
12398
    Rational neg = -(t[0].getConst<Rational>());
86
6199
    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
87
  }
88
89
12932
  Node noUminus = makeUnaryMinusNode(t[0]);
90
6466
  if(pre)
91
6462
    return RewriteResponse(REWRITE_DONE, noUminus);
92
  else
93
4
    return RewriteResponse(REWRITE_AGAIN, noUminus);
94
}
95
96
1319864
RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
97
1319864
  if(t.isConst()){
98
116468
    return rewriteConstant(t);
99
1203396
  }else if(t.isVar()){
100
    return rewriteVariable(t);
101
  }else{
102
1203396
    switch(Kind k = t.getKind()){
103
185973
    case kind::MINUS:
104
185973
      return rewriteMinus(t, true);
105
12657
    case kind::UMINUS:
106
12657
      return rewriteUMinus(t, true);
107
2755
    case kind::DIVISION:
108
    case kind::DIVISION_TOTAL:
109
2755
      return rewriteDiv(t,true);
110
600091
    case kind::PLUS:
111
600091
      return preRewritePlus(t);
112
387703
    case kind::MULT:
113
387703
    case kind::NONLINEAR_MULT: return preRewriteMult(t);
114
308
    case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
115
65
    case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
116
2052
    case kind::EXPONENTIAL:
117
    case kind::SINE:
118
    case kind::COSINE:
119
    case kind::TANGENT:
120
    case kind::COSECANT:
121
    case kind::SECANT:
122
    case kind::COTANGENT:
123
    case kind::ARCSINE:
124
    case kind::ARCCOSINE:
125
    case kind::ARCTANGENT:
126
    case kind::ARCCOSECANT:
127
    case kind::ARCSECANT:
128
    case kind::ARCCOTANGENT:
129
2052
    case kind::SQRT: return preRewriteTranscendental(t);
130
812
    case kind::INTS_DIVISION:
131
812
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
132
5786
    case kind::INTS_DIVISION_TOTAL:
133
    case kind::INTS_MODULUS_TOTAL:
134
5786
      return rewriteIntsDivModTotal(t,true);
135
32
    case kind::ABS:
136
32
      if(t[0].isConst()) {
137
16
        const Rational& rat = t[0].getConst<Rational>();
138
16
        if(rat >= 0) {
139
16
          return RewriteResponse(REWRITE_DONE, t[0]);
140
        } else {
141
          return RewriteResponse(REWRITE_DONE,
142
                                 NodeManager::currentNM()->mkConst(-rat));
143
        }
144
16
      }
145
16
      return RewriteResponse(REWRITE_DONE, t);
146
601
    case kind::IS_INTEGER:
147
    case kind::TO_INTEGER:
148
601
      return RewriteResponse(REWRITE_DONE, t);
149
4357
    case kind::TO_REAL:
150
4357
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
151
119
    case kind::POW:
152
119
      return RewriteResponse(REWRITE_DONE, t);
153
85
    case kind::PI:
154
85
      return RewriteResponse(REWRITE_DONE, t);
155
    default: Unhandled() << k;
156
    }
157
  }
158
}
159
160
2405321
RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
161
2405321
  if(t.isConst()){
162
437494
    return rewriteConstant(t);
163
1967827
  }else if(t.isVar()){
164
    return rewriteVariable(t);
165
  }else{
166
1967827
    switch(t.getKind()){
167
2
    case kind::MINUS:
168
2
      return rewriteMinus(t, false);
169
8
    case kind::UMINUS:
170
8
      return rewriteUMinus(t, false);
171
1198
    case kind::DIVISION:
172
    case kind::DIVISION_TOTAL:
173
1198
      return rewriteDiv(t, false);
174
1403988
    case kind::PLUS:
175
1403988
      return postRewritePlus(t);
176
546601
    case kind::MULT:
177
546601
    case kind::NONLINEAR_MULT: return postRewriteMult(t);
178
539
    case kind::IAND: return postRewriteIAnd(t);
179
127
    case kind::POW2: return postRewritePow2(t);
180
3918
    case kind::EXPONENTIAL:
181
    case kind::SINE:
182
    case kind::COSINE:
183
    case kind::TANGENT:
184
    case kind::COSECANT:
185
    case kind::SECANT:
186
    case kind::COTANGENT:
187
    case kind::ARCSINE:
188
    case kind::ARCCOSINE:
189
    case kind::ARCTANGENT:
190
    case kind::ARCCOSECANT:
191
    case kind::ARCSECANT:
192
    case kind::ARCCOTANGENT:
193
3918
    case kind::SQRT: return postRewriteTranscendental(t);
194
1164
    case kind::INTS_DIVISION:
195
1164
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
196
8669
    case kind::INTS_DIVISION_TOTAL:
197
    case kind::INTS_MODULUS_TOTAL:
198
8669
      return rewriteIntsDivModTotal(t, false);
199
38
    case kind::ABS:
200
38
      if(t[0].isConst()) {
201
        const Rational& rat = t[0].getConst<Rational>();
202
        if(rat >= 0) {
203
          return RewriteResponse(REWRITE_DONE, t[0]);
204
        } else {
205
          return RewriteResponse(REWRITE_DONE,
206
                                 NodeManager::currentNM()->mkConst(-rat));
207
        }
208
38
      }
209
38
      return RewriteResponse(REWRITE_DONE, t);
210
    case kind::TO_REAL:
211
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
212
1160
    case kind::TO_INTEGER:
213
1160
      if(t[0].isConst()) {
214
66
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
215
      }
216
1094
      if(t[0].getType().isInteger()) {
217
2
        return RewriteResponse(REWRITE_DONE, t[0]);
218
      }
219
      //Unimplemented() << "TO_INTEGER, nonconstant";
220
      //return rewriteToInteger(t);
221
1092
      return RewriteResponse(REWRITE_DONE, t);
222
    case kind::IS_INTEGER:
223
      if(t[0].isConst()) {
224
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
225
      }
226
      if(t[0].getType().isInteger()) {
227
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
228
      }
229
      //Unimplemented() << "IS_INTEGER, nonconstant";
230
      //return rewriteIsInteger(t);
231
      return RewriteResponse(REWRITE_DONE, t);
232
181
    case kind::POW:
233
      {
234
181
        if(t[1].getKind() == kind::CONST_RATIONAL){
235
168
          const Rational& exp = t[1].getConst<Rational>();
236
169
          TNode base = t[0];
237
168
          if(exp.sgn() == 0){
238
52
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
239
116
          }else if(exp.sgn() > 0 && exp.isIntegral()){
240
117
            cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
241
116
            if (exp <= r)
242
            {
243
115
              unsigned num = exp.getNumerator().toUnsignedInt();
244
115
              if( num==1 ){
245
4
                return RewriteResponse(REWRITE_AGAIN, base);
246
              }else{
247
222
                NodeBuilder nb(kind::MULT);
248
731
                for(unsigned i=0; i < num; ++i){
249
620
                  nb << base;
250
                }
251
111
                Assert(nb.getNumChildren() > 0);
252
222
                Node mult = nb;
253
111
                return RewriteResponse(REWRITE_AGAIN, mult);
254
              }
255
            }
256
          }
257
        }
258
39
        else if (t[0].getKind() == kind::CONST_RATIONAL
259
39
                 && t[0].getConst<Rational>().getNumerator().toUnsignedInt() == 2)
260
        {
261
          return RewriteResponse(
262
12
              REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
263
        }
264
265
        // Todo improve the exception thrown
266
4
        std::stringstream ss;
267
        ss << "The exponent of the POW(^) operator can only be a positive "
268
2
              "integral constant below "
269
2
           << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
270
2
        ss << "Exception occurred in:" << std::endl;
271
2
        ss << "  " << t;
272
2
        throw LogicException(ss.str());
273
      }
274
234
    case kind::PI:
275
234
      return RewriteResponse(REWRITE_DONE, t);
276
    default:
277
      Unreachable();
278
    }
279
  }
280
}
281
282
283
387703
RewriteResponse ArithRewriter::preRewriteMult(TNode t){
284
387703
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
285
286
387703
  if(t.getNumChildren() == 2){
287
1149312
    if(t[0].getKind() == kind::CONST_RATIONAL
288
1149312
       && t[0].getConst<Rational>().isOne()){
289
36759
      return RewriteResponse(REWRITE_DONE, t[1]);
290
    }
291
1039035
    if(t[1].getKind() == kind::CONST_RATIONAL
292
1039035
       && t[1].getConst<Rational>().isOne()){
293
6019
      return RewriteResponse(REWRITE_DONE, t[0]);
294
    }
295
  }
296
297
  // Rewrite multiplications with a 0 argument and to 0
298
1040004
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
299
708580
    if((*i).getKind() == kind::CONST_RATIONAL) {
300
390380
      if((*i).getConst<Rational>().isZero()) {
301
27002
        TNode zero = (*i);
302
13501
        return RewriteResponse(REWRITE_DONE, zero);
303
      }
304
    }
305
  }
306
331424
  return RewriteResponse(REWRITE_DONE, t);
307
}
308
309
600091
static bool canFlatten(Kind k, TNode t){
310
2123035
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
311
3097392
    TNode child = *i;
312
1574448
    if(child.getKind() == k){
313
51504
      return true;
314
    }
315
  }
316
548587
  return false;
317
}
318
319
108378
static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
320
108378
  if(t.getKind() == k){
321
346698
    for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
322
476640
      TNode child = *i;
323
238320
      if(child.getKind() == k){
324
56874
        flatten(pb, k, child);
325
      }else{
326
181446
        pb.push_back(child);
327
      }
328
    }
329
  }else{
330
    pb.push_back(t);
331
  }
332
108378
}
333
334
51504
static Node flatten(Kind k, TNode t){
335
103008
  std::vector<TNode> pb;
336
51504
  flatten(pb, k, t);
337
51504
  Assert(pb.size() >= 2);
338
103008
  return NodeManager::currentNM()->mkNode(k, pb);
339
}
340
341
600091
RewriteResponse ArithRewriter::preRewritePlus(TNode t){
342
600091
  Assert(t.getKind() == kind::PLUS);
343
344
600091
  if(canFlatten(kind::PLUS, t)){
345
51504
    return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
346
  }else{
347
548587
    return RewriteResponse(REWRITE_DONE, t);
348
  }
349
}
350
351
1403988
RewriteResponse ArithRewriter::postRewritePlus(TNode t){
352
1403988
  Assert(t.getKind() == kind::PLUS);
353
354
2807976
  std::vector<Monomial> monomials;
355
2807976
  std::vector<Polynomial> polynomials;
356
357
4968472
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
358
7128968
    TNode curr = *i;
359
3564484
    if(Monomial::isMember(curr)){
360
3351476
      monomials.push_back(Monomial::parseMonomial(curr));
361
    }else{
362
213008
      polynomials.push_back(Polynomial::parsePolynomial(curr));
363
    }
364
  }
365
366
1403988
  if(!monomials.empty()){
367
1388818
    Monomial::sort(monomials);
368
1388818
    Monomial::combineAdjacentMonomials(monomials);
369
1388818
    polynomials.push_back(Polynomial::mkPolynomial(monomials));
370
  }
371
372
2807976
  Polynomial res = Polynomial::sumPolynomials(polynomials);
373
374
2807976
  return RewriteResponse(REWRITE_DONE, res.getNode());
375
}
376
377
546601
RewriteResponse ArithRewriter::postRewriteMult(TNode t){
378
546601
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
379
380
1093202
  Polynomial res = Polynomial::mkOne();
381
382
1668382
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
383
2243562
    Node curr = *i;
384
2243562
    Polynomial currPoly = Polynomial::parsePolynomial(curr);
385
386
1121781
    res = res * currPoly;
387
  }
388
389
1093202
  return RewriteResponse(REWRITE_DONE, res.getNode());
390
}
391
392
127
RewriteResponse ArithRewriter::postRewritePow2(TNode t)
393
{
394
127
  Assert(t.getKind() == kind::POW2);
395
127
  NodeManager* nm = NodeManager::currentNM();
396
  // if constant, we eliminate
397
127
  if (t[0].isConst())
398
  {
399
    // pow2 is only supported for integers
400
35
    Assert(t[0].getType().isInteger());
401
70
    Integer i = t[0].getConst<Rational>().getNumerator();
402
35
    if (i < 0)
403
    {
404
      return RewriteResponse(
405
          REWRITE_DONE,
406
15
          nm->mkConst<Rational>(Rational(Integer(0), Integer(1))));
407
    }
408
20
    unsigned long k = i.getUnsignedLong();
409
40
    Node ret = nm->mkConst<Rational>(Rational(Integer(2).pow(k), Integer(1)));
410
20
    return RewriteResponse(REWRITE_DONE, ret);
411
  }
412
92
  return RewriteResponse(REWRITE_DONE, t);
413
}
414
415
539
RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
416
{
417
539
  Assert(t.getKind() == kind::IAND);
418
539
  NodeManager* nm = NodeManager::currentNM();
419
  // if constant, we eliminate
420
539
  if (t[0].isConst() && t[1].isConst())
421
  {
422
158
    size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
423
316
    Node iToBvop = nm->mkConst(IntToBitVector(bsize));
424
316
    Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
425
316
    Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
426
316
    Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
427
316
    Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
428
158
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
429
  }
430
381
  else if (t[0] > t[1])
431
  {
432
    // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
433
60
    Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
434
30
    return RewriteResponse(REWRITE_AGAIN, ret);
435
  }
436
351
  else if (t[0] == t[1])
437
  {
438
    // ((_ iand k) x x) ---> x
439
6
    return RewriteResponse(REWRITE_DONE, t[0]);
440
  }
441
  // simplifications involving constants
442
1029
  for (unsigned i = 0; i < 2; i++)
443
  {
444
687
    if (!t[i].isConst())
445
    {
446
658
      continue;
447
    }
448
29
    if (t[i].getConst<Rational>().sgn() == 0)
449
    {
450
      // ((_ iand k) 0 y) ---> 0
451
3
      return RewriteResponse(REWRITE_DONE, t[i]);
452
    }
453
  }
454
342
  return RewriteResponse(REWRITE_DONE, t);
455
}
456
457
2052
RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
458
2052
  return RewriteResponse(REWRITE_DONE, t);
459
}
460
461
3918
RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
462
3918
  Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
463
3918
  NodeManager* nm = NodeManager::currentNM();
464
3918
  switch( t.getKind() ){
465
452
  case kind::EXPONENTIAL: {
466
452
    if(t[0].getKind() == kind::CONST_RATIONAL){
467
560
      Node one = nm->mkConst(Rational(1));
468
280
      if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
469
        return RewriteResponse(
470
            REWRITE_AGAIN,
471
31
            nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
472
      }else{
473
249
        return RewriteResponse(REWRITE_DONE, t);
474
      }
475
    }
476
172
    else if (t[0].getKind() == kind::PLUS)
477
    {
478
48
      std::vector<Node> product;
479
72
      for (const Node tc : t[0])
480
      {
481
48
        product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
482
      }
483
      // We need to do a full rewrite here, since we can get exponentials of
484
      // constants, e.g. when we are rewriting exp(2 + x)
485
      return RewriteResponse(REWRITE_AGAIN_FULL,
486
24
                             nm->mkNode(kind::MULT, product));
487
148
    }
488
  }
489
148
    break;
490
3252
  case kind::SINE:
491
3252
    if(t[0].getKind() == kind::CONST_RATIONAL){
492
1757
      const Rational& rat = t[0].getConst<Rational>();
493
1757
      if(rat.sgn() == 0){
494
49
        return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
495
      }
496
1708
      else if (rat.sgn() == -1)
497
      {
498
        Node ret =
499
1134
            nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
500
567
        return RewriteResponse(REWRITE_AGAIN_FULL, ret);
501
      }
502
    }else{
503
      // get the factor of PI in the argument
504
2939
      Node pi_factor;
505
2939
      Node pi;
506
2939
      Node rem;
507
2939
      std::map<Node, Node> msum;
508
1495
      if (ArithMSum::getMonomialSum(t[0], msum))
509
      {
510
1495
        pi = mkPi();
511
1495
        std::map<Node, Node>::iterator itm = msum.find(pi);
512
1495
        if (itm != msum.end())
513
        {
514
399
          if (itm->second.isNull())
515
          {
516
            pi_factor = mkRationalNode(Rational(1));
517
          }
518
          else
519
          {
520
399
            pi_factor = itm->second;
521
          }
522
399
          msum.erase(pi);
523
399
          if (!msum.empty())
524
          {
525
348
            rem = ArithMSum::mkNode(msum);
526
          }
527
        }
528
      }
529
      else
530
      {
531
        Assert(false);
532
      }
533
534
      // if there is a factor of PI
535
1495
      if( !pi_factor.isNull() ){
536
399
        Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
537
747
        Rational r = pi_factor.getConst<Rational>();
538
747
        Rational r_abs = r.abs();
539
747
        Rational rone = Rational(1);
540
747
        Node ntwo = mkRationalNode(Rational(2));
541
399
        if (r_abs > rone)
542
        {
543
          //add/substract 2*pi beyond scope
544
          Node ra_div_two = nm->mkNode(
545
              kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
546
          Node new_pi_factor;
547
          if( r.sgn()==1 ){
548
            new_pi_factor =
549
                nm->mkNode(kind::MINUS,
550
                           pi_factor,
551
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
552
          }else{
553
            Assert(r.sgn() == -1);
554
            new_pi_factor =
555
                nm->mkNode(kind::PLUS,
556
                           pi_factor,
557
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
558
          }
559
          Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
560
          if (!rem.isNull())
561
          {
562
            new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
563
          }
564
          // sin( 2*n*PI + x ) = sin( x )
565
          return RewriteResponse(REWRITE_AGAIN_FULL,
566
                                 nm->mkNode(kind::SINE, new_arg));
567
        }
568
399
        else if (r_abs == rone)
569
        {
570
          // sin( PI + x ) = -sin( x )
571
          if (rem.isNull())
572
          {
573
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
574
          }
575
          else
576
          {
577
            return RewriteResponse(
578
                REWRITE_AGAIN_FULL,
579
                nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
580
          }
581
        }
582
399
        else if (rem.isNull())
583
        {
584
          // other rational cases based on Niven's theorem
585
          // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
586
51
          Integer one = Integer(1);
587
51
          Integer two = Integer(2);
588
51
          Integer six = Integer(6);
589
51
          if (r_abs.getDenominator() == two)
590
          {
591
51
            Assert(r_abs.getNumerator() == one);
592
            return RewriteResponse(REWRITE_DONE,
593
51
                                   mkRationalNode(Rational(r.sgn())));
594
          }
595
          else if (r_abs.getDenominator() == six)
596
          {
597
            Integer five = Integer(5);
598
            if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
599
            {
600
              return RewriteResponse(
601
                  REWRITE_DONE,
602
                  mkRationalNode(Rational(r.sgn()) / Rational(2)));
603
            }
604
          }
605
        }
606
      }
607
    }
608
2585
    break;
609
86
  case kind::COSINE: {
610
    return RewriteResponse(
611
        REWRITE_AGAIN_FULL,
612
172
        nm->mkNode(kind::SINE,
613
344
                   nm->mkNode(kind::MINUS,
614
344
                              nm->mkNode(kind::MULT,
615
172
                                         nm->mkConst(Rational(1) / Rational(2)),
616
172
                                         mkPi()),
617
86
                              t[0])));
618
  }
619
  break;
620
16
  case kind::TANGENT:
621
  {
622
    return RewriteResponse(REWRITE_AGAIN_FULL,
623
64
                           nm->mkNode(kind::DIVISION,
624
32
                                      nm->mkNode(kind::SINE, t[0]),
625
48
                                      nm->mkNode(kind::COSINE, t[0])));
626
  }
627
  break;
628
4
  case kind::COSECANT:
629
  {
630
    return RewriteResponse(REWRITE_AGAIN_FULL,
631
16
                           nm->mkNode(kind::DIVISION,
632
8
                                      mkRationalNode(Rational(1)),
633
12
                                      nm->mkNode(kind::SINE, t[0])));
634
  }
635
  break;
636
4
  case kind::SECANT:
637
  {
638
    return RewriteResponse(REWRITE_AGAIN_FULL,
639
16
                           nm->mkNode(kind::DIVISION,
640
8
                                      mkRationalNode(Rational(1)),
641
12
                                      nm->mkNode(kind::COSINE, t[0])));
642
  }
643
  break;
644
8
  case kind::COTANGENT:
645
  {
646
    return RewriteResponse(REWRITE_AGAIN_FULL,
647
32
                           nm->mkNode(kind::DIVISION,
648
16
                                      nm->mkNode(kind::COSINE, t[0]),
649
24
                                      nm->mkNode(kind::SINE, t[0])));
650
  }
651
  break;
652
96
  default:
653
96
    break;
654
  }
655
2829
  return RewriteResponse(REWRITE_DONE, t);
656
}
657
658
2365699
RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
659
2365699
  if(atom.getKind() == kind::IS_INTEGER) {
660
48
    if(atom[0].isConst()) {
661
6
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
662
    }
663
42
    if(atom[0].getType().isInteger()) {
664
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
665
    }
666
    // not supported, but this isn't the right place to complain
667
42
    return RewriteResponse(REWRITE_DONE, atom);
668
2365651
  } else if(atom.getKind() == kind::DIVISIBLE) {
669
    if(atom[0].isConst()) {
670
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
671
    }
672
    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
673
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
674
    }
675
    return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
676
  }
677
678
  // left |><| right
679
4731302
  TNode left = atom[0];
680
4731302
  TNode right = atom[1];
681
682
4731302
  Polynomial pleft = Polynomial::parsePolynomial(left);
683
4731302
  Polynomial pright = Polynomial::parsePolynomial(right);
684
685
2365651
  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
686
2365651
  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
687
688
4731302
  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
689
2365651
  Assert(cmp.isNormalForm());
690
2365651
  return RewriteResponse(REWRITE_DONE, cmp.getNode());
691
}
692
693
1588710
RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
694
1588710
  Assert(isAtom(atom));
695
696
1588710
  NodeManager* currNM = NodeManager::currentNM();
697
698
1588710
  if(atom.getKind() == kind::EQUAL) {
699
890708
    if(atom[0] == atom[1]) {
700
27980
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
701
    }
702
698002
  }else if(atom.getKind() == kind::GT){
703
112196
    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
704
56098
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
705
641904
  }else if(atom.getKind() == kind::LT){
706
196606
    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
707
98303
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
708
543601
  }else if(atom.getKind() == kind::IS_INTEGER){
709
34
    if(atom[0].getType().isInteger()){
710
7
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
711
    }
712
543567
  }else if(atom.getKind() == kind::DIVISIBLE){
713
    if(atom.getOperator().getConst<Divisible>().k.isOne()){
714
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
715
    }
716
  }
717
718
1406322
  return RewriteResponse(REWRITE_DONE, atom);
719
}
720
721
4771020
RewriteResponse ArithRewriter::postRewrite(TNode t){
722
4771020
  if(isTerm(t)){
723
4810642
    RewriteResponse response = postRewriteTerm(t);
724
2405319
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
725
    {
726
      Polynomial::parsePolynomial(response.d_node);
727
    }
728
2405319
    return response;
729
2365699
  }else if(isAtom(t)){
730
4731398
    RewriteResponse response = postRewriteAtom(t);
731
2365699
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
732
    {
733
      Comparison::parseNormalForm(response.d_node);
734
    }
735
2365699
    return response;
736
  }else{
737
    Unreachable();
738
  }
739
}
740
741
2908574
RewriteResponse ArithRewriter::preRewrite(TNode t){
742
2908574
  if(isTerm(t)){
743
1319864
    return preRewriteTerm(t);
744
1588710
  }else if(isAtom(t)){
745
1588710
    return preRewriteAtom(t);
746
  }else{
747
    Unreachable();
748
  }
749
}
750
751
189283
Node ArithRewriter::makeUnaryMinusNode(TNode n){
752
378566
  Rational qNegOne(-1);
753
378566
  return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
754
}
755
756
182817
Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
757
365634
  Node negR = makeUnaryMinusNode(r);
758
182817
  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
759
760
365634
  return diff;
761
}
762
763
3953
RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
764
3953
  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
765
766
7906
  Node left = t[0];
767
7906
  Node right = t[1];
768
3953
  if(right.getKind() == kind::CONST_RATIONAL){
769
2400
    const Rational& den = right.getConst<Rational>();
770
771
2400
    if(den.isZero()){
772
116
      if(t.getKind() == kind::DIVISION_TOTAL){
773
15
        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
774
      }else{
775
        // This is unsupported, but this is not a good place to complain
776
101
        return RewriteResponse(REWRITE_DONE, t);
777
      }
778
    }
779
2284
    Assert(den != Rational(0));
780
781
2284
    if(left.getKind() == kind::CONST_RATIONAL){
782
735
      const Rational& num = left.getConst<Rational>();
783
1470
      Rational div = num / den;
784
1470
      Node result =  mkRationalNode(div);
785
735
      return RewriteResponse(REWRITE_DONE, result);
786
    }
787
788
3098
    Rational div = den.inverse();
789
790
3098
    Node result = mkRationalNode(div);
791
792
3098
    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
793
1549
    if(pre){
794
1538
      return RewriteResponse(REWRITE_DONE, mult);
795
    }else{
796
11
      return RewriteResponse(REWRITE_AGAIN, mult);
797
    }
798
  }else{
799
1553
    return RewriteResponse(REWRITE_DONE, t);
800
  }
801
}
802
803
1976
RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
804
{
805
1976
  NodeManager* nm = NodeManager::currentNM();
806
1976
  Kind k = t.getKind();
807
3952
  Node zero = nm->mkConst(Rational(0));
808
1976
  if (k == kind::INTS_MODULUS)
809
  {
810
782
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
811
    {
812
      // can immediately replace by INTS_MODULUS_TOTAL
813
136
      Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
814
68
      return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
815
    }
816
  }
817
1908
  if (k == kind::INTS_DIVISION)
818
  {
819
1194
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
820
    {
821
      // can immediately replace by INTS_DIVISION_TOTAL
822
378
      Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
823
189
      return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
824
    }
825
  }
826
1719
  return RewriteResponse(REWRITE_DONE, t);
827
}
828
829
14455
RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
830
{
831
14455
  if (pre)
832
  {
833
    // do not rewrite at prewrite.
834
5786
    return RewriteResponse(REWRITE_DONE, t);
835
  }
836
8669
  NodeManager* nm = NodeManager::currentNM();
837
8669
  Kind k = t.getKind();
838
8669
  Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
839
17338
  TNode n = t[0];
840
17338
  TNode d = t[1];
841
8669
  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
842
8669
  if(dIsConstant && d.getConst<Rational>().isZero()){
843
    // (div x 0) ---> 0 or (mod x 0) ---> 0
844
69
    return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
845
8600
  }else if(dIsConstant && d.getConst<Rational>().isOne()){
846
246
    if (k == kind::INTS_MODULUS_TOTAL)
847
    {
848
      // (mod x 1) --> 0
849
19
      return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
850
    }
851
227
    Assert(k == kind::INTS_DIVISION_TOTAL);
852
    // (div x 1) --> x
853
227
    return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
854
  }
855
8354
  else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
856
  {
857
    // pull negation
858
    // (div x (- c)) ---> (- (div x c))
859
    // (mod x (- c)) ---> (mod x c)
860
78
    Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
861
39
    Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
862
39
                   ? nm->mkNode(kind::UMINUS, nn)
863
92
                   : nn;
864
39
    return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
865
  }
866
8315
  else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
867
  {
868
2420
    Assert(d.getConst<Rational>().isIntegral());
869
2420
    Assert(n.getConst<Rational>().isIntegral());
870
2420
    Assert(!d.getConst<Rational>().isZero());
871
4840
    Integer di = d.getConst<Rational>().getNumerator();
872
4840
    Integer ni = n.getConst<Rational>().getNumerator();
873
874
2420
    bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
875
876
4840
    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
877
878
    // constant evaluation
879
    // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
880
4840
    Node resultNode = mkRationalNode(Rational(result));
881
2420
    return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
882
  }
883
5895
  if (k == kind::INTS_MODULUS_TOTAL)
884
  {
885
    // Note these rewrites do not need to account for modulus by zero as being
886
    // a UF, which is handled by the reduction of INTS_MODULUS.
887
2727
    Kind k0 = t[0].getKind();
888
2727
    if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
889
    {
890
      // (mod (mod x c) c) --> (mod x c)
891
1
      return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
892
    }
893
2726
    else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
894
    {
895
      // can drop all
896
2808
      std::vector<Node> newChildren;
897
1422
      bool childChanged = false;
898
5974
      for (const Node& tc : t[0])
899
      {
900
4684
        if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
901
        {
902
66
          newChildren.push_back(tc[0]);
903
66
          childChanged = true;
904
66
          continue;
905
        }
906
4552
        newChildren.push_back(tc);
907
      }
908
1422
      if (childChanged)
909
      {
910
        // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
911
        // op is one of { NONLINEAR_MULT, MULT, PLUS }.
912
72
        Node ret = nm->mkNode(k0, newChildren);
913
36
        ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
914
36
        return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
915
      }
916
    }
917
  }
918
  else
919
  {
920
3168
    Assert(k == kind::INTS_DIVISION_TOTAL);
921
    // Note these rewrites do not need to account for division by zero as being
922
    // a UF, which is handled by the reduction of INTS_DIVISION.
923
3168
    if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
924
    {
925
      // (div (mod x c) c) --> 0
926
      Node ret = mkRationalNode(0);
927
      return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
928
    }
929
  }
930
5858
  return RewriteResponse(REWRITE_DONE, t);
931
}
932
933
2139
TrustNode ArithRewriter::expandDefinition(Node node)
934
{
935
  // call eliminate operators, to eliminate partial operators only
936
4278
  std::vector<SkolemLemma> lems;
937
2139
  TrustNode ret = d_opElim.eliminate(node, lems, true);
938
2139
  Assert(lems.empty());
939
4278
  return ret;
940
}
941
942
3068
RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
943
{
944
6136
  Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
945
3068
                         << r << std::endl;
946
3068
  return RewriteResponse(REWRITE_AGAIN_FULL, ret);
947
}
948
949
}  // namespace arith
950
}  // namespace theory
951
29322
}  // namespace cvc5