GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/ext/factoring_check.cpp Lines: 109 110 99.1 %
Date: 2021-05-22 Branches: 224 412 54.4 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Gereon Kremer, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of factoring check.
14
 */
15
16
#include "theory/arith/nl/ext/factoring_check.h"
17
18
#include "expr/node.h"
19
#include "expr/proof.h"
20
#include "expr/skolem_manager.h"
21
#include "theory/arith/arith_msum.h"
22
#include "theory/arith/inference_manager.h"
23
#include "theory/arith/nl/nl_model.h"
24
#include "theory/arith/nl/ext/ext_state.h"
25
#include "theory/rewriter.h"
26
27
namespace cvc5 {
28
namespace theory {
29
namespace arith {
30
namespace nl {
31
32
4914
FactoringCheck::FactoringCheck(ExtState* data) : d_data(data)
33
{
34
4914
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
35
4914
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
36
4914
}
37
38
262
void FactoringCheck::check(const std::vector<Node>& asserts,
39
                           const std::vector<Node>& false_asserts)
40
{
41
262
  NodeManager* nm = NodeManager::currentNM();
42
262
  Trace("nl-ext") << "Get factoring lemmas..." << std::endl;
43
7551
  for (const Node& lit : asserts)
44
  {
45
7289
    bool polarity = lit.getKind() != Kind::NOT;
46
14578
    Node atom = lit.getKind() == Kind::NOT ? lit[0] : lit;
47
14578
    Node litv = d_data->d_model.computeConcreteModelValue(lit);
48
7289
    bool considerLit = false;
49
    // Only consider literals that are in false_asserts.
50
7289
    considerLit = std::find(false_asserts.begin(), false_asserts.end(), lit)
51
14578
                  != false_asserts.end();
52
53
7289
    if (considerLit)
54
    {
55
3152
      std::map<Node, Node> msum;
56
1576
      if (ArithMSum::getMonomialSumLit(atom, msum))
57
      {
58
3152
        Trace("nl-ext-factor") << "Factoring for literal " << lit
59
1576
                               << ", monomial sum is : " << std::endl;
60
1576
        if (Trace.isOn("nl-ext-factor"))
61
        {
62
          ArithMSum::debugPrintMonomialSum(msum, "nl-ext-factor");
63
        }
64
3152
        std::map<Node, std::vector<Node> > factor_to_mono;
65
3152
        std::map<Node, std::vector<Node> > factor_to_mono_orig;
66
5096
        for (std::map<Node, Node>::iterator itm = msum.begin();
67
5096
             itm != msum.end();
68
             ++itm)
69
        {
70
3520
          if (!itm->first.isNull())
71
          {
72
2525
            if (itm->first.getKind() == Kind::NONLINEAR_MULT)
73
            {
74
390
              std::vector<Node> children;
75
629
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
76
              {
77
434
                children.push_back(itm->first[i]);
78
              }
79
390
              std::map<Node, bool> processed;
80
629
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
81
              {
82
434
                if (processed.find(itm->first[i]) == processed.end())
83
                {
84
350
                  processed[itm->first[i]] = true;
85
350
                  children[i] = d_one;
86
350
                  if (!itm->second.isNull())
87
                  {
88
188
                    children.push_back(itm->second);
89
                  }
90
700
                  Node val = nm->mkNode(Kind::MULT, children);
91
350
                  if (!itm->second.isNull())
92
                  {
93
188
                    children.pop_back();
94
                  }
95
350
                  children[i] = itm->first[i];
96
350
                  val = Rewriter::rewrite(val);
97
350
                  factor_to_mono[itm->first[i]].push_back(val);
98
350
                  factor_to_mono_orig[itm->first[i]].push_back(itm->first);
99
                }
100
              }
101
            }
102
          }
103
        }
104
318
        for (std::map<Node, std::vector<Node> >::iterator itf =
105
1576
                 factor_to_mono.begin();
106
1894
             itf != factor_to_mono.end();
107
             ++itf)
108
        {
109
383
          Node x = itf->first;
110
318
          if (itf->second.size() == 1)
111
          {
112
292
            std::map<Node, Node>::iterator itm = msum.find(x);
113
292
            if (itm != msum.end())
114
            {
115
39
              itf->second.push_back(itm->second.isNull() ? d_one : itm->second);
116
39
              factor_to_mono_orig[x].push_back(x);
117
            }
118
          }
119
318
          if (itf->second.size() <= 1)
120
          {
121
253
            continue;
122
          }
123
130
          Node sum = nm->mkNode(Kind::PLUS, itf->second);
124
65
          sum = Rewriter::rewrite(sum);
125
130
          Trace("nl-ext-factor")
126
65
              << "* Factored sum for " << x << " : " << sum << std::endl;
127
128
65
          CDProof* proof = nullptr;
129
65
          if (d_data->isProofEnabled())
130
          {
131
23
            proof = d_data->getProof();
132
          }
133
130
          Node kf = getFactorSkolem(sum, proof);
134
130
          std::vector<Node> poly;
135
65
          poly.push_back(nm->mkNode(Kind::MULT, x, kf));
136
          std::map<Node, std::vector<Node> >::iterator itfo =
137
65
              factor_to_mono_orig.find(x);
138
65
          Assert(itfo != factor_to_mono_orig.end());
139
302
          for (std::map<Node, Node>::iterator itm = msum.begin();
140
302
               itm != msum.end();
141
               ++itm)
142
          {
143
711
            if (std::find(itfo->second.begin(), itfo->second.end(), itm->first)
144
711
                == itfo->second.end())
145
            {
146
269
              poly.push_back(ArithMSum::mkCoeffTerm(
147
269
                  itm->second, itm->first.isNull() ? d_one : itm->first));
148
            }
149
          }
150
          Node polyn =
151
130
              poly.size() == 1 ? poly[0] : nm->mkNode(Kind::PLUS, poly);
152
130
          Trace("nl-ext-factor")
153
65
              << "...factored polynomial : " << polyn << std::endl;
154
130
          Node conc_lit = nm->mkNode(atom.getKind(), polyn, d_zero);
155
65
          conc_lit = Rewriter::rewrite(conc_lit);
156
65
          if (!polarity)
157
          {
158
14
            conc_lit = conc_lit.negate();
159
          }
160
161
130
          std::vector<Node> lemma_disj;
162
65
          lemma_disj.push_back(conc_lit);
163
65
          lemma_disj.push_back(lit.negate());
164
130
          Node flem = nm->mkNode(Kind::OR, lemma_disj);
165
65
          Trace("nl-ext-factor") << "...lemma is " << flem << std::endl;
166
65
          if (d_data->isProofEnabled())
167
          {
168
46
            Node k_eq = kf.eqNode(sum);
169
46
            Node split = nm->mkNode(Kind::OR, lit, lit.notNode());
170
23
            proof->addStep(split, PfRule::SPLIT, {}, {lit});
171
92
            proof->addStep(
172
69
                flem, PfRule::MACRO_SR_PRED_TRANSFORM, {split, k_eq}, {flem});
173
          }
174
65
          d_data->d_im.addPendingLemma(
175
              flem, InferenceId::ARITH_NL_FACTOR, proof);
176
        }
177
      }
178
    }
179
  }
180
262
}
181
182
65
Node FactoringCheck::getFactorSkolem(Node n, CDProof* proof)
183
{
184
65
  std::map<Node, Node>::iterator itf = d_factor_skolem.find(n);
185
65
  Node k;
186
65
  if (itf == d_factor_skolem.end())
187
  {
188
46
    NodeManager* nm = NodeManager::currentNM();
189
46
    k = nm->getSkolemManager()->mkPurifySkolem(n, "kf");
190
92
    Node k_eq = k.eqNode(n);
191
92
    Trace("nl-ext-factor") << "...adding factor skolem " << k << " == " << n
192
46
                           << std::endl;
193
46
    d_data->d_im.addPendingLemma(k_eq, InferenceId::ARITH_NL_FACTOR, proof);
194
46
    d_factor_skolem[n] = k;
195
  }
196
  else
197
  {
198
19
    k = itf->second;
199
  }
200
65
  if (d_data->isProofEnabled())
201
  {
202
46
    Node k_eq = k.eqNode(n);
203
23
    proof->addStep(k_eq, PfRule::MACRO_SR_PRED_INTRO, {}, {k_eq});
204
  }
205
65
  return k;
206
}
207
208
}  // namespace nl
209
}  // namespace arith
210
}  // namespace theory
211
28191
}  // namespace cvc5