GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/arith/nl/ext/factoring_check.cpp Lines: 109 110 99.1 %
Date: 2021-09-10 Branches: 224 410 54.6 %

Line Exec Source
1
/******************************************************************************
2
 * Top contributors (to current version):
3
 *   Andrew Reynolds, Gereon Kremer, Tim King
4
 *
5
 * This file is part of the cvc5 project.
6
 *
7
 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 * in the top-level source directory and their institutional affiliations.
9
 * All rights reserved.  See the file COPYING in the top-level source
10
 * directory for licensing information.
11
 * ****************************************************************************
12
 *
13
 * Implementation of factoring check.
14
 */
15
16
#include "theory/arith/nl/ext/factoring_check.h"
17
18
#include "expr/node.h"
19
#include "expr/skolem_manager.h"
20
#include "proof/proof.h"
21
#include "theory/arith/arith_msum.h"
22
#include "theory/arith/inference_manager.h"
23
#include "theory/arith/nl/ext/ext_state.h"
24
#include "theory/arith/nl/nl_model.h"
25
#include "theory/rewriter.h"
26
#include "util/rational.h"
27
28
namespace cvc5 {
29
namespace theory {
30
namespace arith {
31
namespace nl {
32
33
5193
FactoringCheck::FactoringCheck(ExtState* data) : d_data(data)
34
{
35
5193
  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
36
5193
  d_one = NodeManager::currentNM()->mkConst(Rational(1));
37
5193
}
38
39
426
void FactoringCheck::check(const std::vector<Node>& asserts,
40
                           const std::vector<Node>& false_asserts)
41
{
42
426
  NodeManager* nm = NodeManager::currentNM();
43
426
  Trace("nl-ext") << "Get factoring lemmas..." << std::endl;
44
16825
  for (const Node& lit : asserts)
45
  {
46
16399
    bool polarity = lit.getKind() != Kind::NOT;
47
32798
    Node atom = lit.getKind() == Kind::NOT ? lit[0] : lit;
48
32798
    Node litv = d_data->d_model.computeConcreteModelValue(lit);
49
16399
    bool considerLit = false;
50
    // Only consider literals that are in false_asserts.
51
16399
    considerLit = std::find(false_asserts.begin(), false_asserts.end(), lit)
52
32798
                  != false_asserts.end();
53
54
16399
    if (considerLit)
55
    {
56
5042
      std::map<Node, Node> msum;
57
2521
      if (ArithMSum::getMonomialSumLit(atom, msum))
58
      {
59
5042
        Trace("nl-ext-factor") << "Factoring for literal " << lit
60
2521
                               << ", monomial sum is : " << std::endl;
61
2521
        if (Trace.isOn("nl-ext-factor"))
62
        {
63
          ArithMSum::debugPrintMonomialSum(msum, "nl-ext-factor");
64
        }
65
5042
        std::map<Node, std::vector<Node> > factor_to_mono;
66
5042
        std::map<Node, std::vector<Node> > factor_to_mono_orig;
67
8852
        for (std::map<Node, Node>::iterator itm = msum.begin();
68
8852
             itm != msum.end();
69
             ++itm)
70
        {
71
6331
          if (!itm->first.isNull())
72
          {
73
4710
            if (itm->first.getKind() == Kind::NONLINEAR_MULT)
74
            {
75
418
              std::vector<Node> children;
76
629
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
77
              {
78
420
                children.push_back(itm->first[i]);
79
              }
80
418
              std::map<Node, bool> processed;
81
629
              for (unsigned i = 0; i < itm->first.getNumChildren(); i++)
82
              {
83
420
                if (processed.find(itm->first[i]) == processed.end())
84
                {
85
388
                  processed[itm->first[i]] = true;
86
388
                  children[i] = d_one;
87
388
                  if (!itm->second.isNull())
88
                  {
89
208
                    children.push_back(itm->second);
90
                  }
91
776
                  Node val = nm->mkNode(Kind::MULT, children);
92
388
                  if (!itm->second.isNull())
93
                  {
94
208
                    children.pop_back();
95
                  }
96
388
                  children[i] = itm->first[i];
97
388
                  val = Rewriter::rewrite(val);
98
388
                  factor_to_mono[itm->first[i]].push_back(val);
99
388
                  factor_to_mono_orig[itm->first[i]].push_back(itm->first);
100
                }
101
              }
102
            }
103
          }
104
        }
105
354
        for (std::map<Node, std::vector<Node> >::iterator itf =
106
2521
                 factor_to_mono.begin();
107
2875
             itf != factor_to_mono.end();
108
             ++itf)
109
        {
110
425
          Node x = itf->first;
111
354
          if (itf->second.size() == 1)
112
          {
113
320
            std::map<Node, Node>::iterator itm = msum.find(x);
114
320
            if (itm != msum.end())
115
            {
116
37
              itf->second.push_back(itm->second.isNull() ? d_one : itm->second);
117
37
              factor_to_mono_orig[x].push_back(x);
118
            }
119
          }
120
354
          if (itf->second.size() <= 1)
121
          {
122
283
            continue;
123
          }
124
142
          Node sum = nm->mkNode(Kind::PLUS, itf->second);
125
71
          sum = Rewriter::rewrite(sum);
126
142
          Trace("nl-ext-factor")
127
71
              << "* Factored sum for " << x << " : " << sum << std::endl;
128
129
71
          CDProof* proof = nullptr;
130
71
          if (d_data->isProofEnabled())
131
          {
132
21
            proof = d_data->getProof();
133
          }
134
142
          Node kf = getFactorSkolem(sum, proof);
135
142
          std::vector<Node> poly;
136
71
          poly.push_back(nm->mkNode(Kind::MULT, x, kf));
137
          std::map<Node, std::vector<Node> >::iterator itfo =
138
71
              factor_to_mono_orig.find(x);
139
71
          Assert(itfo != factor_to_mono_orig.end());
140
336
          for (std::map<Node, Node>::iterator itm = msum.begin();
141
336
               itm != msum.end();
142
               ++itm)
143
          {
144
795
            if (std::find(itfo->second.begin(), itfo->second.end(), itm->first)
145
795
                == itfo->second.end())
146
            {
147
330
              poly.push_back(ArithMSum::mkCoeffTerm(
148
330
                  itm->second, itm->first.isNull() ? d_one : itm->first));
149
            }
150
          }
151
          Node polyn =
152
142
              poly.size() == 1 ? poly[0] : nm->mkNode(Kind::PLUS, poly);
153
142
          Trace("nl-ext-factor")
154
71
              << "...factored polynomial : " << polyn << std::endl;
155
142
          Node conc_lit = nm->mkNode(atom.getKind(), polyn, d_zero);
156
71
          conc_lit = Rewriter::rewrite(conc_lit);
157
71
          if (!polarity)
158
          {
159
18
            conc_lit = conc_lit.negate();
160
          }
161
162
142
          std::vector<Node> lemma_disj;
163
71
          lemma_disj.push_back(conc_lit);
164
71
          lemma_disj.push_back(lit.negate());
165
142
          Node flem = nm->mkNode(Kind::OR, lemma_disj);
166
71
          Trace("nl-ext-factor") << "...lemma is " << flem << std::endl;
167
71
          if (d_data->isProofEnabled())
168
          {
169
42
            Node k_eq = kf.eqNode(sum);
170
42
            Node split = nm->mkNode(Kind::OR, lit, lit.notNode());
171
21
            proof->addStep(split, PfRule::SPLIT, {}, {lit});
172
84
            proof->addStep(
173
63
                flem, PfRule::MACRO_SR_PRED_TRANSFORM, {split, k_eq}, {flem});
174
          }
175
71
          d_data->d_im.addPendingLemma(
176
              flem, InferenceId::ARITH_NL_FACTOR, proof);
177
        }
178
      }
179
    }
180
  }
181
426
}
182
183
71
Node FactoringCheck::getFactorSkolem(Node n, CDProof* proof)
184
{
185
71
  std::map<Node, Node>::iterator itf = d_factor_skolem.find(n);
186
71
  Node k;
187
71
  if (itf == d_factor_skolem.end())
188
  {
189
47
    NodeManager* nm = NodeManager::currentNM();
190
47
    k = nm->getSkolemManager()->mkPurifySkolem(n, "kf");
191
94
    Node k_eq = k.eqNode(n);
192
94
    Trace("nl-ext-factor") << "...adding factor skolem " << k << " == " << n
193
47
                           << std::endl;
194
47
    d_data->d_im.addPendingLemma(k_eq, InferenceId::ARITH_NL_FACTOR, proof);
195
47
    d_factor_skolem[n] = k;
196
  }
197
  else
198
  {
199
24
    k = itf->second;
200
  }
201
71
  if (d_data->isProofEnabled())
202
  {
203
42
    Node k_eq = k.eqNode(n);
204
21
    proof->addStep(k_eq, PfRule::MACRO_SR_PRED_INTRO, {}, {k_eq});
205
  }
206
71
  return k;
207
}
208
209
}  // namespace nl
210
}  // namespace arith
211
}  // namespace theory
212
29502
}  // namespace cvc5