GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/ext/factoring_check.cpp Lines: 109 110 99.1 %
Date: 2021-03-23 Branches: 224 412 54.4 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file factoring_check.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Gereon Kremer, Tim King
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of factoring check
13
 **/
14
15
#include "theory/arith/nl/ext/factoring_check.h"
16
17
#include "expr/node.h"
18
#include "expr/proof.h"
19
#include "expr/skolem_manager.h"
20
#include "theory/arith/arith_msum.h"
21
#include "theory/arith/inference_manager.h"
22
#include "theory/arith/nl/nl_model.h"
23
#include "theory/arith/nl/ext/ext_state.h"
24
#include "theory/rewriter.h"
25
26
namespace CVC4 {
27
namespace theory {
28
namespace arith {
29
namespace nl {
30
31
4391
FactoringCheck::FactoringCheck(ExtState* data) : d_data(data)
32
{
33
4391
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
34
4391
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
35
4391
}
36
37
255
void FactoringCheck::check(const std::vector<Node>& asserts,
38
                           const std::vector<Node>& false_asserts)
39
{
40
255
  NodeManager* nm = NodeManager::currentNM();
41
255
  Trace("nl-ext") << "Get factoring lemmas..." << std::endl;
42
8561
  for (const Node& lit : asserts)
43
  {
44
8306
    bool polarity = lit.getKind() != Kind::NOT;
45
16612
    Node atom = lit.getKind() == Kind::NOT ? lit[0] : lit;
46
16612
    Node litv = d_data->d_model.computeConcreteModelValue(lit);
47
8306
    bool considerLit = false;
48
    // Only consider literals that are in false_asserts.
49
8306
    considerLit = std::find(false_asserts.begin(), false_asserts.end(), lit)
50
16612
                  != false_asserts.end();
51
52
8306
    if (considerLit)
53
    {
54
3142
      std::map<Node, Node> msum;
55
1571
      if (ArithMSum::getMonomialSumLit(atom, msum))
56
      {
57
3142
        Trace("nl-ext-factor") << "Factoring for literal " << lit
58
1571
                               << ", monomial sum is : " << std::endl;
59
1571
        if (Trace.isOn("nl-ext-factor"))
60
        {
61
          ArithMSum::debugPrintMonomialSum(msum, "nl-ext-factor");
62
        }
63
3142
        std::map<Node, std::vector<Node> > factor_to_mono;
64
3142
        std::map<Node, std::vector<Node> > factor_to_mono_orig;
65
5095
        for (std::map<Node, Node>::iterator itm = msum.begin();
66
5095
             itm != msum.end();
67
             ++itm)
68
        {
69
3524
          if (!itm->first.isNull())
70
          {
71
2535
            if (itm->first.getKind() == Kind::NONLINEAR_MULT)
72
            {
73
450
              std::vector<Node> children;
74
759
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
75
              {
76
534
                children.push_back(itm->first[i]);
77
              }
78
450
              std::map<Node, bool> processed;
79
759
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
80
              {
81
534
                if (processed.find(itm->first[i]) == processed.end())
82
                {
83
371
                  processed[itm->first[i]] = true;
84
371
                  children[i] = d_one;
85
371
                  if (!itm->second.isNull())
86
                  {
87
200
                    children.push_back(itm->second);
88
                  }
89
742
                  Node val = nm->mkNode(Kind::MULT, children);
90
371
                  if (!itm->second.isNull())
91
                  {
92
200
                    children.pop_back();
93
                  }
94
371
                  children[i] = itm->first[i];
95
371
                  val = Rewriter::rewrite(val);
96
371
                  factor_to_mono[itm->first[i]].push_back(val);
97
371
                  factor_to_mono_orig[itm->first[i]].push_back(itm->first);
98
                }
99
              }
100
            }
101
          }
102
        }
103
327
        for (std::map<Node, std::vector<Node> >::iterator itf =
104
1571
                 factor_to_mono.begin();
105
1898
             itf != factor_to_mono.end();
106
             ++itf)
107
        {
108
395
          Node x = itf->first;
109
327
          if (itf->second.size() == 1)
110
          {
111
295
            std::map<Node, Node>::iterator itm = msum.find(x);
112
295
            if (itm != msum.end())
113
            {
114
36
              itf->second.push_back(itm->second.isNull() ? d_one : itm->second);
115
36
              factor_to_mono_orig[x].push_back(x);
116
            }
117
          }
118
327
          if (itf->second.size() <= 1)
119
          {
120
259
            continue;
121
          }
122
136
          Node sum = nm->mkNode(Kind::PLUS, itf->second);
123
68
          sum = Rewriter::rewrite(sum);
124
136
          Trace("nl-ext-factor")
125
68
              << "* Factored sum for " << x << " : " << sum << std::endl;
126
127
68
          CDProof* proof = nullptr;
128
68
          if (d_data->isProofEnabled())
129
          {
130
18
            proof = d_data->getProof();
131
          }
132
136
          Node kf = getFactorSkolem(sum, proof);
133
136
          std::vector<Node> poly;
134
68
          poly.push_back(nm->mkNode(Kind::MULT, x, kf));
135
          std::map<Node, std::vector<Node> >::iterator itfo =
136
68
              factor_to_mono_orig.find(x);
137
68
          Assert(itfo != factor_to_mono_orig.end());
138
335
          for (std::map<Node, Node>::iterator itm = msum.begin();
139
335
               itm != msum.end();
140
               ++itm)
141
          {
142
801
            if (std::find(itfo->second.begin(), itfo->second.end(), itm->first)
143
801
                == itfo->second.end())
144
            {
145
331
              poly.push_back(ArithMSum::mkCoeffTerm(
146
331
                  itm->second, itm->first.isNull() ? d_one : itm->first));
147
            }
148
          }
149
          Node polyn =
150
136
              poly.size() == 1 ? poly[0] : nm->mkNode(Kind::PLUS, poly);
151
136
          Trace("nl-ext-factor")
152
68
              << "...factored polynomial : " << polyn << std::endl;
153
136
          Node conc_lit = nm->mkNode(atom.getKind(), polyn, d_zero);
154
68
          conc_lit = Rewriter::rewrite(conc_lit);
155
68
          if (!polarity)
156
          {
157
17
            conc_lit = conc_lit.negate();
158
          }
159
160
136
          std::vector<Node> lemma_disj;
161
68
          lemma_disj.push_back(conc_lit);
162
68
          lemma_disj.push_back(lit.negate());
163
136
          Node flem = nm->mkNode(Kind::OR, lemma_disj);
164
68
          Trace("nl-ext-factor") << "...lemma is " << flem << std::endl;
165
68
          if (d_data->isProofEnabled())
166
          {
167
36
            Node k_eq = kf.eqNode(sum);
168
36
            Node split = nm->mkNode(Kind::OR, lit, lit.notNode());
169
18
            proof->addStep(split, PfRule::SPLIT, {}, {lit});
170
72
            proof->addStep(
171
54
                flem, PfRule::MACRO_SR_PRED_TRANSFORM, {split, k_eq}, {flem});
172
          }
173
68
          d_data->d_im.addPendingLemma(
174
              flem, InferenceId::ARITH_NL_FACTOR, proof);
175
        }
176
      }
177
    }
178
  }
179
255
}
180
181
68
Node FactoringCheck::getFactorSkolem(Node n, CDProof* proof)
182
{
183
68
  std::map<Node, Node>::iterator itf = d_factor_skolem.find(n);
184
68
  Node k;
185
68
  if (itf == d_factor_skolem.end())
186
  {
187
55
    NodeManager* nm = NodeManager::currentNM();
188
55
    k = nm->getSkolemManager()->mkPurifySkolem(n, "kf");
189
110
    Node k_eq = k.eqNode(n);
190
110
    Trace("nl-ext-factor") << "...adding factor skolem " << k << " == " << n
191
55
                           << std::endl;
192
55
    d_data->d_im.addPendingLemma(k_eq, InferenceId::ARITH_NL_FACTOR, proof);
193
55
    d_factor_skolem[n] = k;
194
  }
195
  else
196
  {
197
13
    k = itf->second;
198
  }
199
68
  if (d_data->isProofEnabled())
200
  {
201
36
    Node k_eq = k.eqNode(n);
202
18
    proof->addStep(k_eq, PfRule::MACRO_SR_PRED_INTRO, {}, {k_eq});
203
  }
204
68
  return k;
205
}
206
207
}  // namespace nl
208
}  // namespace arith
209
}  // namespace theory
210
26685
}  // namespace CVC4