GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/ext/factoring_check.cpp Lines: 110 111 99.1 %
Date: 2021-11-07 Branches: 224 410 54.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Gereon Kremer, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of factoring check.
14
 */
15
16
#include "theory/arith/nl/ext/factoring_check.h"
17
18
#include "expr/node.h"
19
#include "expr/skolem_manager.h"
20
#include "proof/proof.h"
21
#include "theory/arith/arith_msum.h"
22
#include "theory/arith/inference_manager.h"
23
#include "theory/arith/nl/ext/ext_state.h"
24
#include "theory/arith/nl/nl_model.h"
25
#include "theory/rewriter.h"
26
#include "util/rational.h"
27
28
namespace cvc5 {
29
namespace theory {
30
namespace arith {
31
namespace nl {
32
33
9696
FactoringCheck::FactoringCheck(Env& env, ExtState* data)
34
9696
    : EnvObj(env), d_data(data)
35
{
36
9696
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
37
9696
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
38
9696
}
39
40
427
void FactoringCheck::check(const std::vector<Node>& asserts,
41
                           const std::vector<Node>& false_asserts)
42
{
43
427
  NodeManager* nm = NodeManager::currentNM();
44
427
  Trace("nl-ext") << "Get factoring lemmas..." << std::endl;
45
29371
  for (const Node& lit : asserts)
46
  {
47
28944
    bool polarity = lit.getKind() != Kind::NOT;
48
57888
    Node atom = lit.getKind() == Kind::NOT ? lit[0] : lit;
49
57888
    Node litv = d_data->d_model.computeConcreteModelValue(lit);
50
28944
    bool considerLit = false;
51
    // Only consider literals that are in false_asserts.
52
28944
    considerLit = std::find(false_asserts.begin(), false_asserts.end(), lit)
53
57888
                  != false_asserts.end();
54
55
28944
    if (considerLit)
56
    {
57
5136
      std::map<Node, Node> msum;
58
2568
      if (ArithMSum::getMonomialSumLit(atom, msum))
59
      {
60
5136
        Trace("nl-ext-factor") << "Factoring for literal " << lit
61
2568
                               << ", monomial sum is : " << std::endl;
62
2568
        if (Trace.isOn("nl-ext-factor"))
63
        {
64
          ArithMSum::debugPrintMonomialSum(msum, "nl-ext-factor");
65
        }
66
5136
        std::map<Node, std::vector<Node> > factor_to_mono;
67
5136
        std::map<Node, std::vector<Node> > factor_to_mono_orig;
68
8960
        for (std::map<Node, Node>::iterator itm = msum.begin();
69
8960
             itm != msum.end();
70
             ++itm)
71
        {
72
6392
          if (!itm->first.isNull())
73
          {
74
4754
            if (itm->first.getKind() == Kind::NONLINEAR_MULT)
75
            {
76
748
              std::vector<Node> children;
77
1124
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
78
              {
79
750
                children.push_back(itm->first[i]);
80
              }
81
748
              std::map<Node, bool> processed;
82
1124
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
83
              {
84
750
                if (processed.find(itm->first[i]) == processed.end())
85
                {
86
703
                  processed[itm->first[i]] = true;
87
703
                  children[i] = d_one;
88
703
                  if (!itm->second.isNull())
89
                  {
90
407
                    children.push_back(itm->second);
91
                  }
92
1406
                  Node val = nm->mkNode(Kind::MULT, children);
93
703
                  if (!itm->second.isNull())
94
                  {
95
407
                    children.pop_back();
96
                  }
97
703
                  children[i] = itm->first[i];
98
703
                  val = rewrite(val);
99
703
                  factor_to_mono[itm->first[i]].push_back(val);
100
703
                  factor_to_mono_orig[itm->first[i]].push_back(itm->first);
101
                }
102
              }
103
            }
104
          }
105
        }
106
679
        for (std::map<Node, std::vector<Node> >::iterator itf =
107
2568
                 factor_to_mono.begin();
108
3247
             itf != factor_to_mono.end();
109
             ++itf)
110
        {
111
771
          Node x = itf->first;
112
679
          if (itf->second.size() == 1)
113
          {
114
655
            std::map<Node, Node>::iterator itm = msum.find(x);
115
655
            if (itm != msum.end())
116
            {
117
68
              itf->second.push_back(itm->second.isNull() ? d_one : itm->second);
118
68
              factor_to_mono_orig[x].push_back(x);
119
            }
120
          }
121
679
          if (itf->second.size() <= 1)
122
          {
123
587
            continue;
124
          }
125
184
          Node sum = nm->mkNode(Kind::PLUS, itf->second);
126
92
          sum = rewrite(sum);
127
184
          Trace("nl-ext-factor")
128
92
              << "* Factored sum for " << x << " : " << sum << std::endl;
129
130
92
          CDProof* proof = nullptr;
131
92
          if (d_data->isProofEnabled())
132
          {
133
19
            proof = d_data->getProof();
134
          }
135
184
          Node kf = getFactorSkolem(sum, proof);
136
184
          std::vector<Node> poly;
137
92
          poly.push_back(nm->mkNode(Kind::MULT, x, kf));
138
          std::map<Node, std::vector<Node> >::iterator itfo =
139
92
              factor_to_mono_orig.find(x);
140
92
          Assert(itfo != factor_to_mono_orig.end());
141
423
          for (std::map<Node, Node>::iterator itm = msum.begin();
142
423
               itm != msum.end();
143
               ++itm)
144
          {
145
993
            if (std::find(itfo->second.begin(), itfo->second.end(), itm->first)
146
993
                == itfo->second.end())
147
            {
148
380
              poly.push_back(ArithMSum::mkCoeffTerm(
149
380
                  itm->second, itm->first.isNull() ? d_one : itm->first));
150
            }
151
          }
152
          Node polyn =
153
184
              poly.size() == 1 ? poly[0] : nm->mkNode(Kind::PLUS, poly);
154
184
          Trace("nl-ext-factor")
155
92
              << "...factored polynomial : " << polyn << std::endl;
156
184
          Node conc_lit = nm->mkNode(atom.getKind(), polyn, d_zero);
157
92
          conc_lit = rewrite(conc_lit);
158
92
          if (!polarity)
159
          {
160
10
            conc_lit = conc_lit.negate();
161
          }
162
163
184
          std::vector<Node> lemma_disj;
164
92
          lemma_disj.push_back(conc_lit);
165
92
          lemma_disj.push_back(lit.negate());
166
184
          Node flem = nm->mkNode(Kind::OR, lemma_disj);
167
92
          Trace("nl-ext-factor") << "...lemma is " << flem << std::endl;
168
92
          if (d_data->isProofEnabled())
169
          {
170
38
            Node k_eq = kf.eqNode(sum);
171
38
            Node split = nm->mkNode(Kind::OR, lit, lit.notNode());
172
19
            proof->addStep(split, PfRule::SPLIT, {}, {lit});
173
76
            proof->addStep(
174
57
                flem, PfRule::MACRO_SR_PRED_TRANSFORM, {split, k_eq}, {flem});
175
          }
176
92
          d_data->d_im.addPendingLemma(
177
              flem, InferenceId::ARITH_NL_FACTOR, proof);
178
        }
179
      }
180
    }
181
  }
182
427
}
183
184
92
Node FactoringCheck::getFactorSkolem(Node n, CDProof* proof)
185
{
186
92
  std::map<Node, Node>::iterator itf = d_factor_skolem.find(n);
187
92
  Node k;
188
92
  if (itf == d_factor_skolem.end())
189
  {
190
57
    NodeManager* nm = NodeManager::currentNM();
191
57
    k = nm->getSkolemManager()->mkPurifySkolem(n, "kf");
192
114
    Node k_eq = k.eqNode(n);
193
114
    Trace("nl-ext-factor") << "...adding factor skolem " << k << " == " << n
194
57
                           << std::endl;
195
57
    d_data->d_im.addPendingLemma(k_eq, InferenceId::ARITH_NL_FACTOR, proof);
196
57
    d_factor_skolem[n] = k;
197
  }
198
  else
199
  {
200
35
    k = itf->second;
201
  }
202
92
  if (d_data->isProofEnabled())
203
  {
204
38
    Node k_eq = k.eqNode(n);
205
19
    proof->addStep(k_eq, PfRule::MACRO_SR_PRED_INTRO, {}, {k_eq});
206
  }
207
92
  return k;
208
}
209
210
}  // namespace nl
211
}  // namespace arith
212
}  // namespace theory
213
31137
}  // namespace cvc5