GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/ext/factoring_check.cpp Lines: 109 110 99.1 %
Date: 2021-08-14 Branches: 224 410 54.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Gereon Kremer, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of factoring check.
14
 */
15
16
#include "theory/arith/nl/ext/factoring_check.h"
17
18
#include "expr/node.h"
19
#include "expr/skolem_manager.h"
20
#include "proof/proof.h"
21
#include "theory/arith/arith_msum.h"
22
#include "theory/arith/inference_manager.h"
23
#include "theory/arith/nl/ext/ext_state.h"
24
#include "theory/arith/nl/nl_model.h"
25
#include "theory/rewriter.h"
26
#include "util/rational.h"
27
28
namespace cvc5 {
29
namespace theory {
30
namespace arith {
31
namespace nl {
32
33
5145
FactoringCheck::FactoringCheck(ExtState* data) : d_data(data)
34
{
35
5145
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
36
5145
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
37
5145
}
38
39
276
void FactoringCheck::check(const std::vector<Node>& asserts,
40
                           const std::vector<Node>& false_asserts)
41
{
42
276
  NodeManager* nm = NodeManager::currentNM();
43
276
  Trace("nl-ext") << "Get factoring lemmas..." << std::endl;
44
7173
  for (const Node& lit : asserts)
45
  {
46
6897
    bool polarity = lit.getKind() != Kind::NOT;
47
13794
    Node atom = lit.getKind() == Kind::NOT ? lit[0] : lit;
48
13794
    Node litv = d_data->d_model.computeConcreteModelValue(lit);
49
6897
    bool considerLit = false;
50
    // Only consider literals that are in false_asserts.
51
6897
    considerLit = std::find(false_asserts.begin(), false_asserts.end(), lit)
52
13794
                  != false_asserts.end();
53
54
6897
    if (considerLit)
55
    {
56
3176
      std::map<Node, Node> msum;
57
1588
      if (ArithMSum::getMonomialSumLit(atom, msum))
58
      {
59
3176
        Trace("nl-ext-factor") << "Factoring for literal " << lit
60
1588
                               << ", monomial sum is : " << std::endl;
61
1588
        if (Trace.isOn("nl-ext-factor"))
62
        {
63
          ArithMSum::debugPrintMonomialSum(msum, "nl-ext-factor");
64
        }
65
3176
        std::map<Node, std::vector<Node> > factor_to_mono;
66
3176
        std::map<Node, std::vector<Node> > factor_to_mono_orig;
67
5100
        for (std::map<Node, Node>::iterator itm = msum.begin();
68
5100
             itm != msum.end();
69
             ++itm)
70
        {
71
3512
          if (!itm->first.isNull())
72
          {
73
2520
            if (itm->first.getKind() == Kind::NONLINEAR_MULT)
74
            {
75
324
              std::vector<Node> children;
76
489
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
77
              {
78
327
                children.push_back(itm->first[i]);
79
              }
80
324
              std::map<Node, bool> processed;
81
489
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
82
              {
83
327
                if (processed.find(itm->first[i]) == processed.end())
84
                {
85
294
                  processed[itm->first[i]] = true;
86
294
                  children[i] = d_one;
87
294
                  if (!itm->second.isNull())
88
                  {
89
154
                    children.push_back(itm->second);
90
                  }
91
588
                  Node val = nm->mkNode(Kind::MULT, children);
92
294
                  if (!itm->second.isNull())
93
                  {
94
154
                    children.pop_back();
95
                  }
96
294
                  children[i] = itm->first[i];
97
294
                  val = Rewriter::rewrite(val);
98
294
                  factor_to_mono[itm->first[i]].push_back(val);
99
294
                  factor_to_mono_orig[itm->first[i]].push_back(itm->first);
100
                }
101
              }
102
            }
103
          }
104
        }
105
274
        for (std::map<Node, std::vector<Node> >::iterator itf =
106
1588
                 factor_to_mono.begin();
107
1862
             itf != factor_to_mono.end();
108
             ++itf)
109
        {
110
326
          Node x = itf->first;
111
274
          if (itf->second.size() == 1)
112
          {
113
254
            std::map<Node, Node>::iterator itm = msum.find(x);
114
254
            if (itm != msum.end())
115
            {
116
32
              itf->second.push_back(itm->second.isNull() ? d_one : itm->second);
117
32
              factor_to_mono_orig[x].push_back(x);
118
            }
119
          }
120
274
          if (itf->second.size() <= 1)
121
          {
122
222
            continue;
123
          }
124
104
          Node sum = nm->mkNode(Kind::PLUS, itf->second);
125
52
          sum = Rewriter::rewrite(sum);
126
104
          Trace("nl-ext-factor")
127
52
              << "* Factored sum for " << x << " : " << sum << std::endl;
128
129
52
          CDProof* proof = nullptr;
130
52
          if (d_data->isProofEnabled())
131
          {
132
20
            proof = d_data->getProof();
133
          }
134
104
          Node kf = getFactorSkolem(sum, proof);
135
104
          std::vector<Node> poly;
136
52
          poly.push_back(nm->mkNode(Kind::MULT, x, kf));
137
          std::map<Node, std::vector<Node> >::iterator itfo =
138
52
              factor_to_mono_orig.find(x);
139
52
          Assert(itfo != factor_to_mono_orig.end());
140
236
          for (std::map<Node, Node>::iterator itm = msum.begin();
141
236
               itm != msum.end();
142
               ++itm)
143
          {
144
552
            if (std::find(itfo->second.begin(), itfo->second.end(), itm->first)
145
552
                == itfo->second.end())
146
            {
147
211
              poly.push_back(ArithMSum::mkCoeffTerm(
148
211
                  itm->second, itm->first.isNull() ? d_one : itm->first));
149
            }
150
          }
151
          Node polyn =
152
104
              poly.size() == 1 ? poly[0] : nm->mkNode(Kind::PLUS, poly);
153
104
          Trace("nl-ext-factor")
154
52
              << "...factored polynomial : " << polyn << std::endl;
155
104
          Node conc_lit = nm->mkNode(atom.getKind(), polyn, d_zero);
156
52
          conc_lit = Rewriter::rewrite(conc_lit);
157
52
          if (!polarity)
158
          {
159
10
            conc_lit = conc_lit.negate();
160
          }
161
162
104
          std::vector<Node> lemma_disj;
163
52
          lemma_disj.push_back(conc_lit);
164
52
          lemma_disj.push_back(lit.negate());
165
104
          Node flem = nm->mkNode(Kind::OR, lemma_disj);
166
52
          Trace("nl-ext-factor") << "...lemma is " << flem << std::endl;
167
52
          if (d_data->isProofEnabled())
168
          {
169
40
            Node k_eq = kf.eqNode(sum);
170
40
            Node split = nm->mkNode(Kind::OR, lit, lit.notNode());
171
20
            proof->addStep(split, PfRule::SPLIT, {}, {lit});
172
80
            proof->addStep(
173
60
                flem, PfRule::MACRO_SR_PRED_TRANSFORM, {split, k_eq}, {flem});
174
          }
175
52
          d_data->d_im.addPendingLemma(
176
              flem, InferenceId::ARITH_NL_FACTOR, proof);
177
        }
178
      }
179
    }
180
  }
181
276
}
182
183
52
Node FactoringCheck::getFactorSkolem(Node n, CDProof* proof)
184
{
185
52
  std::map<Node, Node>::iterator itf = d_factor_skolem.find(n);
186
52
  Node k;
187
52
  if (itf == d_factor_skolem.end())
188
  {
189
33
    NodeManager* nm = NodeManager::currentNM();
190
33
    k = nm->getSkolemManager()->mkPurifySkolem(n, "kf");
191
66
    Node k_eq = k.eqNode(n);
192
66
    Trace("nl-ext-factor") << "...adding factor skolem " << k << " == " << n
193
33
                           << std::endl;
194
33
    d_data->d_im.addPendingLemma(k_eq, InferenceId::ARITH_NL_FACTOR, proof);
195
33
    d_factor_skolem[n] = k;
196
  }
197
  else
198
  {
199
19
    k = itf->second;
200
  }
201
52
  if (d_data->isProofEnabled())
202
  {
203
40
    Node k_eq = k.eqNode(n);
204
20
    proof->addStep(k_eq, PfRule::MACRO_SR_PRED_INTRO, {}, {k_eq});
205
  }
206
52
  return k;
207
}
208
209
}  // namespace nl
210
}  // namespace arith
211
}  // namespace theory
212
29340
}  // namespace cvc5