GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/arith_rewriter.cpp Lines: 461 520 88.7 %
Date: 2021-09-29 Branches: 1082 2567 42.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Tim King, Morgan Deters
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * [[ Add one-line brief description here ]]
14
 *
15
 * [[ Add lengthier description here ]]
16
 * \todo document this file
17
 */
18
19
#include "theory/arith/arith_rewriter.h"
20
21
#include <set>
22
#include <sstream>
23
#include <stack>
24
#include <vector>
25
26
#include "smt/logic_exception.h"
27
#include "theory/arith/arith_msum.h"
28
#include "theory/arith/arith_utilities.h"
29
#include "theory/arith/normal_form.h"
30
#include "theory/arith/operator_elim.h"
31
#include "theory/theory.h"
32
#include "util/bitvector.h"
33
#include "util/divisible.h"
34
#include "util/iand.h"
35
36
namespace cvc5 {
37
namespace theory {
38
namespace arith {
39
40
6271
ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
41
42
8135702
bool ArithRewriter::isAtom(TNode n) {
43
8135702
  Kind k = n.getKind();
44
10429300
  return arith::isRelationOperator(k) || k == kind::IS_INTEGER
45
10429191
      || k == kind::DIVISIBLE;
46
}
47
48
333375
RewriteResponse ArithRewriter::rewriteConstant(TNode t){
49
333375
  Assert(t.isConst());
50
333375
  Assert(t.getKind() == kind::CONST_RATIONAL);
51
52
333375
  return RewriteResponse(REWRITE_DONE, t);
53
}
54
55
RewriteResponse ArithRewriter::rewriteVariable(TNode t){
56
  Assert(t.isVar());
57
58
  return RewriteResponse(REWRITE_DONE, t);
59
}
60
61
147773
RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
62
147773
  Assert(t.getKind() == kind::MINUS);
63
64
147773
  if(pre){
65
147771
    if(t[0] == t[1]){
66
4654
      Rational zero(0);
67
4654
      Node zeroNode  = mkRationalNode(zero);
68
2327
      return RewriteResponse(REWRITE_DONE, zeroNode);
69
    }else{
70
290888
      Node noMinus = makeSubtractionNode(t[0],t[1]);
71
145444
      return RewriteResponse(REWRITE_DONE, noMinus);
72
    }
73
  }else{
74
4
    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
75
4
    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
76
4
    Polynomial diff = minuend - subtrahend;
77
2
    return RewriteResponse(REWRITE_DONE, diff.getNode());
78
  }
79
}
80
81
7826
RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
82
7826
  Assert(t.getKind() == kind::UMINUS);
83
84
7826
  if(t[0].getKind() == kind::CONST_RATIONAL){
85
5792
    Rational neg = -(t[0].getConst<Rational>());
86
2896
    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
87
  }
88
89
9860
  Node noUminus = makeUnaryMinusNode(t[0]);
90
4930
  if(pre)
91
4929
    return RewriteResponse(REWRITE_DONE, noUminus);
92
  else
93
1
    return RewriteResponse(REWRITE_AGAIN, noUminus);
94
}
95
96
804102
RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
97
804102
  if(t.isConst()){
98
71557
    return rewriteConstant(t);
99
732545
  }else if(t.isVar()){
100
    return rewriteVariable(t);
101
  }else{
102
732545
    switch(Kind k = t.getKind()){
103
147771
    case kind::MINUS:
104
147771
      return rewriteMinus(t, true);
105
7824
    case kind::UMINUS:
106
7824
      return rewriteUMinus(t, true);
107
1212
    case kind::DIVISION:
108
    case kind::DIVISION_TOTAL:
109
1212
      return rewriteDiv(t,true);
110
377473
    case kind::PLUS:
111
377473
      return preRewritePlus(t);
112
192312
    case kind::MULT:
113
192312
    case kind::NONLINEAR_MULT: return preRewriteMult(t);
114
303
    case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
115
21
    case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
116
438
    case kind::EXPONENTIAL:
117
    case kind::SINE:
118
    case kind::COSINE:
119
    case kind::TANGENT:
120
    case kind::COSECANT:
121
    case kind::SECANT:
122
    case kind::COTANGENT:
123
    case kind::ARCSINE:
124
    case kind::ARCCOSINE:
125
    case kind::ARCTANGENT:
126
    case kind::ARCCOSECANT:
127
    case kind::ARCSECANT:
128
    case kind::ARCCOTANGENT:
129
438
    case kind::SQRT: return preRewriteTranscendental(t);
130
422
    case kind::INTS_DIVISION:
131
422
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
132
3491
    case kind::INTS_DIVISION_TOTAL:
133
    case kind::INTS_MODULUS_TOTAL:
134
3491
      return rewriteIntsDivModTotal(t,true);
135
22
    case kind::ABS:
136
22
      if(t[0].isConst()) {
137
11
        const Rational& rat = t[0].getConst<Rational>();
138
11
        if(rat >= 0) {
139
11
          return RewriteResponse(REWRITE_DONE, t[0]);
140
        } else {
141
          return RewriteResponse(REWRITE_DONE,
142
                                 NodeManager::currentNM()->mkConst(-rat));
143
        }
144
11
      }
145
11
      return RewriteResponse(REWRITE_DONE, t);
146
257
    case kind::IS_INTEGER:
147
    case kind::TO_INTEGER:
148
257
      return RewriteResponse(REWRITE_DONE, t);
149
907
    case kind::TO_REAL:
150
907
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
151
65
    case kind::POW:
152
65
      return RewriteResponse(REWRITE_DONE, t);
153
27
    case kind::PI:
154
27
      return RewriteResponse(REWRITE_DONE, t);
155
    default: Unhandled() << k;
156
    }
157
  }
158
}
159
160
1489387
RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
161
1489387
  if(t.isConst()){
162
261818
    return rewriteConstant(t);
163
1227569
  }else if(t.isVar()){
164
    return rewriteVariable(t);
165
  }else{
166
1227569
    switch(t.getKind()){
167
2
    case kind::MINUS:
168
2
      return rewriteMinus(t, false);
169
2
    case kind::UMINUS:
170
2
      return rewriteUMinus(t, false);
171
745
    case kind::DIVISION:
172
    case kind::DIVISION_TOTAL:
173
745
      return rewriteDiv(t, false);
174
915609
    case kind::PLUS:
175
915609
      return postRewritePlus(t);
176
303496
    case kind::MULT:
177
303496
    case kind::NONLINEAR_MULT: return postRewriteMult(t);
178
458
    case kind::IAND: return postRewriteIAnd(t);
179
39
    case kind::POW2: return postRewritePow2(t);
180
829
    case kind::EXPONENTIAL:
181
    case kind::SINE:
182
    case kind::COSINE:
183
    case kind::TANGENT:
184
    case kind::COSECANT:
185
    case kind::SECANT:
186
    case kind::COTANGENT:
187
    case kind::ARCSINE:
188
    case kind::ARCCOSINE:
189
    case kind::ARCTANGENT:
190
    case kind::ARCCOSECANT:
191
    case kind::ARCSECANT:
192
    case kind::ARCCOTANGENT:
193
829
    case kind::SQRT: return postRewriteTranscendental(t);
194
605
    case kind::INTS_DIVISION:
195
605
    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
196
5116
    case kind::INTS_DIVISION_TOTAL:
197
    case kind::INTS_MODULUS_TOTAL:
198
5116
      return rewriteIntsDivModTotal(t, false);
199
24
    case kind::ABS:
200
24
      if(t[0].isConst()) {
201
        const Rational& rat = t[0].getConst<Rational>();
202
        if(rat >= 0) {
203
          return RewriteResponse(REWRITE_DONE, t[0]);
204
        } else {
205
          return RewriteResponse(REWRITE_DONE,
206
                                 NodeManager::currentNM()->mkConst(-rat));
207
        }
208
24
      }
209
24
      return RewriteResponse(REWRITE_DONE, t);
210
    case kind::TO_REAL:
211
    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
212
484
    case kind::TO_INTEGER:
213
484
      if(t[0].isConst()) {
214
46
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
215
      }
216
438
      if(t[0].getType().isInteger()) {
217
2
        return RewriteResponse(REWRITE_DONE, t[0]);
218
      }
219
      //Unimplemented() << "TO_INTEGER, nonconstant";
220
      //return rewriteToInteger(t);
221
436
      return RewriteResponse(REWRITE_DONE, t);
222
    case kind::IS_INTEGER:
223
      if(t[0].isConst()) {
224
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
225
      }
226
      if(t[0].getType().isInteger()) {
227
        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
228
      }
229
      //Unimplemented() << "IS_INTEGER, nonconstant";
230
      //return rewriteIsInteger(t);
231
      return RewriteResponse(REWRITE_DONE, t);
232
93
    case kind::POW:
233
      {
234
93
        if(t[1].getKind() == kind::CONST_RATIONAL){
235
89
          const Rational& exp = t[1].getConst<Rational>();
236
90
          TNode base = t[0];
237
89
          if(exp.sgn() == 0){
238
30
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
239
59
          }else if(exp.sgn() > 0 && exp.isIntegral()){
240
60
            cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
241
59
            if (exp <= r)
242
            {
243
58
              unsigned num = exp.getNumerator().toUnsignedInt();
244
58
              if( num==1 ){
245
4
                return RewriteResponse(REWRITE_AGAIN, base);
246
              }else{
247
108
                NodeBuilder nb(kind::MULT);
248
306
                for(unsigned i=0; i < num; ++i){
249
252
                  nb << base;
250
                }
251
54
                Assert(nb.getNumChildren() > 0);
252
108
                Node mult = nb;
253
54
                return RewriteResponse(REWRITE_AGAIN, mult);
254
              }
255
            }
256
          }
257
        }
258
12
        else if (t[0].getKind() == kind::CONST_RATIONAL
259
12
                 && t[0].getConst<Rational>().getNumerator().toUnsignedInt() == 2)
260
        {
261
          return RewriteResponse(
262
3
              REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
263
        }
264
265
        // Todo improve the exception thrown
266
4
        std::stringstream ss;
267
        ss << "The exponent of the POW(^) operator can only be a positive "
268
2
              "integral constant below "
269
2
           << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
270
2
        ss << "Exception occurred in:" << std::endl;
271
2
        ss << "  " << t;
272
2
        throw LogicException(ss.str());
273
      }
274
67
    case kind::PI:
275
67
      return RewriteResponse(REWRITE_DONE, t);
276
    default:
277
      Unreachable();
278
    }
279
  }
280
}
281
282
283
192312
RewriteResponse ArithRewriter::preRewriteMult(TNode t){
284
192312
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
285
286
192312
  if(t.getNumChildren() == 2){
287
568866
    if(t[0].getKind() == kind::CONST_RATIONAL
288
568866
       && t[0].getConst<Rational>().isOne()){
289
6857
      return RewriteResponse(REWRITE_DONE, t[1]);
290
    }
291
548295
    if(t[1].getKind() == kind::CONST_RATIONAL
292
548295
       && t[1].getConst<Rational>().isOne()){
293
2470
      return RewriteResponse(REWRITE_DONE, t[0]);
294
    }
295
  }
296
297
  // Rewrite multiplications with a 0 argument and to 0
298
559038
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
299
379282
    if((*i).getKind() == kind::CONST_RATIONAL) {
300
205788
      if((*i).getConst<Rational>().isZero()) {
301
6458
        TNode zero = (*i);
302
3229
        return RewriteResponse(REWRITE_DONE, zero);
303
      }
304
    }
305
  }
306
179756
  return RewriteResponse(REWRITE_DONE, t);
307
}
308
309
377473
static bool canFlatten(Kind k, TNode t){
310
1258193
  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
311
1795870
    TNode child = *i;
312
915150
    if(child.getKind() == k){
313
34430
      return true;
314
    }
315
  }
316
343043
  return false;
317
}
318
319
72303
static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
320
72303
  if(t.getKind() == k){
321
228221
    for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
322
311836
      TNode child = *i;
323
155918
      if(child.getKind() == k){
324
37873
        flatten(pb, k, child);
325
      }else{
326
118045
        pb.push_back(child);
327
      }
328
    }
329
  }else{
330
    pb.push_back(t);
331
  }
332
72303
}
333
334
34430
static Node flatten(Kind k, TNode t){
335
68860
  std::vector<TNode> pb;
336
34430
  flatten(pb, k, t);
337
34430
  Assert(pb.size() >= 2);
338
68860
  return NodeManager::currentNM()->mkNode(k, pb);
339
}
340
341
377473
RewriteResponse ArithRewriter::preRewritePlus(TNode t){
342
377473
  Assert(t.getKind() == kind::PLUS);
343
344
377473
  if(canFlatten(kind::PLUS, t)){
345
34430
    return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
346
  }else{
347
343043
    return RewriteResponse(REWRITE_DONE, t);
348
  }
349
}
350
351
915609
RewriteResponse ArithRewriter::postRewritePlus(TNode t){
352
915609
  Assert(t.getKind() == kind::PLUS);
353
354
1831218
  std::vector<Monomial> monomials;
355
1831218
  std::vector<Polynomial> polynomials;
356
357
3170212
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
358
4509206
    TNode curr = *i;
359
2254603
    if(Monomial::isMember(curr)){
360
2167591
      monomials.push_back(Monomial::parseMonomial(curr));
361
    }else{
362
87012
      polynomials.push_back(Polynomial::parsePolynomial(curr));
363
    }
364
  }
365
366
915609
  if(!monomials.empty()){
367
910273
    Monomial::sort(monomials);
368
910273
    Monomial::combineAdjacentMonomials(monomials);
369
910273
    polynomials.push_back(Polynomial::mkPolynomial(monomials));
370
  }
371
372
1831218
  Polynomial res = Polynomial::sumPolynomials(polynomials);
373
374
1831218
  return RewriteResponse(REWRITE_DONE, res.getNode());
375
}
376
377
303496
RewriteResponse ArithRewriter::postRewriteMult(TNode t){
378
303496
  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
379
380
606992
  Polynomial res = Polynomial::mkOne();
381
382
929575
  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
383
1252158
    Node curr = *i;
384
1252158
    Polynomial currPoly = Polynomial::parsePolynomial(curr);
385
386
626079
    res = res * currPoly;
387
  }
388
389
606992
  return RewriteResponse(REWRITE_DONE, res.getNode());
390
}
391
392
39
RewriteResponse ArithRewriter::postRewritePow2(TNode t)
393
{
394
39
  Assert(t.getKind() == kind::POW2);
395
39
  NodeManager* nm = NodeManager::currentNM();
396
  // if constant, we eliminate
397
39
  if (t[0].isConst())
398
  {
399
    // pow2 is only supported for integers
400
12
    Assert(t[0].getType().isInteger());
401
24
    Integer i = t[0].getConst<Rational>().getNumerator();
402
12
    if (i < 0)
403
    {
404
      return RewriteResponse(
405
          REWRITE_DONE,
406
4
          nm->mkConst<Rational>(Rational(Integer(0), Integer(1))));
407
    }
408
8
    unsigned long k = i.getUnsignedLong();
409
16
    Node ret = nm->mkConst<Rational>(Rational(Integer(2).pow(k), Integer(1)));
410
8
    return RewriteResponse(REWRITE_DONE, ret);
411
  }
412
27
  return RewriteResponse(REWRITE_DONE, t);
413
}
414
415
458
RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
416
{
417
458
  Assert(t.getKind() == kind::IAND);
418
458
  NodeManager* nm = NodeManager::currentNM();
419
  // if constant, we eliminate
420
458
  if (t[0].isConst() && t[1].isConst())
421
  {
422
218
    size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
423
436
    Node iToBvop = nm->mkConst(IntToBitVector(bsize));
424
436
    Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
425
436
    Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
426
436
    Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
427
436
    Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
428
218
    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
429
  }
430
240
  else if (t[0] > t[1])
431
  {
432
    // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
433
52
    Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
434
26
    return RewriteResponse(REWRITE_AGAIN, ret);
435
  }
436
214
  else if (t[0] == t[1])
437
  {
438
    // ((_ iand k) x x) ---> x
439
2
    return RewriteResponse(REWRITE_DONE, t[0]);
440
  }
441
  // simplifications involving constants
442
626
  for (unsigned i = 0; i < 2; i++)
443
  {
444
419
    if (!t[i].isConst())
445
    {
446
350
      continue;
447
    }
448
69
    if (t[i].getConst<Rational>().sgn() == 0)
449
    {
450
      // ((_ iand k) 0 y) ---> 0
451
5
      return RewriteResponse(REWRITE_DONE, t[i]);
452
    }
453
  }
454
207
  return RewriteResponse(REWRITE_DONE, t);
455
}
456
457
438
RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
458
438
  return RewriteResponse(REWRITE_DONE, t);
459
}
460
461
829
RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
462
829
  Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
463
829
  NodeManager* nm = NodeManager::currentNM();
464
829
  switch( t.getKind() ){
465
162
  case kind::EXPONENTIAL: {
466
162
    if(t[0].getKind() == kind::CONST_RATIONAL){
467
206
      Node one = nm->mkConst(Rational(1));
468
103
      if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
469
        return RewriteResponse(
470
            REWRITE_AGAIN,
471
14
            nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
472
      }else{
473
89
        return RewriteResponse(REWRITE_DONE, t);
474
      }
475
    }
476
59
    else if (t[0].getKind() == kind::PLUS)
477
    {
478
18
      std::vector<Node> product;
479
27
      for (const Node tc : t[0])
480
      {
481
18
        product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
482
      }
483
      // We need to do a full rewrite here, since we can get exponentials of
484
      // constants, e.g. when we are rewriting exp(2 + x)
485
      return RewriteResponse(REWRITE_AGAIN_FULL,
486
9
                             nm->mkNode(kind::MULT, product));
487
50
    }
488
  }
489
50
    break;
490
596
  case kind::SINE:
491
596
    if(t[0].getKind() == kind::CONST_RATIONAL){
492
276
      const Rational& rat = t[0].getConst<Rational>();
493
276
      if(rat.sgn() == 0){
494
5
        return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
495
      }
496
271
      else if (rat.sgn() == -1)
497
      {
498
        Node ret =
499
128
            nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
500
64
        return RewriteResponse(REWRITE_AGAIN_FULL, ret);
501
      }
502
    }else{
503
      // get the factor of PI in the argument
504
635
      Node pi_factor;
505
635
      Node pi;
506
635
      Node rem;
507
635
      std::map<Node, Node> msum;
508
320
      if (ArithMSum::getMonomialSum(t[0], msum))
509
      {
510
320
        pi = mkPi();
511
320
        std::map<Node, Node>::iterator itm = msum.find(pi);
512
320
        if (itm != msum.end())
513
        {
514
107
          if (itm->second.isNull())
515
          {
516
            pi_factor = mkRationalNode(Rational(1));
517
          }
518
          else
519
          {
520
107
            pi_factor = itm->second;
521
          }
522
107
          msum.erase(pi);
523
107
          if (!msum.empty())
524
          {
525
102
            rem = ArithMSum::mkNode(msum);
526
          }
527
        }
528
      }
529
      else
530
      {
531
        Assert(false);
532
      }
533
534
      // if there is a factor of PI
535
320
      if( !pi_factor.isNull() ){
536
107
        Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
537
209
        Rational r = pi_factor.getConst<Rational>();
538
209
        Rational r_abs = r.abs();
539
209
        Rational rone = Rational(1);
540
209
        Node ntwo = mkRationalNode(Rational(2));
541
107
        if (r_abs > rone)
542
        {
543
          //add/substract 2*pi beyond scope
544
          Node ra_div_two = nm->mkNode(
545
              kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
546
          Node new_pi_factor;
547
          if( r.sgn()==1 ){
548
            new_pi_factor =
549
                nm->mkNode(kind::MINUS,
550
                           pi_factor,
551
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
552
          }else{
553
            Assert(r.sgn() == -1);
554
            new_pi_factor =
555
                nm->mkNode(kind::PLUS,
556
                           pi_factor,
557
                           nm->mkNode(kind::MULT, ntwo, ra_div_two));
558
          }
559
          Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
560
          if (!rem.isNull())
561
          {
562
            new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
563
          }
564
          // sin( 2*n*PI + x ) = sin( x )
565
          return RewriteResponse(REWRITE_AGAIN_FULL,
566
                                 nm->mkNode(kind::SINE, new_arg));
567
        }
568
107
        else if (r_abs == rone)
569
        {
570
          // sin( PI + x ) = -sin( x )
571
          if (rem.isNull())
572
          {
573
            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
574
          }
575
          else
576
          {
577
            return RewriteResponse(
578
                REWRITE_AGAIN_FULL,
579
                nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
580
          }
581
        }
582
107
        else if (rem.isNull())
583
        {
584
          // other rational cases based on Niven's theorem
585
          // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
586
5
          Integer one = Integer(1);
587
5
          Integer two = Integer(2);
588
5
          Integer six = Integer(6);
589
5
          if (r_abs.getDenominator() == two)
590
          {
591
5
            Assert(r_abs.getNumerator() == one);
592
            return RewriteResponse(REWRITE_DONE,
593
5
                                   mkRationalNode(Rational(r.sgn())));
594
          }
595
          else if (r_abs.getDenominator() == six)
596
          {
597
            Integer five = Integer(5);
598
            if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
599
            {
600
              return RewriteResponse(
601
                  REWRITE_DONE,
602
                  mkRationalNode(Rational(r.sgn()) / Rational(2)));
603
            }
604
          }
605
        }
606
      }
607
    }
608
522
    break;
609
23
  case kind::COSINE: {
610
    return RewriteResponse(
611
        REWRITE_AGAIN_FULL,
612
46
        nm->mkNode(kind::SINE,
613
92
                   nm->mkNode(kind::MINUS,
614
92
                              nm->mkNode(kind::MULT,
615
46
                                         nm->mkConst(Rational(1) / Rational(2)),
616
46
                                         mkPi()),
617
23
                              t[0])));
618
  }
619
  break;
620
4
  case kind::TANGENT:
621
  {
622
    return RewriteResponse(REWRITE_AGAIN_FULL,
623
16
                           nm->mkNode(kind::DIVISION,
624
8
                                      nm->mkNode(kind::SINE, t[0]),
625
12
                                      nm->mkNode(kind::COSINE, t[0])));
626
  }
627
  break;
628
1
  case kind::COSECANT:
629
  {
630
    return RewriteResponse(REWRITE_AGAIN_FULL,
631
4
                           nm->mkNode(kind::DIVISION,
632
2
                                      mkRationalNode(Rational(1)),
633
3
                                      nm->mkNode(kind::SINE, t[0])));
634
  }
635
  break;
636
1
  case kind::SECANT:
637
  {
638
    return RewriteResponse(REWRITE_AGAIN_FULL,
639
4
                           nm->mkNode(kind::DIVISION,
640
2
                                      mkRationalNode(Rational(1)),
641
3
                                      nm->mkNode(kind::COSINE, t[0])));
642
  }
643
  break;
644
2
  case kind::COTANGENT:
645
  {
646
    return RewriteResponse(REWRITE_AGAIN_FULL,
647
8
                           nm->mkNode(kind::DIVISION,
648
4
                                      nm->mkNode(kind::COSINE, t[0]),
649
6
                                      nm->mkNode(kind::SINE, t[0])));
650
  }
651
  break;
652
40
  default:
653
40
    break;
654
  }
655
612
  return RewriteResponse(REWRITE_DONE, t);
656
}
657
658
1467122
RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
659
1467122
  if(atom.getKind() == kind::IS_INTEGER) {
660
26
    if(atom[0].isConst()) {
661
4
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
662
    }
663
22
    if(atom[0].getType().isInteger()) {
664
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
665
    }
666
    // not supported, but this isn't the right place to complain
667
22
    return RewriteResponse(REWRITE_DONE, atom);
668
1467096
  } else if(atom.getKind() == kind::DIVISIBLE) {
669
    if(atom[0].isConst()) {
670
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
671
    }
672
    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
673
      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
674
    }
675
    return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
676
  }
677
678
  // left |><| right
679
2934192
  TNode left = atom[0];
680
2934192
  TNode right = atom[1];
681
682
2934192
  Polynomial pleft = Polynomial::parsePolynomial(left);
683
2934192
  Polynomial pright = Polynomial::parsePolynomial(right);
684
685
1467096
  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
686
1467096
  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
687
688
2934192
  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
689
1467096
  Assert(cmp.isNormalForm());
690
1467096
  return RewriteResponse(REWRITE_DONE, cmp.getNode());
691
}
692
693
969323
RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
694
969323
  Assert(isAtom(atom));
695
696
969323
  NodeManager* currNM = NodeManager::currentNM();
697
698
969323
  if(atom.getKind() == kind::EQUAL) {
699
595235
    if(atom[0] == atom[1]) {
700
14242
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
701
    }
702
374088
  }else if(atom.getKind() == kind::GT){
703
61916
    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
704
30958
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
705
343130
  }else if(atom.getKind() == kind::LT){
706
74230
    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
707
37115
    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
708
306015
  }else if(atom.getKind() == kind::IS_INTEGER){
709
19
    if(atom[0].getType().isInteger()){
710
4
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
711
    }
712
305996
  }else if(atom.getKind() == kind::DIVISIBLE){
713
    if(atom.getOperator().getConst<Divisible>().k.isOne()){
714
      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
715
    }
716
  }
717
718
887004
  return RewriteResponse(REWRITE_DONE, atom);
719
}
720
721
2956509
RewriteResponse ArithRewriter::postRewrite(TNode t){
722
2956509
  if(isTerm(t)){
723
2978774
    RewriteResponse response = postRewriteTerm(t);
724
1489385
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
725
    {
726
      Polynomial::parsePolynomial(response.d_node);
727
    }
728
1489385
    return response;
729
1467122
  }else if(isAtom(t)){
730
2934244
    RewriteResponse response = postRewriteAtom(t);
731
1467122
    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
732
    {
733
      Comparison::parseNormalForm(response.d_node);
734
    }
735
1467122
    return response;
736
  }else{
737
    Unreachable();
738
  }
739
}
740
741
1773425
RewriteResponse ArithRewriter::preRewrite(TNode t){
742
1773425
  if(isTerm(t)){
743
804102
    return preRewriteTerm(t);
744
969323
  }else if(isAtom(t)){
745
969323
    return preRewriteAtom(t);
746
  }else{
747
    Unreachable();
748
  }
749
}
750
751
150374
Node ArithRewriter::makeUnaryMinusNode(TNode n){
752
300748
  Rational qNegOne(-1);
753
300748
  return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
754
}
755
756
145444
Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
757
290888
  Node negR = makeUnaryMinusNode(r);
758
145444
  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
759
760
290888
  return diff;
761
}
762
763
1957
RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
764
1957
  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
765
766
3914
  Node left = t[0];
767
3914
  Node right = t[1];
768
1957
  if(right.getKind() == kind::CONST_RATIONAL){
769
989
    const Rational& den = right.getConst<Rational>();
770
771
989
    if(den.isZero()){
772
73
      if(t.getKind() == kind::DIVISION_TOTAL){
773
9
        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
774
      }else{
775
        // This is unsupported, but this is not a good place to complain
776
64
        return RewriteResponse(REWRITE_DONE, t);
777
      }
778
    }
779
916
    Assert(den != Rational(0));
780
781
916
    if(left.getKind() == kind::CONST_RATIONAL){
782
359
      const Rational& num = left.getConst<Rational>();
783
718
      Rational div = num / den;
784
718
      Node result =  mkRationalNode(div);
785
359
      return RewriteResponse(REWRITE_DONE, result);
786
    }
787
788
1114
    Rational div = den.inverse();
789
790
1114
    Node result = mkRationalNode(div);
791
792
1114
    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
793
557
    if(pre){
794
553
      return RewriteResponse(REWRITE_DONE, mult);
795
    }else{
796
4
      return RewriteResponse(REWRITE_AGAIN, mult);
797
    }
798
  }else{
799
968
    return RewriteResponse(REWRITE_DONE, t);
800
  }
801
}
802
803
1027
RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
804
{
805
1027
  NodeManager* nm = NodeManager::currentNM();
806
1027
  Kind k = t.getKind();
807
2054
  Node zero = nm->mkConst(Rational(0));
808
1027
  if (k == kind::INTS_MODULUS)
809
  {
810
468
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
811
    {
812
      // can immediately replace by INTS_MODULUS_TOTAL
813
102
      Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
814
51
      return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
815
    }
816
  }
817
976
  if (k == kind::INTS_DIVISION)
818
  {
819
559
    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
820
    {
821
      // can immediately replace by INTS_DIVISION_TOTAL
822
162
      Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
823
81
      return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
824
    }
825
  }
826
895
  return RewriteResponse(REWRITE_DONE, t);
827
}
828
829
8607
RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
830
{
831
8607
  if (pre)
832
  {
833
    // do not rewrite at prewrite.
834
3491
    return RewriteResponse(REWRITE_DONE, t);
835
  }
836
5116
  NodeManager* nm = NodeManager::currentNM();
837
5116
  Kind k = t.getKind();
838
5116
  Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
839
10232
  TNode n = t[0];
840
10232
  TNode d = t[1];
841
5116
  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
842
5116
  if(dIsConstant && d.getConst<Rational>().isZero()){
843
    // (div x 0) ---> 0 or (mod x 0) ---> 0
844
109
    return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
845
5007
  }else if(dIsConstant && d.getConst<Rational>().isOne()){
846
136
    if (k == kind::INTS_MODULUS_TOTAL)
847
    {
848
      // (mod x 1) --> 0
849
14
      return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
850
    }
851
122
    Assert(k == kind::INTS_DIVISION_TOTAL);
852
    // (div x 1) --> x
853
122
    return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
854
  }
855
4871
  else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
856
  {
857
    // pull negation
858
    // (div x (- c)) ---> (- (div x c))
859
    // (mod x (- c)) ---> (mod x c)
860
48
    Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
861
24
    Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
862
24
                   ? nm->mkNode(kind::UMINUS, nn)
863
59
                   : nn;
864
24
    return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
865
  }
866
4847
  else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
867
  {
868
1493
    Assert(d.getConst<Rational>().isIntegral());
869
1493
    Assert(n.getConst<Rational>().isIntegral());
870
1493
    Assert(!d.getConst<Rational>().isZero());
871
2986
    Integer di = d.getConst<Rational>().getNumerator();
872
2986
    Integer ni = n.getConst<Rational>().getNumerator();
873
874
1493
    bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
875
876
2986
    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
877
878
    // constant evaluation
879
    // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
880
2986
    Node resultNode = mkRationalNode(Rational(result));
881
1493
    return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
882
  }
883
3354
  if (k == kind::INTS_MODULUS_TOTAL)
884
  {
885
    // Note these rewrites do not need to account for modulus by zero as being
886
    // a UF, which is handled by the reduction of INTS_MODULUS.
887
2555
    Kind k0 = t[0].getKind();
888
2555
    if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
889
    {
890
      // (mod (mod x c) c) --> (mod x c)
891
2
      return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
892
    }
893
2553
    else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
894
    {
895
      // can drop all
896
1964
      std::vector<Node> newChildren;
897
990
      bool childChanged = false;
898
4126
      for (const Node& tc : t[0])
899
      {
900
3196
        if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
901
        {
902
30
          newChildren.push_back(tc[0]);
903
30
          childChanged = true;
904
30
          continue;
905
        }
906
3136
        newChildren.push_back(tc);
907
      }
908
990
      if (childChanged)
909
      {
910
        // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
911
        // op is one of { NONLINEAR_MULT, MULT, PLUS }.
912
32
        Node ret = nm->mkNode(k0, newChildren);
913
16
        ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
914
16
        return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
915
      }
916
    }
917
  }
918
  else
919
  {
920
799
    Assert(k == kind::INTS_DIVISION_TOTAL);
921
    // Note these rewrites do not need to account for division by zero as being
922
    // a UF, which is handled by the reduction of INTS_DIVISION.
923
799
    if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
924
    {
925
      // (div (mod x c) c) --> 0
926
      Node ret = mkRationalNode(0);
927
      return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
928
    }
929
  }
930
3336
  return RewriteResponse(REWRITE_DONE, t);
931
}
932
933
2172
TrustNode ArithRewriter::expandDefinition(Node node)
934
{
935
  // call eliminate operators, to eliminate partial operators only
936
4344
  std::vector<SkolemLemma> lems;
937
2172
  TrustNode ret = d_opElim.eliminate(node, lems, true);
938
2172
  Assert(lems.empty());
939
4344
  return ret;
940
}
941
942
1912
RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
943
{
944
3824
  Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
945
1912
                         << r << std::endl;
946
1912
  return RewriteResponse(REWRITE_AGAIN_FULL, ret);
947
}
948
949
}  // namespace arith
950
}  // namespace theory
951
22746
}  // namespace cvc5