GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/strings/arith_entail.cpp Lines: 385 427 90.2 %
Date: 2021-11-07 Branches: 941 2036 46.2 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Andres Noetzli, Aina Niemetz
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of arithmetic entailment computation for string terms.
14
 */
15
16
#include "theory/strings/arith_entail.h"
17
18
#include "expr/attribute.h"
19
#include "expr/node_algorithm.h"
20
#include "theory/arith/arith_msum.h"
21
#include "theory/rewriter.h"
22
#include "theory/strings/theory_strings_utils.h"
23
#include "theory/strings/word.h"
24
#include "theory/theory.h"
25
#include "util/rational.h"
26
27
using namespace cvc5::kind;
28
29
namespace cvc5 {
30
namespace theory {
31
namespace strings {
32
33
33191
ArithEntail::ArithEntail(Rewriter* r) : d_rr(r)
34
{
35
33191
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
36
33191
}
37
38
117736
Node ArithEntail::rewrite(Node a) { return d_rr->rewrite(a); }
39
40
1849
bool ArithEntail::checkEq(Node a, Node b)
41
{
42
1849
  if (a == b)
43
  {
44
    return true;
45
  }
46
3698
  Node ar = d_rr->rewrite(a);
47
3698
  Node br = d_rr->rewrite(b);
48
1849
  return ar == br;
49
}
50
51
455734
bool ArithEntail::check(Node a, Node b, bool strict)
52
{
53
455734
  if (a == b)
54
  {
55
21751
    return !strict;
56
  }
57
867966
  Node diff = NodeManager::currentNM()->mkNode(kind::MINUS, a, b);
58
433983
  return check(diff, strict);
59
}
60
61
struct StrCheckEntailArithTag
62
{
63
};
64
struct StrCheckEntailArithComputedTag
65
{
66
};
67
/** Attribute true for expressions for which check returned true */
68
typedef expr::Attribute<StrCheckEntailArithTag, bool> StrCheckEntailArithAttr;
69
typedef expr::Attribute<StrCheckEntailArithComputedTag, bool>
70
    StrCheckEntailArithComputedAttr;
71
72
683998
bool ArithEntail::check(Node a, bool strict)
73
{
74
683998
  if (a.isConst())
75
  {
76
71270
    return a.getConst<Rational>().sgn() >= (strict ? 1 : 0);
77
  }
78
79
  Node ar = strict ? NodeManager::currentNM()->mkNode(
80
828261
                kind::MINUS, a, NodeManager::currentNM()->mkConst(Rational(1)))
81
1440989
                   : a;
82
612728
  ar = d_rr->rewrite(ar);
83
84
612728
  if (ar.getAttribute(StrCheckEntailArithComputedAttr()))
85
  {
86
550448
    return ar.getAttribute(StrCheckEntailArithAttr());
87
  }
88
89
62280
  bool ret = checkInternal(ar);
90
62280
  if (!ret)
91
  {
92
    // try with approximations
93
56973
    ret = checkApprox(ar);
94
  }
95
  // cache the result
96
62280
  ar.setAttribute(StrCheckEntailArithAttr(), ret);
97
62280
  ar.setAttribute(StrCheckEntailArithComputedAttr(), true);
98
62280
  return ret;
99
}
100
101
56973
bool ArithEntail::checkApprox(Node ar)
102
{
103
56973
  Assert(d_rr->rewrite(ar) == ar);
104
56973
  NodeManager* nm = NodeManager::currentNM();
105
113946
  std::map<Node, Node> msum;
106
113946
  Trace("strings-ent-approx-debug")
107
56973
      << "Setup arithmetic approximations for " << ar << std::endl;
108
56973
  if (!ArithMSum::getMonomialSum(ar, msum))
109
  {
110
    Trace("strings-ent-approx-debug")
111
        << "...failed to get monomial sum!" << std::endl;
112
    return false;
113
  }
114
  // for each monomial v*c, mApprox[v] a list of
115
  // possibilities for how the term can be soundly approximated, that is,
116
  // if mApprox[v] contains av, then v*c > av*c. Notice that if c
117
  // is positive, then v > av, otherwise if c is negative, then v < av.
118
  // In other words, av is an under-approximation if c is positive, and an
119
  // over-approximation if c is negative.
120
56973
  bool changed = false;
121
113946
  std::map<Node, std::vector<Node> > mApprox;
122
  // map from approximations to their monomial sums
123
113946
  std::map<Node, std::map<Node, Node> > approxMsums;
124
  // aarSum stores each monomial that does not have multiple approximations
125
113946
  std::vector<Node> aarSum;
126
186739
  for (std::pair<const Node, Node>& m : msum)
127
  {
128
259532
    Node v = m.first;
129
259532
    Node c = m.second;
130
259532
    Trace("strings-ent-approx-debug")
131
129766
        << "Get approximations " << v << "..." << std::endl;
132
129766
    if (v.isNull())
133
    {
134
83556
      Node mn = c.isNull() ? nm->mkConst(Rational(1)) : c;
135
41778
      aarSum.push_back(mn);
136
    }
137
    else
138
    {
139
      // c.isNull() means c = 1
140
87988
      bool isOverApprox = !c.isNull() && c.getConst<Rational>().sgn() == -1;
141
87988
      std::vector<Node>& approx = mApprox[v];
142
175976
      std::unordered_set<Node> visited;
143
175976
      std::vector<Node> toProcess;
144
87988
      toProcess.push_back(v);
145
32362
      do
146
      {
147
240700
        Node curr = toProcess.back();
148
120350
        Trace("strings-ent-approx-debug") << "  process " << curr << std::endl;
149
120350
        curr = d_rr->rewrite(curr);
150
120350
        toProcess.pop_back();
151
120350
        if (visited.find(curr) == visited.end())
152
        {
153
118508
          visited.insert(curr);
154
237016
          std::vector<Node> currApprox;
155
118508
          getArithApproximations(curr, currApprox, isOverApprox);
156
118508
          if (currApprox.empty())
157
          {
158
191842
            Trace("strings-ent-approx-debug")
159
95921
                << "...approximation: " << curr << std::endl;
160
            // no approximations, thus curr is a possibility
161
95921
            approx.push_back(curr);
162
          }
163
          else
164
          {
165
22587
            toProcess.insert(
166
45174
                toProcess.end(), currApprox.begin(), currApprox.end());
167
          }
168
        }
169
120350
      } while (!toProcess.empty());
170
87988
      Assert(!approx.empty());
171
      // if we have only one approximation, move it to final
172
87988
      if (approx.size() == 1)
173
      {
174
80097
        changed = v != approx[0];
175
160194
        Node mn = ArithMSum::mkCoeffTerm(c, approx[0]);
176
80097
        aarSum.push_back(mn);
177
80097
        mApprox.erase(v);
178
      }
179
      else
180
      {
181
        // compute monomial sum form for each approximation, used below
182
23715
        for (const Node& aa : approx)
183
        {
184
15824
          if (approxMsums.find(aa) == approxMsums.end())
185
          {
186
            CVC5_UNUSED bool ret =
187
15675
                ArithMSum::getMonomialSum(aa, approxMsums[aa]);
188
15675
            Assert(ret);
189
          }
190
        }
191
7891
        changed = true;
192
      }
193
    }
194
  }
195
56973
  if (!changed)
196
  {
197
    // approximations had no effect, return
198
40899
    Trace("strings-ent-approx-debug") << "...no approximations" << std::endl;
199
40899
    return false;
200
  }
201
  // get the current "fixed" sum for the abstraction of ar
202
16074
  Node aar = aarSum.empty()
203
                 ? d_zero
204
32148
                 : (aarSum.size() == 1 ? aarSum[0] : nm->mkNode(PLUS, aarSum));
205
16074
  aar = d_rr->rewrite(aar);
206
32148
  Trace("strings-ent-approx-debug")
207
32148
      << "...processed fixed sum " << aar << " with " << mApprox.size()
208
16074
      << " approximated monomials." << std::endl;
209
  // if we have a choice of how to approximate
210
16074
  if (!mApprox.empty())
211
  {
212
    // convert aar back to monomial sum
213
13454
    std::map<Node, Node> msumAar;
214
6727
    if (!ArithMSum::getMonomialSum(aar, msumAar))
215
    {
216
      return false;
217
    }
218
6727
    if (Trace.isOn("strings-ent-approx"))
219
    {
220
      Trace("strings-ent-approx")
221
          << "---- Check arithmetic entailment by under-approximation " << ar
222
          << " >= 0" << std::endl;
223
      Trace("strings-ent-approx") << "FIXED:" << std::endl;
224
      ArithMSum::debugPrintMonomialSum(msumAar, "strings-ent-approx");
225
      Trace("strings-ent-approx") << "APPROX:" << std::endl;
226
      for (std::pair<const Node, std::vector<Node> >& a : mApprox)
227
      {
228
        Node c = msum[a.first];
229
        Trace("strings-ent-approx") << "  ";
230
        if (!c.isNull())
231
        {
232
          Trace("strings-ent-approx") << c << " * ";
233
        }
234
        Trace("strings-ent-approx")
235
            << a.second << " ...from " << a.first << std::endl;
236
      }
237
      Trace("strings-ent-approx") << std::endl;
238
    }
239
13454
    Rational one(1);
240
    // incorporate monomials one at a time that have a choice of approximations
241
20239
    while (!mApprox.empty())
242
    {
243
13512
      Node v;
244
13512
      Node vapprox;
245
6756
      int maxScore = -1;
246
      // Look at each approximation, take the one with the best score.
247
      // Notice that we are in the process of trying to prove
248
      // ( c1*t1 + .. + cn*tn ) + ( approx_1 | ... | approx_m ) >= 0,
249
      // where c1*t1 + .. + cn*tn is the "fixed" component of our sum (aar)
250
      // and approx_1 ... approx_m are possible approximations. The
251
      // intution here is that we want coefficients c1...cn to be positive.
252
      // This is because arithmetic string terms t1...tn (which may be
253
      // applications of len, indexof, str.to.int) are never entailed to be
254
      // negative. Hence, we add the approx_i that contributes the "most"
255
      // towards making all constants c1...cn positive and cancelling negative
256
      // monomials in approx_i itself.
257
6756
      for (std::pair<const Node, std::vector<Node> >& nam : mApprox)
258
      {
259
6756
        Node cr = msum[nam.first];
260
20304
        for (const Node& aa : nam.second)
261
        {
262
13548
          unsigned helpsCancelCount = 0;
263
13548
          unsigned addsObligationCount = 0;
264
13548
          std::map<Node, Node>::iterator it;
265
          // we are processing an approximation cr*( c1*t1 + ... + cn*tn )
266
29537
          for (std::pair<const Node, Node>& aam : approxMsums[aa])
267
          {
268
            // Say aar is of the form t + c*ti, and aam is the monomial ci*ti
269
            // where ci != 0. We say aam:
270
            // (1) helps cancel if c != 0 and c>0 != ci>0
271
            // (2) adds obligation if c>=0 and c+ci<0
272
31978
            Node ti = aam.first;
273
31978
            Node ci = aam.second;
274
15989
            if (!cr.isNull())
275
            {
276
15989
              ci = ci.isNull() ? cr : d_rr->rewrite(nm->mkNode(MULT, ci, cr));
277
            }
278
15989
            Trace("strings-ent-approx-debug") << ci << "*" << ti << " ";
279
15989
            int ciSgn = ci.isNull() ? 1 : ci.getConst<Rational>().sgn();
280
15989
            it = msumAar.find(ti);
281
15989
            if (it != msumAar.end())
282
            {
283
22942
              Node c = it->second;
284
11471
              int cSgn = c.isNull() ? 1 : c.getConst<Rational>().sgn();
285
11471
              if (cSgn == 0)
286
              {
287
1581
                addsObligationCount += (ciSgn == -1 ? 1 : 0);
288
              }
289
9890
              else if (cSgn != ciSgn)
290
              {
291
9271
                helpsCancelCount++;
292
18542
                Rational r1 = c.isNull() ? one : c.getConst<Rational>();
293
18542
                Rational r2 = ci.isNull() ? one : ci.getConst<Rational>();
294
18542
                Rational r12 = r1 + r2;
295
9271
                if (r12.sgn() == -1)
296
                {
297
7608
                  addsObligationCount++;
298
                }
299
              }
300
            }
301
            else
302
            {
303
4518
              addsObligationCount += (ciSgn == -1 ? 1 : 0);
304
            }
305
          }
306
27096
          Trace("strings-ent-approx-debug")
307
13548
              << "counts=" << helpsCancelCount << "," << addsObligationCount
308
13548
              << " for " << aa << " into " << aar << std::endl;
309
27096
          int score = (addsObligationCount > 0 ? 0 : 2)
310
13548
                      + (helpsCancelCount > 0 ? 1 : 0);
311
          // if its the best, update v and vapprox
312
13548
          if (v.isNull() || score > maxScore)
313
          {
314
8595
            v = nam.first;
315
8595
            vapprox = aa;
316
8595
            maxScore = score;
317
          }
318
        }
319
6756
        if (!v.isNull())
320
        {
321
6756
          break;
322
        }
323
      }
324
13512
      Trace("strings-ent-approx")
325
6756
          << "- Decide " << v << " = " << vapprox << std::endl;
326
      // we incorporate v approximated by vapprox into the overall approximation
327
      // for ar
328
6756
      Assert(!v.isNull() && !vapprox.isNull());
329
6756
      Assert(msum.find(v) != msum.end());
330
13512
      Node mn = ArithMSum::mkCoeffTerm(msum[v], vapprox);
331
6756
      aar = nm->mkNode(PLUS, aar, mn);
332
      // update the msumAar map
333
6756
      aar = d_rr->rewrite(aar);
334
6756
      msumAar.clear();
335
6756
      if (!ArithMSum::getMonomialSum(aar, msumAar))
336
      {
337
        Assert(false);
338
        Trace("strings-ent-approx")
339
            << "...failed to get monomial sum!" << std::endl;
340
        return false;
341
      }
342
      // we have processed the approximation for v
343
6756
      mApprox.erase(v);
344
    }
345
6727
    Trace("strings-ent-approx") << "-----------------" << std::endl;
346
  }
347
16074
  if (aar == ar)
348
  {
349
    Trace("strings-ent-approx-debug")
350
        << "...approximation had no effect" << std::endl;
351
    // this should never happen, but we avoid the infinite loop for sanity here
352
    Assert(false);
353
    return false;
354
  }
355
  // Check entailment on the approximation of ar.
356
  // Notice that this may trigger further reasoning by approximation. For
357
  // example, len( replace( x ++ y, substr( x, 0, n ), z ) ) may be
358
  // under-approximated as len( x ) + len( y ) - len( substr( x, 0, n ) ) on
359
  // this call, where in the recursive call we may over-approximate
360
  // len( substr( x, 0, n ) ) as len( x ). In this example, we can infer
361
  // that len( replace( x ++ y, substr( x, 0, n ), z ) ) >= len( y ) in two
362
  // steps.
363
16074
  if (check(aar))
364
  {
365
3538
    Trace("strings-ent-approx")
366
1769
        << "*** StrArithApprox: showed " << ar
367
1769
        << " >= 0 using under-approximation!" << std::endl;
368
3538
    Trace("strings-ent-approx")
369
1769
        << "*** StrArithApprox: under-approximation was " << aar << std::endl;
370
1769
    return true;
371
  }
372
14305
  return false;
373
}
374
375
118956
void ArithEntail::getArithApproximations(Node a,
376
                                         std::vector<Node>& approx,
377
                                         bool isOverApprox)
378
{
379
118956
  NodeManager* nm = NodeManager::currentNM();
380
  // We do not handle PLUS here since this leads to exponential behavior.
381
  // Instead, this is managed, e.g. during checkApprox, where
382
  // PLUS terms are expanded "on-demand" during the reasoning.
383
237912
  Trace("strings-ent-approx-debug")
384
118956
      << "Get arith approximations " << a << std::endl;
385
118956
  Kind ak = a.getKind();
386
118956
  if (ak == MULT)
387
  {
388
896
    Node c;
389
896
    Node v;
390
448
    if (ArithMSum::getMonomial(a, c, v))
391
    {
392
448
      bool isNeg = c.getConst<Rational>().sgn() == -1;
393
448
      getArithApproximations(v, approx, isNeg ? !isOverApprox : isOverApprox);
394
508
      for (unsigned i = 0, size = approx.size(); i < size; i++)
395
      {
396
60
        approx[i] = nm->mkNode(MULT, c, approx[i]);
397
      }
398
    }
399
  }
400
118508
  else if (ak == STRING_LENGTH)
401
  {
402
86250
    Kind aak = a[0].getKind();
403
86250
    if (aak == STRING_SUBSTR)
404
    {
405
      // over,under-approximations for len( substr( x, n, m ) )
406
52768
      Node lenx = nm->mkNode(STRING_LENGTH, a[0][0]);
407
26384
      if (isOverApprox)
408
      {
409
        // m >= 0 implies
410
        //  m >= len( substr( x, n, m ) )
411
14953
        if (check(a[0][2]))
412
        {
413
9148
          approx.push_back(a[0][2]);
414
        }
415
14953
        if (check(lenx, a[0][1]))
416
        {
417
          // n <= len( x ) implies
418
          //   len( x ) - n >= len( substr( x, n, m ) )
419
4966
          approx.push_back(nm->mkNode(MINUS, lenx, a[0][1]));
420
        }
421
        else
422
        {
423
          // len( x ) >= len( substr( x, n, m ) )
424
9987
          approx.push_back(lenx);
425
        }
426
      }
427
      else
428
      {
429
        // 0 <= n and n+m <= len( x ) implies
430
        //   m <= len( substr( x, n, m ) )
431
22862
        Node npm = nm->mkNode(PLUS, a[0][1], a[0][2]);
432
11431
        if (check(a[0][1]) && check(lenx, npm))
433
        {
434
1034
          approx.push_back(a[0][2]);
435
        }
436
        // 0 <= n and n+m >= len( x ) implies
437
        //   len(x)-n <= len( substr( x, n, m ) )
438
11431
        if (check(a[0][1]) && check(npm, lenx))
439
        {
440
684
          approx.push_back(nm->mkNode(MINUS, lenx, a[0][1]));
441
        }
442
      }
443
    }
444
59866
    else if (aak == STRING_REPLACE)
445
    {
446
      // over,under-approximations for len( replace( x, y, z ) )
447
      // notice this is either len( x ) or ( len( x ) + len( z ) - len( y ) )
448
7462
      Node lenx = nm->mkNode(STRING_LENGTH, a[0][0]);
449
7462
      Node leny = nm->mkNode(STRING_LENGTH, a[0][1]);
450
7462
      Node lenz = nm->mkNode(STRING_LENGTH, a[0][2]);
451
3731
      if (isOverApprox)
452
      {
453
2027
        if (check(leny, lenz))
454
        {
455
          // len( y ) >= len( z ) implies
456
          //   len( x ) >= len( replace( x, y, z ) )
457
510
          approx.push_back(lenx);
458
        }
459
        else
460
        {
461
          // len( x ) + len( z ) >= len( replace( x, y, z ) )
462
1517
          approx.push_back(nm->mkNode(PLUS, lenx, lenz));
463
        }
464
      }
465
      else
466
      {
467
1704
        if (check(lenz, leny) || check(lenz, lenx))
468
        {
469
          // len( y ) <= len( z ) or len( x ) <= len( z ) implies
470
          //   len( x ) <= len( replace( x, y, z ) )
471
923
          approx.push_back(lenx);
472
        }
473
        else
474
        {
475
          // len( x ) - len( y ) <= len( replace( x, y, z ) )
476
781
          approx.push_back(nm->mkNode(MINUS, lenx, leny));
477
        }
478
      }
479
    }
480
56135
    else if (aak == STRING_ITOS)
481
    {
482
      // over,under-approximations for len( int.to.str( x ) )
483
310
      if (isOverApprox)
484
      {
485
161
        if (check(a[0][0], false))
486
        {
487
54
          if (check(a[0][0], true))
488
          {
489
            // x > 0 implies
490
            //   x >= len( int.to.str( x ) )
491
7
            approx.push_back(a[0][0]);
492
          }
493
          else
494
          {
495
            // x >= 0 implies
496
            //   x+1 >= len( int.to.str( x ) )
497
47
            approx.push_back(
498
94
                nm->mkNode(PLUS, nm->mkConst(Rational(1)), a[0][0]));
499
          }
500
        }
501
      }
502
      else
503
      {
504
149
        if (check(a[0][0]))
505
        {
506
          // x >= 0 implies
507
          //   len( int.to.str( x ) ) >= 1
508
38
          approx.push_back(nm->mkConst(Rational(1)));
509
        }
510
        // other crazy things are possible here, e.g.
511
        // len( int.to.str( len( y ) + 10 ) ) >= 2
512
      }
513
    }
514
  }
515
32258
  else if (ak == STRING_INDEXOF)
516
  {
517
    // over,under-approximations for indexof( x, y, n )
518
2720
    if (isOverApprox)
519
    {
520
2920
      Node lenx = nm->mkNode(STRING_LENGTH, a[0]);
521
2920
      Node leny = nm->mkNode(STRING_LENGTH, a[1]);
522
1460
      if (check(lenx, leny))
523
      {
524
        // len( x ) >= len( y ) implies
525
        //   len( x ) - len( y ) >= indexof( x, y, n )
526
20
        approx.push_back(nm->mkNode(MINUS, lenx, leny));
527
      }
528
      else
529
      {
530
        // len( x ) >= indexof( x, y, n )
531
1440
        approx.push_back(lenx);
532
      }
533
    }
534
    else
535
    {
536
      // TODO?:
537
      // contains( substr( x, n, len( x ) ), y ) implies
538
      //   n <= indexof( x, y, n )
539
      // ...hard to test, runs risk of non-termination
540
541
      // -1 <= indexof( x, y, n )
542
1260
      approx.push_back(nm->mkConst(Rational(-1)));
543
    }
544
  }
545
29538
  else if (ak == STRING_STOI)
546
  {
547
    // over,under-approximations for str.to.int( x )
548
    if (isOverApprox)
549
    {
550
      // TODO?:
551
      // y >= 0 implies
552
      //   y >= str.to.int( int.to.str( y ) )
553
    }
554
    else
555
    {
556
      // -1 <= str.to.int( x )
557
      approx.push_back(nm->mkConst(Rational(-1)));
558
    }
559
  }
560
118956
  Trace("strings-ent-approx-debug") << "Return " << approx.size() << std::endl;
561
118956
}
562
563
3632
bool ArithEntail::checkWithEqAssumption(Node assumption, Node a, bool strict)
564
{
565
3632
  Assert(assumption.getKind() == kind::EQUAL);
566
3632
  Assert(d_rr->rewrite(assumption) == assumption);
567
7264
  Trace("strings-entail") << "checkWithEqAssumption: " << assumption << " " << a
568
3632
                          << ", strict=" << strict << std::endl;
569
570
  // Find candidates variables to compute substitutions for
571
7264
  std::unordered_set<Node> candVars;
572
7264
  std::vector<Node> toVisit = {assumption};
573
40194
  while (!toVisit.empty())
574
  {
575
36562
    Node curr = toVisit.back();
576
18281
    toVisit.pop_back();
577
578
52394
    if (curr.getKind() == kind::PLUS || curr.getKind() == kind::MULT
579
32934
        || curr.getKind() == kind::MINUS || curr.getKind() == kind::EQUAL)
580
    {
581
21909
      for (const auto& currChild : curr)
582
      {
583
14649
        toVisit.push_back(currChild);
584
      }
585
    }
586
11021
    else if (curr.isVar() && Theory::theoryOf(curr) == THEORY_ARITH)
587
    {
588
2992
      candVars.insert(curr);
589
    }
590
8029
    else if (curr.getKind() == kind::STRING_LENGTH)
591
    {
592
4139
      candVars.insert(curr);
593
    }
594
  }
595
596
  // Check if any of the candidate variables are in n
597
7264
  Node v;
598
3632
  Assert(toVisit.empty());
599
3632
  toVisit.push_back(a);
600
27246
  while (!toVisit.empty())
601
  {
602
23754
    Node curr = toVisit.back();
603
11947
    toVisit.pop_back();
604
605
20536
    for (const auto& currChild : curr)
606
    {
607
8589
      toVisit.push_back(currChild);
608
    }
609
610
11947
    if (candVars.find(curr) != candVars.end())
611
    {
612
140
      v = curr;
613
140
      break;
614
    }
615
  }
616
617
3632
  if (v.isNull())
618
  {
619
    // No suitable candidate found
620
3492
    return false;
621
  }
622
623
280
  Node solution = ArithMSum::solveEqualityFor(assumption, v);
624
140
  if (solution.isNull())
625
  {
626
    // Could not solve for v
627
    return false;
628
  }
629
280
  Trace("strings-entail") << "checkWithEqAssumption: subs " << v << " -> "
630
140
                          << solution << std::endl;
631
632
  // use capture avoiding substitution
633
140
  a = expr::substituteCaptureAvoiding(a, v, solution);
634
140
  return check(a, strict);
635
}
636
637
6739
bool ArithEntail::checkWithAssumption(Node assumption,
638
                                      Node a,
639
                                      Node b,
640
                                      bool strict)
641
{
642
6739
  Assert(d_rr->rewrite(assumption) == assumption);
643
644
6739
  NodeManager* nm = NodeManager::currentNM();
645
646
6739
  if (!assumption.isConst() && assumption.getKind() != kind::EQUAL)
647
  {
648
    // We rewrite inequality assumptions from x <= y to x + (str.len s) = y
649
    // where s is some fresh string variable. We use (str.len s) because
650
    // (str.len s) must be non-negative for the equation to hold.
651
7248
    Node x, y;
652
3624
    if (assumption.getKind() == kind::GEQ)
653
    {
654
2459
      x = assumption[0];
655
2459
      y = assumption[1];
656
    }
657
    else
658
    {
659
      // (not (>= s t)) --> (>= (t - 1) s)
660
1165
      Assert(assumption.getKind() == kind::NOT
661
             && assumption[0].getKind() == kind::GEQ);
662
1165
      x = nm->mkNode(kind::MINUS, assumption[0][1], nm->mkConst(Rational(1)));
663
1165
      y = assumption[0][0];
664
    }
665
666
7248
    Node s = nm->mkBoundVar("slackVal", nm->stringType());
667
7248
    Node slen = nm->mkNode(kind::STRING_LENGTH, s);
668
3624
    assumption = d_rr->rewrite(
669
7248
        nm->mkNode(kind::EQUAL, x, nm->mkNode(kind::PLUS, y, slen)));
670
  }
671
672
13478
  Node diff = nm->mkNode(kind::MINUS, a, b);
673
6739
  bool res = false;
674
6739
  if (assumption.isConst())
675
  {
676
3107
    bool assumptionBool = assumption.getConst<bool>();
677
3107
    if (assumptionBool)
678
    {
679
3107
      res = check(diff, strict);
680
    }
681
    else
682
    {
683
      res = true;
684
    }
685
  }
686
  else
687
  {
688
3632
    res = checkWithEqAssumption(assumption, diff, strict);
689
  }
690
13478
  return res;
691
}
692
693
bool ArithEntail::checkWithAssumptions(std::vector<Node> assumptions,
694
                                       Node a,
695
                                       Node b,
696
                                       bool strict)
697
{
698
  // TODO: We currently try to show the entailment with each assumption
699
  // independently. In the future, we should make better use of multiple
700
  // assumptions.
701
  bool res = false;
702
  for (const auto& assumption : assumptions)
703
  {
704
    Assert(d_rr->rewrite(assumption) == assumption);
705
706
    if (checkWithAssumption(assumption, a, b, strict))
707
    {
708
      res = true;
709
      break;
710
    }
711
  }
712
  return res;
713
}
714
715
struct ArithEntailConstantBoundLowerId
716
{
717
};
718
typedef expr::Attribute<ArithEntailConstantBoundLowerId, Node>
719
    ArithEntailConstantBoundLower;
720
721
struct ArithEntailConstantBoundUpperId
722
{
723
};
724
typedef expr::Attribute<ArithEntailConstantBoundUpperId, Node>
725
    ArithEntailConstantBoundUpper;
726
727
85903
void ArithEntail::setConstantBoundCache(Node n, Node ret, bool isLower)
728
{
729
85903
  if (isLower)
730
  {
731
    ArithEntailConstantBoundLower acbl;
732
40756
    n.setAttribute(acbl, ret);
733
  }
734
  else
735
  {
736
    ArithEntailConstantBoundUpper acbu;
737
45147
    n.setAttribute(acbu, ret);
738
  }
739
85903
}
740
741
597102
Node ArithEntail::getConstantBoundCache(Node n, bool isLower)
742
{
743
597102
  if (isLower)
744
  {
745
    ArithEntailConstantBoundLower acbl;
746
545342
    if (n.hasAttribute(acbl))
747
    {
748
524224
      return n.getAttribute(acbl);
749
    }
750
  }
751
  else
752
  {
753
    ArithEntailConstantBoundUpper acbu;
754
51760
    if (n.hasAttribute(acbu))
755
    {
756
34700
      return n.getAttribute(acbu);
757
    }
758
  }
759
38178
  return Node::null();
760
}
761
762
45159
Node ArithEntail::getConstantBound(Node a, bool isLower)
763
{
764
45159
  Assert(d_rr->rewrite(a) == a);
765
45159
  Node ret = getConstantBoundCache(a, isLower);
766
45159
  if (!ret.isNull())
767
  {
768
21463
    return ret;
769
  }
770
23696
  if (a.isConst())
771
  {
772
955
    ret = a;
773
  }
774
22741
  else if (a.getKind() == kind::STRING_LENGTH)
775
  {
776
430
    if (isLower)
777
    {
778
430
      ret = d_zero;
779
    }
780
  }
781
22311
  else if (a.getKind() == kind::PLUS || a.getKind() == kind::MULT)
782
  {
783
31150
    std::vector<Node> children;
784
15575
    bool success = true;
785
24667
    for (unsigned i = 0; i < a.getNumChildren(); i++)
786
    {
787
33252
      Node ac = getConstantBound(a[i], isLower);
788
24160
      if (ac.isNull())
789
      {
790
7994
        ret = ac;
791
7994
        success = false;
792
7994
        break;
793
      }
794
      else
795
      {
796
16166
        if (ac.getConst<Rational>().sgn() == 0)
797
        {
798
1080
          if (a.getKind() == kind::MULT)
799
          {
800
20
            ret = ac;
801
20
            success = false;
802
20
            break;
803
          }
804
        }
805
        else
806
        {
807
15086
          if (a.getKind() == kind::MULT)
808
          {
809
7082
            if ((ac.getConst<Rational>().sgn() > 0) != isLower)
810
            {
811
7054
              ret = Node::null();
812
7054
              success = false;
813
7054
              break;
814
            }
815
          }
816
8032
          children.push_back(ac);
817
        }
818
      }
819
    }
820
15575
    if (success)
821
    {
822
507
      if (children.empty())
823
      {
824
119
        ret = d_zero;
825
      }
826
388
      else if (children.size() == 1)
827
      {
828
388
        ret = children[0];
829
      }
830
      else
831
      {
832
        ret = NodeManager::currentNM()->mkNode(a.getKind(), children);
833
        ret = d_rr->rewrite(ret);
834
      }
835
    }
836
  }
837
47392
  Trace("strings-rewrite-cbound")
838
23696
      << "Constant " << (isLower ? "lower" : "upper") << " bound for " << a
839
23696
      << " is " << ret << std::endl;
840
23696
  Assert(ret.isNull() || ret.isConst());
841
  // entailment check should be at least as powerful as computing a lower bound
842
23696
  Assert(!isLower || ret.isNull() || ret.getConst<Rational>().sgn() < 0
843
         || check(a, false));
844
23696
  Assert(!isLower || ret.isNull() || ret.getConst<Rational>().sgn() <= 0
845
         || check(a, true));
846
  // cache
847
23696
  setConstantBoundCache(a, ret, isLower);
848
23696
  return ret;
849
}
850
851
551943
Node ArithEntail::getConstantBoundLength(Node s, bool isLower)
852
{
853
551943
  Assert(s.getType().isStringLike());
854
551943
  Node ret = getConstantBoundCache(s, isLower);
855
551943
  if (!ret.isNull())
856
  {
857
489736
    return ret;
858
  }
859
62207
  NodeManager* nm = NodeManager::currentNM();
860
62207
  if (s.isConst())
861
  {
862
6040
    ret = nm->mkConst(Rational(Word::getLength(s)));
863
  }
864
56167
  else if (s.getKind() == STRING_CONCAT)
865
  {
866
32822
    Rational sum(0);
867
16411
    bool success = true;
868
30766
    for (const Node& sc : s)
869
    {
870
39858
      Node b = getConstantBoundLength(sc, isLower);
871
25503
      if (b.isNull())
872
      {
873
11148
        if (isLower)
874
        {
875
          // assume zero and continue
876
          continue;
877
        }
878
11148
        success = false;
879
11148
        break;
880
      }
881
14355
      Assert(b.getKind() == CONST_RATIONAL);
882
14355
      sum = sum + b.getConst<Rational>();
883
    }
884
16411
    if (success)
885
    {
886
5263
      ret = nm->mkConst(sum);
887
    }
888
  }
889
39756
  else if (isLower)
890
  {
891
8777
    ret = d_zero;
892
  }
893
  // cache
894
62207
  setConstantBoundCache(s, ret, isLower);
895
62207
  return ret;
896
}
897
898
164141
bool ArithEntail::checkInternal(Node a)
899
{
900
164141
  Assert(d_rr->rewrite(a) == a);
901
  // check whether a >= 0
902
164141
  if (a.isConst())
903
  {
904
71928
    return a.getConst<Rational>().sgn() >= 0;
905
  }
906
92213
  else if (a.getKind() == kind::STRING_LENGTH)
907
  {
908
    // str.len( t ) >= 0
909
9574
    return true;
910
  }
911
82639
  else if (a.getKind() == kind::PLUS || a.getKind() == kind::MULT)
912
  {
913
103549
    for (unsigned i = 0; i < a.getNumChildren(); i++)
914
    {
915
101861
      if (!checkInternal(a[i]))
916
      {
917
77999
        return false;
918
      }
919
    }
920
    // t1 >= 0 ^ ... ^ tn >= 0 => t1 op ... op tn >= 0
921
1688
    return true;
922
  }
923
924
2952
  return false;
925
}
926
927
934
bool ArithEntail::inferZerosInSumGeq(Node x,
928
                                     std::vector<Node>& ys,
929
                                     std::vector<Node>& zeroYs)
930
{
931
934
  Assert(zeroYs.empty());
932
933
934
  NodeManager* nm = NodeManager::currentNM();
934
935
  // Check if we can show that y1 + ... + yn >= x
936
1868
  Node sum = (ys.size() > 1) ? nm->mkNode(PLUS, ys) : ys[0];
937
934
  if (!check(sum, x))
938
  {
939
830
    return false;
940
  }
941
942
  // Try to remove yi one-by-one and check if we can still show:
943
  //
944
  // y1 + ... + yi-1 +  yi+1 + ... + yn >= x
945
  //
946
  // If that's the case, we know that yi can be zero and the inequality still
947
  // holds.
948
104
  size_t i = 0;
949
548
  while (i < ys.size())
950
  {
951
444
    Node yi = ys[i];
952
222
    std::vector<Node>::iterator pos = ys.erase(ys.begin() + i);
953
222
    if (ys.size() > 1)
954
    {
955
66
      sum = nm->mkNode(PLUS, ys);
956
    }
957
    else
958
    {
959
156
      sum = ys.size() == 1 ? ys[0] : d_zero;
960
    }
961
962
222
    if (check(sum, x))
963
    {
964
42
      zeroYs.push_back(yi);
965
    }
966
    else
967
    {
968
180
      ys.insert(pos, yi);
969
180
      i++;
970
    }
971
  }
972
104
  return true;
973
}
974
975
}  // namespace strings
976
}  // namespace theory
977
31137
}  // namespace cvc5