GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/strings/arith_entail.cpp Lines: 344 384 89.6 %
Date: 2021-05-21 Branches: 878 1918 45.8 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of arithmetic entailment computation for string terms.
14
 */
15
16
#include "theory/strings/arith_entail.h"
17
18
#include "expr/attribute.h"
19
#include "expr/node_algorithm.h"
20
#include "theory/arith/arith_msum.h"
21
#include "theory/rewriter.h"
22
#include "theory/strings/theory_strings_utils.h"
23
#include "theory/strings/word.h"
24
#include "theory/theory.h"
25
26
using namespace cvc5::kind;
27
28
namespace cvc5 {
29
namespace theory {
30
namespace strings {
31
32
1449
bool ArithEntail::checkEq(Node a, Node b)
33
{
34
1449
  if (a == b)
35
  {
36
    return true;
37
  }
38
2898
  Node ar = Rewriter::rewrite(a);
39
2898
  Node br = Rewriter::rewrite(b);
40
1449
  return ar == br;
41
}
42
43
496936
bool ArithEntail::check(Node a, Node b, bool strict)
44
{
45
496936
  if (a == b)
46
  {
47
13201
    return !strict;
48
  }
49
967470
  Node diff = NodeManager::currentNM()->mkNode(kind::MINUS, a, b);
50
483735
  return check(diff, strict);
51
}
52
53
struct StrCheckEntailArithTag
54
{
55
};
56
struct StrCheckEntailArithComputedTag
57
{
58
};
59
/** Attribute true for expressions for which check returned true */
60
typedef expr::Attribute<StrCheckEntailArithTag, bool> StrCheckEntailArithAttr;
61
typedef expr::Attribute<StrCheckEntailArithComputedTag, bool>
62
    StrCheckEntailArithComputedAttr;
63
64
732516
bool ArithEntail::check(Node a, bool strict)
65
{
66
732516
  if (a.isConst())
67
  {
68
78284
    return a.getConst<Rational>().sgn() >= (strict ? 1 : 0);
69
  }
70
71
  Node ar = strict ? NodeManager::currentNM()->mkNode(
72
875976
                kind::MINUS, a, NodeManager::currentNM()->mkConst(Rational(1)))
73
1530208
                   : a;
74
654232
  ar = Rewriter::rewrite(ar);
75
76
654232
  if (ar.getAttribute(StrCheckEntailArithComputedAttr()))
77
  {
78
502939
    return ar.getAttribute(StrCheckEntailArithAttr());
79
  }
80
81
151293
  bool ret = checkInternal(ar);
82
151293
  if (!ret)
83
  {
84
    // try with approximations
85
140770
    ret = checkApprox(ar);
86
  }
87
  // cache the result
88
151293
  ar.setAttribute(StrCheckEntailArithAttr(), ret);
89
151293
  ar.setAttribute(StrCheckEntailArithComputedAttr(), true);
90
151293
  return ret;
91
}
92
93
140770
bool ArithEntail::checkApprox(Node ar)
94
{
95
140770
  Assert(Rewriter::rewrite(ar) == ar);
96
140770
  NodeManager* nm = NodeManager::currentNM();
97
281540
  std::map<Node, Node> msum;
98
281540
  Trace("strings-ent-approx-debug")
99
140770
      << "Setup arithmetic approximations for " << ar << std::endl;
100
140770
  if (!ArithMSum::getMonomialSum(ar, msum))
101
  {
102
    Trace("strings-ent-approx-debug")
103
        << "...failed to get monomial sum!" << std::endl;
104
    return false;
105
  }
106
  // for each monomial v*c, mApprox[v] a list of
107
  // possibilities for how the term can be soundly approximated, that is,
108
  // if mApprox[v] contains av, then v*c > av*c. Notice that if c
109
  // is positive, then v > av, otherwise if c is negative, then v < av.
110
  // In other words, av is an under-approximation if c is positive, and an
111
  // over-approximation if c is negative.
112
140770
  bool changed = false;
113
281540
  std::map<Node, std::vector<Node> > mApprox;
114
  // map from approximations to their monomial sums
115
281540
  std::map<Node, std::map<Node, Node> > approxMsums;
116
  // aarSum stores each monomial that does not have multiple approximations
117
281540
  std::vector<Node> aarSum;
118
513112
  for (std::pair<const Node, Node>& m : msum)
119
  {
120
744684
    Node v = m.first;
121
744684
    Node c = m.second;
122
744684
    Trace("strings-ent-approx-debug")
123
372342
        << "Get approximations " << v << "..." << std::endl;
124
372342
    if (v.isNull())
125
    {
126
208110
      Node mn = c.isNull() ? nm->mkConst(Rational(1)) : c;
127
104055
      aarSum.push_back(mn);
128
    }
129
    else
130
    {
131
      // c.isNull() means c = 1
132
268287
      bool isOverApprox = !c.isNull() && c.getConst<Rational>().sgn() == -1;
133
268287
      std::vector<Node>& approx = mApprox[v];
134
536574
      std::unordered_set<Node> visited;
135
536574
      std::vector<Node> toProcess;
136
268287
      toProcess.push_back(v);
137
47378
      do
138
      {
139
631330
        Node curr = toProcess.back();
140
315665
        Trace("strings-ent-approx-debug") << "  process " << curr << std::endl;
141
315665
        curr = Rewriter::rewrite(curr);
142
315665
        toProcess.pop_back();
143
315665
        if (visited.find(curr) == visited.end())
144
        {
145
312142
          visited.insert(curr);
146
624284
          std::vector<Node> currApprox;
147
312142
          getArithApproximations(curr, currApprox, isOverApprox);
148
312142
          if (currApprox.empty())
149
          {
150
558036
            Trace("strings-ent-approx-debug")
151
279018
                << "...approximation: " << curr << std::endl;
152
            // no approximations, thus curr is a possibility
153
279018
            approx.push_back(curr);
154
          }
155
          else
156
          {
157
33124
            toProcess.insert(
158
66248
                toProcess.end(), currApprox.begin(), currApprox.end());
159
          }
160
        }
161
315665
      } while (!toProcess.empty());
162
268287
      Assert(!approx.empty());
163
      // if we have only one approximation, move it to final
164
268287
      if (approx.size() == 1)
165
      {
166
257572
        changed = v != approx[0];
167
515144
        Node mn = ArithMSum::mkCoeffTerm(c, approx[0]);
168
257572
        aarSum.push_back(mn);
169
257572
        mApprox.erase(v);
170
      }
171
      else
172
      {
173
        // compute monomial sum form for each approximation, used below
174
32161
        for (const Node& aa : approx)
175
        {
176
21446
          if (approxMsums.find(aa) == approxMsums.end())
177
          {
178
            CVC5_UNUSED bool ret =
179
20434
                ArithMSum::getMonomialSum(aa, approxMsums[aa]);
180
20434
            Assert(ret);
181
          }
182
        }
183
10715
        changed = true;
184
      }
185
    }
186
  }
187
140770
  if (!changed)
188
  {
189
    // approximations had no effect, return
190
123996
    Trace("strings-ent-approx-debug") << "...no approximations" << std::endl;
191
123996
    return false;
192
  }
193
  // get the current "fixed" sum for the abstraction of ar
194
16774
  Node aar = aarSum.empty()
195
18221
                 ? nm->mkConst(Rational(0))
196
34995
                 : (aarSum.size() == 1 ? aarSum[0] : nm->mkNode(PLUS, aarSum));
197
16774
  aar = Rewriter::rewrite(aar);
198
33548
  Trace("strings-ent-approx-debug")
199
33548
      << "...processed fixed sum " << aar << " with " << mApprox.size()
200
16774
      << " approximated monomials." << std::endl;
201
  // if we have a choice of how to approximate
202
16774
  if (!mApprox.empty())
203
  {
204
    // convert aar back to monomial sum
205
14038
    std::map<Node, Node> msumAar;
206
7019
    if (!ArithMSum::getMonomialSum(aar, msumAar))
207
    {
208
      return false;
209
    }
210
7019
    if (Trace.isOn("strings-ent-approx"))
211
    {
212
      Trace("strings-ent-approx")
213
          << "---- Check arithmetic entailment by under-approximation " << ar
214
          << " >= 0" << std::endl;
215
      Trace("strings-ent-approx") << "FIXED:" << std::endl;
216
      ArithMSum::debugPrintMonomialSum(msumAar, "strings-ent-approx");
217
      Trace("strings-ent-approx") << "APPROX:" << std::endl;
218
      for (std::pair<const Node, std::vector<Node> >& a : mApprox)
219
      {
220
        Node c = msum[a.first];
221
        Trace("strings-ent-approx") << "  ";
222
        if (!c.isNull())
223
        {
224
          Trace("strings-ent-approx") << c << " * ";
225
        }
226
        Trace("strings-ent-approx")
227
            << a.second << " ...from " << a.first << std::endl;
228
      }
229
      Trace("strings-ent-approx") << std::endl;
230
    }
231
14038
    Rational one(1);
232
    // incorporate monomials one at a time that have a choice of approximations
233
21215
    while (!mApprox.empty())
234
    {
235
14196
      Node v;
236
14196
      Node vapprox;
237
7098
      int maxScore = -1;
238
      // Look at each approximation, take the one with the best score.
239
      // Notice that we are in the process of trying to prove
240
      // ( c1*t1 + .. + cn*tn ) + ( approx_1 | ... | approx_m ) >= 0,
241
      // where c1*t1 + .. + cn*tn is the "fixed" component of our sum (aar)
242
      // and approx_1 ... approx_m are possible approximations. The
243
      // intution here is that we want coefficients c1...cn to be positive.
244
      // This is because arithmetic string terms t1...tn (which may be
245
      // applications of len, indexof, str.to.int) are never entailed to be
246
      // negative. Hence, we add the approx_i that contributes the "most"
247
      // towards making all constants c1...cn positive and cancelling negative
248
      // monomials in approx_i itself.
249
7098
      for (std::pair<const Node, std::vector<Node> >& nam : mApprox)
250
      {
251
7098
        Node cr = msum[nam.first];
252
21310
        for (const Node& aa : nam.second)
253
        {
254
14212
          unsigned helpsCancelCount = 0;
255
14212
          unsigned addsObligationCount = 0;
256
14212
          std::map<Node, Node>::iterator it;
257
          // we are processing an approximation cr*( c1*t1 + ... + cn*tn )
258
31584
          for (std::pair<const Node, Node>& aam : approxMsums[aa])
259
          {
260
            // Say aar is of the form t + c*ti, and aam is the monomial ci*ti
261
            // where ci != 0. We say aam:
262
            // (1) helps cancel if c != 0 and c>0 != ci>0
263
            // (2) adds obligation if c>=0 and c+ci<0
264
34744
            Node ti = aam.first;
265
34744
            Node ci = aam.second;
266
17372
            if (!cr.isNull())
267
            {
268
26498
              ci = ci.isNull() ? cr
269
26498
                               : Rewriter::rewrite(nm->mkNode(MULT, ci, cr));
270
            }
271
17372
            Trace("strings-ent-approx-debug") << ci << "*" << ti << " ";
272
17372
            int ciSgn = ci.isNull() ? 1 : ci.getConst<Rational>().sgn();
273
17372
            it = msumAar.find(ti);
274
17372
            if (it != msumAar.end())
275
            {
276
18696
              Node c = it->second;
277
9348
              int cSgn = c.isNull() ? 1 : c.getConst<Rational>().sgn();
278
9348
              if (cSgn == 0)
279
              {
280
1887
                addsObligationCount += (ciSgn == -1 ? 1 : 0);
281
              }
282
7461
              else if (cSgn != ciSgn)
283
              {
284
6260
                helpsCancelCount++;
285
12520
                Rational r1 = c.isNull() ? one : c.getConst<Rational>();
286
12520
                Rational r2 = ci.isNull() ? one : ci.getConst<Rational>();
287
12520
                Rational r12 = r1 + r2;
288
6260
                if (r12.sgn() == -1)
289
                {
290
4021
                  addsObligationCount++;
291
                }
292
              }
293
            }
294
            else
295
            {
296
8024
              addsObligationCount += (ciSgn == -1 ? 1 : 0);
297
            }
298
          }
299
28424
          Trace("strings-ent-approx-debug")
300
14212
              << "counts=" << helpsCancelCount << "," << addsObligationCount
301
14212
              << " for " << aa << " into " << aar << std::endl;
302
28424
          int score = (addsObligationCount > 0 ? 0 : 2)
303
14212
                      + (helpsCancelCount > 0 ? 1 : 0);
304
          // if its the best, update v and vapprox
305
14212
          if (v.isNull() || score > maxScore)
306
          {
307
9679
            v = nam.first;
308
9679
            vapprox = aa;
309
9679
            maxScore = score;
310
          }
311
        }
312
7098
        if (!v.isNull())
313
        {
314
7098
          break;
315
        }
316
      }
317
14196
      Trace("strings-ent-approx")
318
7098
          << "- Decide " << v << " = " << vapprox << std::endl;
319
      // we incorporate v approximated by vapprox into the overall approximation
320
      // for ar
321
7098
      Assert(!v.isNull() && !vapprox.isNull());
322
7098
      Assert(msum.find(v) != msum.end());
323
14196
      Node mn = ArithMSum::mkCoeffTerm(msum[v], vapprox);
324
7098
      aar = nm->mkNode(PLUS, aar, mn);
325
      // update the msumAar map
326
7098
      aar = Rewriter::rewrite(aar);
327
7098
      msumAar.clear();
328
7098
      if (!ArithMSum::getMonomialSum(aar, msumAar))
329
      {
330
        Assert(false);
331
        Trace("strings-ent-approx")
332
            << "...failed to get monomial sum!" << std::endl;
333
        return false;
334
      }
335
      // we have processed the approximation for v
336
7098
      mApprox.erase(v);
337
    }
338
7019
    Trace("strings-ent-approx") << "-----------------" << std::endl;
339
  }
340
16774
  if (aar == ar)
341
  {
342
    Trace("strings-ent-approx-debug")
343
        << "...approximation had no effect" << std::endl;
344
    // this should never happen, but we avoid the infinite loop for sanity here
345
    Assert(false);
346
    return false;
347
  }
348
  // Check entailment on the approximation of ar.
349
  // Notice that this may trigger further reasoning by approximation. For
350
  // example, len( replace( x ++ y, substr( x, 0, n ), z ) ) may be
351
  // under-approximated as len( x ) + len( y ) - len( substr( x, 0, n ) ) on
352
  // this call, where in the recursive call we may over-approximate
353
  // len( substr( x, 0, n ) ) as len( x ). In this example, we can infer
354
  // that len( replace( x ++ y, substr( x, 0, n ), z ) ) >= len( y ) in two
355
  // steps.
356
16774
  if (check(aar))
357
  {
358
3240
    Trace("strings-ent-approx")
359
1620
        << "*** StrArithApprox: showed " << ar
360
1620
        << " >= 0 using under-approximation!" << std::endl;
361
3240
    Trace("strings-ent-approx")
362
1620
        << "*** StrArithApprox: under-approximation was " << aar << std::endl;
363
1620
    return true;
364
  }
365
15154
  return false;
366
}
367
368
312628
void ArithEntail::getArithApproximations(Node a,
369
                                         std::vector<Node>& approx,
370
                                         bool isOverApprox)
371
{
372
312628
  NodeManager* nm = NodeManager::currentNM();
373
  // We do not handle PLUS here since this leads to exponential behavior.
374
  // Instead, this is managed, e.g. during checkApprox, where
375
  // PLUS terms are expanded "on-demand" during the reasoning.
376
625256
  Trace("strings-ent-approx-debug")
377
312628
      << "Get arith approximations " << a << std::endl;
378
312628
  Kind ak = a.getKind();
379
312628
  if (ak == MULT)
380
  {
381
972
    Node c;
382
972
    Node v;
383
486
    if (ArithMSum::getMonomial(a, c, v))
384
    {
385
486
      bool isNeg = c.getConst<Rational>().sgn() == -1;
386
486
      getArithApproximations(v, approx, isNeg ? !isOverApprox : isOverApprox);
387
519
      for (unsigned i = 0, size = approx.size(); i < size; i++)
388
      {
389
33
        approx[i] = nm->mkNode(MULT, c, approx[i]);
390
      }
391
    }
392
  }
393
312142
  else if (ak == STRING_LENGTH)
394
  {
395
278639
    Kind aak = a[0].getKind();
396
278639
    if (aak == STRING_SUBSTR)
397
    {
398
      // over,under-approximations for len( substr( x, n, m ) )
399
72706
      Node lenx = nm->mkNode(STRING_LENGTH, a[0][0]);
400
36353
      if (isOverApprox)
401
      {
402
        // m >= 0 implies
403
        //  m >= len( substr( x, n, m ) )
404
20935
        if (check(a[0][2]))
405
        {
406
12784
          approx.push_back(a[0][2]);
407
        }
408
20935
        if (check(lenx, a[0][1]))
409
        {
410
          // n <= len( x ) implies
411
          //   len( x ) - n >= len( substr( x, n, m ) )
412
8189
          approx.push_back(nm->mkNode(MINUS, lenx, a[0][1]));
413
        }
414
        else
415
        {
416
          // len( x ) >= len( substr( x, n, m ) )
417
12746
          approx.push_back(lenx);
418
        }
419
      }
420
      else
421
      {
422
        // 0 <= n and n+m <= len( x ) implies
423
        //   m <= len( substr( x, n, m ) )
424
30836
        Node npm = nm->mkNode(PLUS, a[0][1], a[0][2]);
425
15418
        if (check(a[0][1]) && check(lenx, npm))
426
        {
427
2659
          approx.push_back(a[0][2]);
428
        }
429
        // 0 <= n and n+m >= len( x ) implies
430
        //   len(x)-n <= len( substr( x, n, m ) )
431
15418
        if (check(a[0][1]) && check(npm, lenx))
432
        {
433
1614
          approx.push_back(nm->mkNode(MINUS, lenx, a[0][1]));
434
        }
435
      }
436
    }
437
242286
    else if (aak == STRING_STRREPL)
438
    {
439
      // over,under-approximations for len( replace( x, y, z ) )
440
      // notice this is either len( x ) or ( len( x ) + len( z ) - len( y ) )
441
9866
      Node lenx = nm->mkNode(STRING_LENGTH, a[0][0]);
442
9866
      Node leny = nm->mkNode(STRING_LENGTH, a[0][1]);
443
9866
      Node lenz = nm->mkNode(STRING_LENGTH, a[0][2]);
444
4933
      if (isOverApprox)
445
      {
446
2720
        if (check(leny, lenz))
447
        {
448
          // len( y ) >= len( z ) implies
449
          //   len( x ) >= len( replace( x, y, z ) )
450
915
          approx.push_back(lenx);
451
        }
452
        else
453
        {
454
          // len( x ) + len( z ) >= len( replace( x, y, z ) )
455
1805
          approx.push_back(nm->mkNode(PLUS, lenx, lenz));
456
        }
457
      }
458
      else
459
      {
460
2213
        if (check(lenz, leny) || check(lenz, lenx))
461
        {
462
          // len( y ) <= len( z ) or len( x ) <= len( z ) implies
463
          //   len( x ) <= len( replace( x, y, z ) )
464
1030
          approx.push_back(lenx);
465
        }
466
        else
467
        {
468
          // len( x ) - len( y ) <= len( replace( x, y, z ) )
469
1183
          approx.push_back(nm->mkNode(MINUS, lenx, leny));
470
        }
471
      }
472
    }
473
237353
    else if (aak == STRING_ITOS)
474
    {
475
      // over,under-approximations for len( int.to.str( x ) )
476
441
      if (isOverApprox)
477
      {
478
224
        if (check(a[0][0], false))
479
        {
480
116
          if (check(a[0][0], true))
481
          {
482
            // x > 0 implies
483
            //   x >= len( int.to.str( x ) )
484
7
            approx.push_back(a[0][0]);
485
          }
486
          else
487
          {
488
            // x >= 0 implies
489
            //   x+1 >= len( int.to.str( x ) )
490
109
            approx.push_back(
491
218
                nm->mkNode(PLUS, nm->mkConst(Rational(1)), a[0][0]));
492
          }
493
        }
494
      }
495
      else
496
      {
497
217
        if (check(a[0][0]))
498
        {
499
          // x >= 0 implies
500
          //   len( int.to.str( x ) ) >= 1
501
97
          approx.push_back(nm->mkConst(Rational(1)));
502
        }
503
        // other crazy things are possible here, e.g.
504
        // len( int.to.str( len( y ) + 10 ) ) >= 2
505
      }
506
    }
507
  }
508
33503
  else if (ak == STRING_STRIDOF)
509
  {
510
    // over,under-approximations for indexof( x, y, n )
511
4240
    if (isOverApprox)
512
    {
513
4452
      Node lenx = nm->mkNode(STRING_LENGTH, a[0]);
514
4452
      Node leny = nm->mkNode(STRING_LENGTH, a[1]);
515
2226
      if (check(lenx, leny))
516
      {
517
        // len( x ) >= len( y ) implies
518
        //   len( x ) - len( y ) >= indexof( x, y, n )
519
22
        approx.push_back(nm->mkNode(MINUS, lenx, leny));
520
      }
521
      else
522
      {
523
        // len( x ) >= indexof( x, y, n )
524
2204
        approx.push_back(lenx);
525
      }
526
    }
527
    else
528
    {
529
      // TODO?:
530
      // contains( substr( x, n, len( x ) ), y ) implies
531
      //   n <= indexof( x, y, n )
532
      // ...hard to test, runs risk of non-termination
533
534
      // -1 <= indexof( x, y, n )
535
2014
      approx.push_back(nm->mkConst(Rational(-1)));
536
    }
537
  }
538
29263
  else if (ak == STRING_STOI)
539
  {
540
    // over,under-approximations for str.to.int( x )
541
    if (isOverApprox)
542
    {
543
      // TODO?:
544
      // y >= 0 implies
545
      //   y >= str.to.int( int.to.str( y ) )
546
    }
547
    else
548
    {
549
      // -1 <= str.to.int( x )
550
      approx.push_back(nm->mkConst(Rational(-1)));
551
    }
552
  }
553
312628
  Trace("strings-ent-approx-debug") << "Return " << approx.size() << std::endl;
554
312628
}
555
556
50438
bool ArithEntail::checkWithEqAssumption(Node assumption, Node a, bool strict)
557
{
558
50438
  Assert(assumption.getKind() == kind::EQUAL);
559
50438
  Assert(Rewriter::rewrite(assumption) == assumption);
560
100876
  Trace("strings-entail") << "checkWithEqAssumption: " << assumption << " " << a
561
50438
                          << ", strict=" << strict << std::endl;
562
563
  // Find candidates variables to compute substitutions for
564
100876
  std::unordered_set<Node> candVars;
565
100876
  std::vector<Node> toVisit = {assumption};
566
624112
  while (!toVisit.empty())
567
  {
568
573674
    Node curr = toVisit.back();
569
286837
    toVisit.pop_back();
570
571
818728
    if (curr.getKind() == kind::PLUS || curr.getKind() == kind::MULT
572
515279
        || curr.getKind() == kind::MINUS || curr.getKind() == kind::EQUAL)
573
    {
574
345232
      for (const auto& currChild : curr)
575
      {
576
236399
        toVisit.push_back(currChild);
577
      }
578
    }
579
178004
    else if (curr.isVar() && Theory::theoryOf(curr) == THEORY_ARITH)
580
    {
581
8802
      candVars.insert(curr);
582
    }
583
169202
    else if (curr.getKind() == kind::STRING_LENGTH)
584
    {
585
105842
      candVars.insert(curr);
586
    }
587
  }
588
589
  // Check if any of the candidate variables are in n
590
100876
  Node v;
591
50438
  Assert(toVisit.empty());
592
50438
  toVisit.push_back(a);
593
503378
  while (!toVisit.empty())
594
  {
595
474386
    Node curr = toVisit.back();
596
247916
    toVisit.pop_back();
597
598
505619
    for (const auto& currChild : curr)
599
    {
600
257703
      toVisit.push_back(currChild);
601
    }
602
603
247916
    if (candVars.find(curr) != candVars.end())
604
    {
605
21446
      v = curr;
606
21446
      break;
607
    }
608
  }
609
610
50438
  if (v.isNull())
611
  {
612
    // No suitable candidate found
613
28992
    return false;
614
  }
615
616
42892
  Node solution = ArithMSum::solveEqualityFor(assumption, v);
617
21446
  if (solution.isNull())
618
  {
619
    // Could not solve for v
620
80
    return false;
621
  }
622
42732
  Trace("strings-entail") << "checkWithEqAssumption: subs " << v << " -> "
623
21366
                          << solution << std::endl;
624
625
  // use capture avoiding substitution
626
21366
  a = expr::substituteCaptureAvoiding(a, v, solution);
627
21366
  return check(a, strict);
628
}
629
630
74735
bool ArithEntail::checkWithAssumption(Node assumption,
631
                                      Node a,
632
                                      Node b,
633
                                      bool strict)
634
{
635
74735
  Assert(Rewriter::rewrite(assumption) == assumption);
636
637
74735
  NodeManager* nm = NodeManager::currentNM();
638
639
74735
  if (!assumption.isConst() && assumption.getKind() != kind::EQUAL)
640
  {
641
    // We rewrite inequality assumptions from x <= y to x + (str.len s) = y
642
    // where s is some fresh string variable. We use (str.len s) because
643
    // (str.len s) must be non-negative for the equation to hold.
644
100860
    Node x, y;
645
50430
    if (assumption.getKind() == kind::GEQ)
646
    {
647
41406
      x = assumption[0];
648
41406
      y = assumption[1];
649
    }
650
    else
651
    {
652
      // (not (>= s t)) --> (>= (t - 1) s)
653
9024
      Assert(assumption.getKind() == kind::NOT
654
             && assumption[0].getKind() == kind::GEQ);
655
9024
      x = nm->mkNode(kind::MINUS, assumption[0][1], nm->mkConst(Rational(1)));
656
9024
      y = assumption[0][0];
657
    }
658
659
100860
    Node s = nm->mkBoundVar("slackVal", nm->stringType());
660
100860
    Node slen = nm->mkNode(kind::STRING_LENGTH, s);
661
50430
    assumption = Rewriter::rewrite(
662
100860
        nm->mkNode(kind::EQUAL, x, nm->mkNode(kind::PLUS, y, slen)));
663
  }
664
665
149470
  Node diff = nm->mkNode(kind::MINUS, a, b);
666
74735
  bool res = false;
667
74735
  if (assumption.isConst())
668
  {
669
24297
    bool assumptionBool = assumption.getConst<bool>();
670
24297
    if (assumptionBool)
671
    {
672
24297
      res = check(diff, strict);
673
    }
674
    else
675
    {
676
      res = true;
677
    }
678
  }
679
  else
680
  {
681
50438
    res = checkWithEqAssumption(assumption, diff, strict);
682
  }
683
149470
  return res;
684
}
685
686
bool ArithEntail::checkWithAssumptions(std::vector<Node> assumptions,
687
                                       Node a,
688
                                       Node b,
689
                                       bool strict)
690
{
691
  // TODO: We currently try to show the entailment with each assumption
692
  // independently. In the future, we should make better use of multiple
693
  // assumptions.
694
  bool res = false;
695
  for (const auto& assumption : assumptions)
696
  {
697
    Assert(Rewriter::rewrite(assumption) == assumption);
698
699
    if (checkWithAssumption(assumption, a, b, strict))
700
    {
701
      res = true;
702
      break;
703
    }
704
  }
705
  return res;
706
}
707
708
25132
Node ArithEntail::getConstantBound(Node a, bool isLower)
709
{
710
25132
  Assert(Rewriter::rewrite(a) == a);
711
25132
  Node ret;
712
25132
  if (a.isConst())
713
  {
714
9915
    ret = a;
715
  }
716
15217
  else if (a.getKind() == kind::STRING_LENGTH)
717
  {
718
3008
    if (isLower)
719
    {
720
3008
      ret = NodeManager::currentNM()->mkConst(Rational(0));
721
    }
722
  }
723
12209
  else if (a.getKind() == kind::PLUS || a.getKind() == kind::MULT)
724
  {
725
17438
    std::vector<Node> children;
726
8719
    bool success = true;
727
16356
    for (unsigned i = 0; i < a.getNumChildren(); i++)
728
    {
729
22716
      Node ac = getConstantBound(a[i], isLower);
730
15079
      if (ac.isNull())
731
      {
732
4049
        ret = ac;
733
4049
        success = false;
734
4049
        break;
735
      }
736
      else
737
      {
738
11030
        if (ac.getConst<Rational>().sgn() == 0)
739
        {
740
2719
          if (a.getKind() == kind::MULT)
741
          {
742
4
            ret = ac;
743
4
            success = false;
744
4
            break;
745
          }
746
        }
747
        else
748
        {
749
8311
          if (a.getKind() == kind::MULT)
750
          {
751
3393
            if ((ac.getConst<Rational>().sgn() > 0) != isLower)
752
            {
753
3389
              ret = Node::null();
754
3389
              success = false;
755
3389
              break;
756
            }
757
          }
758
4922
          children.push_back(ac);
759
        }
760
      }
761
    }
762
8719
    if (success)
763
    {
764
1277
      if (children.empty())
765
      {
766
268
        ret = NodeManager::currentNM()->mkConst(Rational(0));
767
      }
768
1009
      else if (children.size() == 1)
769
      {
770
1009
        ret = children[0];
771
      }
772
      else
773
      {
774
        ret = NodeManager::currentNM()->mkNode(a.getKind(), children);
775
        ret = Rewriter::rewrite(ret);
776
      }
777
    }
778
  }
779
50264
  Trace("strings-rewrite-cbound")
780
25132
      << "Constant " << (isLower ? "lower" : "upper") << " bound for " << a
781
25132
      << " is " << ret << std::endl;
782
25132
  Assert(ret.isNull() || ret.isConst());
783
  // entailment check should be at least as powerful as computing a lower bound
784
25132
  Assert(!isLower || ret.isNull() || ret.getConst<Rational>().sgn() < 0
785
         || check(a, false));
786
25132
  Assert(!isLower || ret.isNull() || ret.getConst<Rational>().sgn() <= 0
787
         || check(a, true));
788
25132
  return ret;
789
}
790
791
393673
bool ArithEntail::checkInternal(Node a)
792
{
793
393673
  Assert(Rewriter::rewrite(a) == a);
794
  // check whether a >= 0
795
393673
  if (a.isConst())
796
  {
797
166364
    return a.getConst<Rational>().sgn() >= 0;
798
  }
799
227309
  else if (a.getKind() == kind::STRING_LENGTH)
800
  {
801
    // str.len( t ) >= 0
802
32719
    return true;
803
  }
804
194590
  else if (a.getKind() == kind::PLUS || a.getKind() == kind::MULT)
805
  {
806
244098
    for (unsigned i = 0; i < a.getNumChildren(); i++)
807
    {
808
242380
      if (!checkInternal(a[i]))
809
      {
810
190476
        return false;
811
      }
812
    }
813
    // t1 >= 0 ^ ... ^ tn >= 0 => t1 op ... op tn >= 0
814
1718
    return true;
815
  }
816
817
2396
  return false;
818
}
819
820
1357
bool ArithEntail::inferZerosInSumGeq(Node x,
821
                                     std::vector<Node>& ys,
822
                                     std::vector<Node>& zeroYs)
823
{
824
1357
  Assert(zeroYs.empty());
825
826
1357
  NodeManager* nm = NodeManager::currentNM();
827
828
  // Check if we can show that y1 + ... + yn >= x
829
2714
  Node sum = (ys.size() > 1) ? nm->mkNode(PLUS, ys) : ys[0];
830
1357
  if (!check(sum, x))
831
  {
832
438
    return false;
833
  }
834
835
  // Try to remove yi one-by-one and check if we can still show:
836
  //
837
  // y1 + ... + yi-1 +  yi+1 + ... + yn >= x
838
  //
839
  // If that's the case, we know that yi can be zero and the inequality still
840
  // holds.
841
919
  size_t i = 0;
842
5801
  while (i < ys.size())
843
  {
844
4882
    Node yi = ys[i];
845
2441
    std::vector<Node>::iterator pos = ys.erase(ys.begin() + i);
846
2441
    if (ys.size() > 1)
847
    {
848
915
      sum = nm->mkNode(PLUS, ys);
849
    }
850
    else
851
    {
852
1526
      sum = ys.size() == 1 ? ys[0] : nm->mkConst(Rational(0));
853
    }
854
855
2441
    if (check(sum, x))
856
    {
857
886
      zeroYs.push_back(yi);
858
    }
859
    else
860
    {
861
1555
      ys.insert(pos, yi);
862
1555
      i++;
863
    }
864
  }
865
919
  return true;
866
}
867
868
}  // namespace strings
869
}  // namespace theory
870
27735
}  // namespace cvc5