GCC Code Coverage Report
Directory: . Exec Total Coverage
File: src/theory/quantifiers/sygus/sygus_unif_strat.cpp Lines: 474 570 83.2 %
Date: 2021-03-22 Branches: 928 2166 42.8 %

Line Exec Source
1
/*********************                                                        */
2
/*! \file sygus_unif_strat.cpp
3
 ** \verbatim
4
 ** Top contributors (to current version):
5
 **   Andrew Reynolds, Haniel Barbosa, Mathias Preiner
6
 ** This file is part of the CVC4 project.
7
 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8
 ** in the top-level source directory and their institutional affiliations.
9
 ** All rights reserved.  See the file COPYING in the top-level source
10
 ** directory for licensing information.\endverbatim
11
 **
12
 ** \brief Implementation of sygus_unif_strat
13
 **/
14
15
#include "theory/quantifiers/sygus/sygus_unif_strat.h"
16
17
#include "expr/dtype.h"
18
#include "expr/dtype_cons.h"
19
#include "theory/datatypes/theory_datatypes_utils.h"
20
#include "theory/quantifiers/sygus/sygus_eval_unfold.h"
21
#include "theory/quantifiers/sygus/sygus_unif.h"
22
#include "theory/quantifiers/sygus/term_database_sygus.h"
23
#include "theory/quantifiers/term_util.h"
24
#include "theory/quantifiers_engine.h"
25
#include "theory/rewriter.h"
26
27
using namespace std;
28
using namespace CVC4::kind;
29
30
namespace CVC4 {
31
namespace theory {
32
namespace quantifiers {
33
34
std::ostream& operator<<(std::ostream& os, EnumRole r)
35
{
36
  switch (r)
37
  {
38
    case enum_invalid: os << "INVALID"; break;
39
    case enum_io: os << "IO"; break;
40
    case enum_ite_condition: os << "CONDITION"; break;
41
    case enum_concat_term: os << "CTERM"; break;
42
    default: os << "enum_" << static_cast<unsigned>(r); break;
43
  }
44
  return os;
45
}
46
47
std::ostream& operator<<(std::ostream& os, NodeRole r)
48
{
49
  switch (r)
50
  {
51
    case role_equal: os << "equal"; break;
52
    case role_string_prefix: os << "string_prefix"; break;
53
    case role_string_suffix: os << "string_suffix"; break;
54
    case role_ite_condition: os << "ite_condition"; break;
55
    default: os << "role_" << static_cast<unsigned>(r); break;
56
  }
57
  return os;
58
}
59
60
647
EnumRole getEnumeratorRoleForNodeRole(NodeRole r)
61
{
62
647
  switch (r)
63
  {
64
297
    case role_equal: return enum_io; break;
65
112
    case role_string_prefix: return enum_concat_term; break;
66
118
    case role_string_suffix: return enum_concat_term; break;
67
120
    case role_ite_condition: return enum_ite_condition; break;
68
    default: break;
69
  }
70
  return enum_invalid;
71
}
72
73
std::ostream& operator<<(std::ostream& os, StrategyType st)
74
{
75
  switch (st)
76
  {
77
    case strat_ITE: os << "ITE"; break;
78
    case strat_CONCAT_PREFIX: os << "CONCAT_PREFIX"; break;
79
    case strat_CONCAT_SUFFIX: os << "CONCAT_SUFFIX"; break;
80
    case strat_ID: os << "ID"; break;
81
    default: os << "strat_" << static_cast<unsigned>(st); break;
82
  }
83
  return os;
84
}
85
86
107
void SygusUnifStrategy::initialize(QuantifiersEngine* qe,
87
                                   Node f,
88
                                   std::vector<Node>& enums)
89
{
90
107
  Assert(d_candidate.isNull());
91
107
  d_candidate = f;
92
107
  d_root = f.getType();
93
107
  d_qe = qe;
94
95
  // collect the enumerator types and form the strategy
96
107
  buildStrategyGraph(d_root, role_equal);
97
  // add the enumerators
98
107
  enums.insert(enums.end(), d_esym_list.begin(), d_esym_list.end());
99
  // finish the initialization of the strategy
100
  // this computes if each node is conditional
101
214
  std::map<Node, std::map<NodeRole, bool> > visited;
102
107
  finishInit(getRootEnumerator(), role_equal, visited, false);
103
107
}
104
105
153
void SygusUnifStrategy::initializeType(TypeNode tn)
106
{
107
306
  Trace("sygus-unif") << "SygusUnifStrategy: initialize : " << tn << " for "
108
153
                      << d_candidate << std::endl;
109
153
  d_tinfo[tn].d_this_type = tn;
110
153
}
111
112
2514
Node SygusUnifStrategy::getRootEnumerator() const
113
{
114
2514
  std::map<TypeNode, EnumTypeInfo>::const_iterator itt = d_tinfo.find(d_root);
115
2514
  Assert(itt != d_tinfo.end());
116
  std::map<EnumRole, Node>::const_iterator it =
117
2514
      itt->second.d_enum.find(enum_io);
118
2514
  Assert(it != itt->second.d_enum.end());
119
2514
  return it->second;
120
}
121
122
12317
EnumInfo& SygusUnifStrategy::getEnumInfo(Node e)
123
{
124
12317
  std::map<Node, EnumInfo>::iterator it = d_einfo.find(e);
125
12317
  Assert(it != d_einfo.end());
126
12317
  return it->second;
127
}
128
129
8519
EnumTypeInfo& SygusUnifStrategy::getEnumTypeInfo(TypeNode tn)
130
{
131
17038
  Trace("sygus-unif") << "SygusUnifStrategy: get : " << tn << " for "
132
8519
                      << d_candidate << std::endl;
133
8519
  std::map<TypeNode, EnumTypeInfo>::iterator it = d_tinfo.find(tn);
134
8519
  Assert(it != d_tinfo.end());
135
8519
  return it->second;
136
}
137
// ----------------------------- establishing enumeration types
138
139
287
void SygusUnifStrategy::registerStrategyPoint(Node et,
140
                                           TypeNode tn,
141
                                           EnumRole enum_role,
142
                                           bool inSearch)
143
{
144
287
  if (d_einfo.find(et) == d_einfo.end())
145
  {
146
514
    Trace("sygus-unif-debug")
147
257
        << "...register " << et << " for " << tn.getDType().getName();
148
514
    Trace("sygus-unif-debug") << ", role = " << enum_role
149
257
                              << ", in search = " << inSearch << std::endl;
150
257
    d_einfo[et].initialize(enum_role);
151
    // if we are actually enumerating this (could be a compound node in the
152
    // strategy)
153
257
    if (inSearch)
154
    {
155
250
      std::map<TypeNode, Node>::iterator itn = d_master_enum.find(tn);
156
250
      if (itn == d_master_enum.end())
157
      {
158
        // use this for the search
159
146
        d_master_enum[tn] = et;
160
146
        d_esym_list.push_back(et);
161
146
        d_einfo[et].d_enum_slave.push_back(et);
162
      }
163
      else
164
      {
165
208
        Trace("sygus-unif-debug") << "Make " << et << " a slave of "
166
104
                                  << itn->second << std::endl;
167
104
        d_einfo[itn->second].d_enum_slave.push_back(et);
168
      }
169
    }
170
  }
171
287
}
172
173
510
void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole)
174
{
175
510
  NodeManager* nm = NodeManager::currentNM();
176
510
  if (d_tinfo.find(tn) == d_tinfo.end())
177
  {
178
    // register type
179
153
    Trace("sygus-unif") << "Register enumerating type : " << tn << std::endl;
180
153
    initializeType(tn);
181
  }
182
510
  EnumTypeInfo& eti = d_tinfo[tn];
183
510
  std::map<NodeRole, StrategyNode>::iterator itsn = eti.d_snodes.find(nrole);
184
510
  if (itsn != eti.d_snodes.end())
185
  {
186
    // already initialized
187
590
    return;
188
  }
189
244
  StrategyNode& snode = eti.d_snodes[nrole];
190
191
  // get the enumerator for this
192
244
  EnumRole erole = getEnumeratorRoleForNodeRole(nrole);
193
194
430
  Node ee;
195
244
  std::map<EnumRole, Node>::iterator iten = eti.d_enum.find(erole);
196
244
  if (iten == eti.d_enum.end())
197
  {
198
214
    ee = nm->mkSkolem("ee", tn);
199
214
    eti.d_enum[erole] = ee;
200
428
    Trace("sygus-unif-debug")
201
428
        << "...enumerator " << ee << " for " << tn.getDType().getName()
202
214
        << ", role = " << erole << std::endl;
203
  }
204
  else
205
  {
206
30
    ee = iten->second;
207
  }
208
209
  // roles that we do not recurse on
210
244
  if (nrole == role_ite_condition)
211
  {
212
58
    Trace("sygus-unif-debug") << "...this register (non-io)" << std::endl;
213
58
    registerStrategyPoint(ee, tn, erole, true);
214
58
    return;
215
  }
216
217
  // look at information on how we will construct solutions for this type
218
  // we know this is a sygus datatype since it is either the top-level type
219
  // in the strategy graph, or was recursed by a strategy we inferred.
220
186
  Assert(tn.isDatatype());
221
186
  const DType& dt = tn.getDType();
222
186
  Assert(dt.isSygus());
223
224
372
  std::map<Node, std::vector<StrategyType> > cop_to_strat;
225
372
  std::map<Node, unsigned> cop_to_cindex;
226
372
  std::map<Node, std::map<unsigned, Node> > cop_to_child_templ;
227
372
  std::map<Node, std::map<unsigned, Node> > cop_to_child_templ_arg;
228
372
  std::map<Node, std::vector<unsigned> > cop_to_carg_list;
229
372
  std::map<Node, std::vector<TypeNode> > cop_to_child_types;
230
372
  std::map<Node, std::vector<Node> > cop_to_sks;
231
232
  // whether we will enumerate the current type
233
186
  bool search_this = false;
234
1359
  for (unsigned j = 0, ncons = dt.getNumConstructors(); j < ncons; j++)
235
  {
236
2346
    Node cop = dt[j].getConstructor();
237
2346
    Node op = dt[j].getSygusOp();
238
2346
    Trace("sygus-unif-debug") << "--- Infer strategy from " << cop
239
1173
                              << " with sygus op " << op << "..." << std::endl;
240
241
    // expand the evaluation to see if this constuctor induces a strategy
242
2346
    std::vector<Node> utchildren;
243
1173
    utchildren.push_back(cop);
244
2346
    std::vector<Node> sks;
245
2346
    std::vector<TypeNode> sktns;
246
2204
    for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++)
247
    {
248
2062
      TypeNode ttn = dt[j][k].getRangeType();
249
2062
      Node kv = nm->mkSkolem("ut", ttn);
250
1031
      sks.push_back(kv);
251
1031
      cop_to_sks[cop].push_back(kv);
252
1031
      sktns.push_back(ttn);
253
1031
      utchildren.push_back(kv);
254
    }
255
2346
    Node ut = nm->mkNode(APPLY_CONSTRUCTOR, utchildren);
256
2346
    std::vector<Node> echildren;
257
1173
    echildren.push_back(ut);
258
2346
    Node sbvl = dt.getSygusVarList();
259
6259
    for (const Node& sbv : sbvl)
260
    {
261
5086
      echildren.push_back(sbv);
262
    }
263
2346
    Node eut = nm->mkNode(DT_SYGUS_EVAL, echildren);
264
2346
    Trace("sygus-unif-debug2") << "  Test evaluation of " << eut << "..."
265
1173
                               << std::endl;
266
1173
    eut = d_qe->getTermDatabaseSygus()->getEvalUnfold()->unfold(eut);
267
1173
    Trace("sygus-unif-debug2") << "  ...got " << eut;
268
1173
    Trace("sygus-unif-debug2") << ", type : " << eut.getType() << std::endl;
269
270
    // candidate strategy
271
1173
    if (eut.getKind() == ITE)
272
    {
273
62
      cop_to_strat[cop].push_back(strat_ITE);
274
    }
275
1111
    else if (eut.getKind() == STRING_CONCAT)
276
    {
277
76
      if (nrole != role_string_suffix)
278
      {
279
50
        cop_to_strat[cop].push_back(strat_CONCAT_PREFIX);
280
      }
281
76
      if (nrole != role_string_prefix)
282
      {
283
52
        cop_to_strat[cop].push_back(strat_CONCAT_SUFFIX);
284
      }
285
    }
286
1035
    else if (dt[j].isSygusIdFunc())
287
    {
288
5
      cop_to_strat[cop].push_back(strat_ID);
289
    }
290
291
    // the kinds for which there is a strategy
292
1173
    if (cop_to_strat.find(cop) != cop_to_strat.end())
293
    {
294
      // infer an injection from the arguments of the datatype
295
286
      std::map<unsigned, unsigned> templ_injection;
296
286
      std::vector<Node> vs;
297
286
      std::vector<Node> ss;
298
286
      std::map<Node, unsigned> templ_var_index;
299
510
      for (unsigned k = 0, sksize = sks.size(); k < sksize; k++)
300
      {
301
367
        Assert(sks[k].getType().isDatatype());
302
367
        echildren[0] = sks[k];
303
734
        Trace("sygus-unif-debug2") << "...set eval dt to " << sks[k]
304
367
                                   << std::endl;
305
734
        Node esk = nm->mkNode(DT_SYGUS_EVAL, echildren);
306
367
        vs.push_back(esk);
307
734
        Node tvar = nm->mkSkolem("templ", esk.getType());
308
367
        templ_var_index[tvar] = k;
309
734
        Trace("sygus-unif-debug2") << "* template inference : looking for "
310
367
                                   << tvar << " for arg " << k << std::endl;
311
367
        ss.push_back(tvar);
312
734
        Trace("sygus-unif-debug2") << "* substitute : " << esk << " -> " << tvar
313
367
                                   << std::endl;
314
      }
315
143
      eut = eut.substitute(vs.begin(), vs.end(), ss.begin(), ss.end());
316
286
      Trace("sygus-unif-debug2") << "Constructor " << j << ", base term is "
317
143
                                 << eut << std::endl;
318
286
      std::map<unsigned, Node> test_args;
319
143
      if (dt[j].isSygusIdFunc())
320
      {
321
5
        test_args[0] = eut;
322
      }
323
      else
324
      {
325
500
        for (unsigned k = 0, size = eut.getNumChildren(); k < size; k++)
326
        {
327
362
          test_args[k] = eut[k];
328
        }
329
      }
330
331
      // TODO : prefix grouping prefix/suffix
332
143
      bool isAssoc = TermUtil::isAssoc(eut.getKind());
333
286
      Trace("sygus-unif-debug2") << eut.getKind() << " isAssoc = " << isAssoc
334
143
                                 << std::endl;
335
286
      std::map<unsigned, std::vector<unsigned> > assoc_combine;
336
286
      std::vector<unsigned> assoc_waiting;
337
143
      int assoc_last_valid_index = -1;
338
510
      for (std::pair<const unsigned, Node>& ta : test_args)
339
      {
340
367
        unsigned k = ta.first;
341
734
        Node eut_c = ta.second;
342
        // success if we can find a injection from args to sygus args
343
367
        if (!inferTemplate(k, eut_c, templ_var_index, templ_injection))
344
        {
345
          Trace("sygus-unif-debug")
346
              << "...fail: could not find injection (range)." << std::endl;
347
          cop_to_strat.erase(cop);
348
          break;
349
        }
350
367
        std::map<unsigned, unsigned>::iterator itti = templ_injection.find(k);
351
367
        if (itti != templ_injection.end())
352
        {
353
          // if associative, combine arguments if it is the same variable
354
507
          if (isAssoc && assoc_last_valid_index >= 0
355
431
              && itti->second == templ_injection[assoc_last_valid_index])
356
          {
357
            templ_injection.erase(k);
358
            assoc_combine[assoc_last_valid_index].push_back(k);
359
          }
360
          else
361
          {
362
349
            assoc_last_valid_index = (int)k;
363
349
            if (!assoc_waiting.empty())
364
            {
365
              assoc_combine[k].insert(assoc_combine[k].end(),
366
                                      assoc_waiting.begin(),
367
                                      assoc_waiting.end());
368
              assoc_waiting.clear();
369
            }
370
349
            assoc_combine[k].push_back(k);
371
          }
372
        }
373
        else
374
        {
375
          // a ground argument
376
18
          if (!isAssoc)
377
          {
378
            Trace("sygus-unif-debug")
379
                << "...fail: could not find injection (functional)."
380
                << std::endl;
381
            cop_to_strat.erase(cop);
382
            break;
383
          }
384
          else
385
          {
386
18
            if (assoc_last_valid_index >= 0)
387
            {
388
18
              assoc_combine[assoc_last_valid_index].push_back(k);
389
            }
390
            else
391
            {
392
              assoc_waiting.push_back(k);
393
            }
394
          }
395
        }
396
      }
397
143
      if (cop_to_strat.find(cop) != cop_to_strat.end())
398
      {
399
        // construct the templates
400
143
        if (!assoc_waiting.empty())
401
        {
402
          // could not find a way to fit some arguments into injection
403
          cop_to_strat.erase(cop);
404
        }
405
        else
406
        {
407
510
          for (std::pair<const unsigned, Node>& ta : test_args)
408
          {
409
367
            unsigned k = ta.first;
410
734
            Trace("sygus-unif-debug2") << "- processing argument " << k << "..."
411
367
                                       << std::endl;
412
367
            if (templ_injection.find(k) != templ_injection.end())
413
            {
414
349
              unsigned sk_index = templ_injection[k];
415
1396
              if (std::find(cop_to_carg_list[cop].begin(),
416
349
                            cop_to_carg_list[cop].end(),
417
698
                            sk_index)
418
1047
                  == cop_to_carg_list[cop].end())
419
              {
420
349
                cop_to_carg_list[cop].push_back(sk_index);
421
              }
422
              else
423
              {
424
                Trace("sygus-unif-debug") << "...fail: duplicate argument used"
425
                                          << std::endl;
426
                cop_to_strat.erase(cop);
427
                break;
428
              }
429
              // also store the template information, if necessary
430
698
              Node teut;
431
349
              if (isAssoc)
432
              {
433
158
                std::vector<unsigned>& ac = assoc_combine[k];
434
158
                Assert(!ac.empty());
435
316
                std::vector<Node> children;
436
334
                for (unsigned ack = 0, size_ac = ac.size(); ack < size_ac;
437
                     ack++)
438
                {
439
176
                  children.push_back(eut[ac[ack]]);
440
                }
441
316
                teut = children.size() == 1
442
316
                           ? children[0]
443
                           : nm->mkNode(eut.getKind(), children);
444
158
                teut = Rewriter::rewrite(teut);
445
              }
446
              else
447
              {
448
191
                teut = ta.second;
449
              }
450
451
349
              if (!teut.isVar())
452
              {
453
37
                cop_to_child_templ[cop][k] = teut;
454
37
                cop_to_child_templ_arg[cop][k] = ss[sk_index];
455
74
                Trace("sygus-unif-debug")
456
37
                    << "  Arg " << k << " (template : " << teut << " arg "
457
37
                    << ss[sk_index] << "), index " << sk_index << std::endl;
458
              }
459
              else
460
              {
461
624
                Trace("sygus-unif-debug") << "  Arg " << k << ", index "
462
312
                                          << sk_index << std::endl;
463
312
                Assert(teut == ss[sk_index]);
464
              }
465
            }
466
            else
467
            {
468
18
              Assert(isAssoc);
469
            }
470
          }
471
        }
472
      }
473
    }
474
475
1173
    std::map<Node, std::vector<StrategyType> >::iterator itcs = cop_to_strat.find(cop);
476
1173
    if (itcs != cop_to_strat.end())
477
    {
478
286
      Trace("sygus-unif") << "-> constructor " << cop
479
286
                          << " matches strategy for " << eut.getKind() << "..."
480
143
                          << std::endl;
481
      // collect children types
482
492
      for (unsigned k = 0, size = cop_to_carg_list[cop].size(); k < size; k++)
483
      {
484
698
        TypeNode ctn = sktns[cop_to_carg_list[cop][k]];
485
698
        Trace("sygus-unif-debug") << "   Child type " << k << " : "
486
349
                                  << ctn.getDType().getName() << std::endl;
487
349
        cop_to_child_types[cop].push_back(ctn);
488
      }
489
      // if there are checks on the consistency of child types wrt strategies,
490
      // these should be enforced here. We currently have none.
491
    }
492
1173
    if (cop_to_strat.find(cop) == cop_to_strat.end())
493
    {
494
2060
      Trace("sygus-unif") << "...constructor " << cop
495
1030
                          << " does not correspond to a strategy." << std::endl;
496
1030
      search_this = true;
497
    }
498
  }
499
500
  // check whether we should also enumerate the current type
501
186
  Trace("sygus-unif-debug2") << "  register this strategy ..." << std::endl;
502
186
  registerStrategyPoint(ee, tn, erole, search_this);
503
504
186
  if (cop_to_strat.empty())
505
  {
506
108
    Trace("sygus-unif") << "...consider " << dt.getName() << " a basic type"
507
54
                        << std::endl;
508
  }
509
  else
510
  {
511
275
    for (std::pair<const Node, std::vector<StrategyType> >& cstr : cop_to_strat)
512
    {
513
286
      Node cop = cstr.first;
514
286
      Trace("sygus-unif-debug") << "Constructor " << cop << " has "
515
286
                                << cstr.second.size() << " strategies..."
516
143
                                << std::endl;
517
312
      for (unsigned s = 0, ssize = cstr.second.size(); s < ssize; s++)
518
      {
519
169
        EnumTypeInfoStrat* cons_strat = new EnumTypeInfoStrat;
520
169
        StrategyType strat = cstr.second[s];
521
522
169
        cons_strat->d_this = strat;
523
169
        cons_strat->d_cons = cop;
524
338
        Trace("sygus-unif-debug") << "Process strategy #" << s
525
169
                                  << " for operator : " << cop << " : " << strat
526
169
                                  << std::endl;
527
169
        Assert(cop_to_child_types.find(cop) != cop_to_child_types.end());
528
169
        std::vector<TypeNode>& childTypes = cop_to_child_types[cop];
529
169
        Assert(cop_to_carg_list.find(cop) != cop_to_carg_list.end());
530
169
        std::vector<unsigned>& cargList = cop_to_carg_list[cop];
531
532
338
        std::vector<Node> sol_templ_children;
533
169
        sol_templ_children.resize(cop_to_sks[cop].size());
534
535
572
        for (unsigned j = 0, csize = childTypes.size(); j < csize; j++)
536
        {
537
          // calculate if we should allocate a new enumerator : should be true
538
          // if we have a new role
539
403
          NodeRole nrole_c = nrole;
540
403
          if (strat == strat_ITE)
541
          {
542
186
            if (j == 0)
543
            {
544
62
              nrole_c = role_ite_condition;
545
            }
546
          }
547
217
          else if (strat == strat_CONCAT_PREFIX)
548
          {
549
104
            if ((j + 1) < childTypes.size())
550
            {
551
54
              nrole_c = role_string_prefix;
552
            }
553
          }
554
113
          else if (strat == strat_CONCAT_SUFFIX)
555
          {
556
108
            if (j > 0)
557
            {
558
56
              nrole_c = role_string_suffix;
559
            }
560
          }
561
          // in all other cases, role is same as parent
562
563
          // register the child type
564
806
          TypeNode ct = childTypes[j];
565
806
          Node csk = cop_to_sks[cop][cargList[j]];
566
403
          cons_strat->d_sol_templ_args.push_back(csk);
567
403
          sol_templ_children[cargList[j]] = csk;
568
569
403
          EnumRole erole_c = getEnumeratorRoleForNodeRole(nrole_c);
570
          // make the enumerator
571
806
          Node et;
572
          // Build the strategy recursively, regardless of whether the
573
          // enumerator is templated.
574
403
          buildStrategyGraph(ct, nrole_c);
575
403
          if (cop_to_child_templ[cop].find(j) != cop_to_child_templ[cop].end())
576
          {
577
            // it is templated, allocate a fresh variable
578
43
            et = nm->mkSkolem("et", ct);
579
86
            Trace("sygus-unif-debug") << "...enumerate " << et << " of type "
580
43
                                      << ct.getDType().getName();
581
86
            Trace("sygus-unif-debug") << " for arg " << j << " of "
582
43
                                      << tn.getDType().getName() << std::endl;
583
43
            registerStrategyPoint(et, ct, erole_c, true);
584
43
            d_einfo[et].d_template = cop_to_child_templ[cop][j];
585
43
            d_einfo[et].d_template_arg = cop_to_child_templ_arg[cop][j];
586
43
            Assert(!d_einfo[et].d_template.isNull());
587
43
            Assert(!d_einfo[et].d_template_arg.isNull());
588
          }
589
          else
590
          {
591
720
            Trace("sygus-unif-debug")
592
720
                << "...child type enumerate " << ct.getDType().getName()
593
360
                << ", node role = " << nrole_c << std::endl;
594
            // otherwise use the previous
595
360
            Assert(d_tinfo[ct].d_enum.find(erole_c)
596
                   != d_tinfo[ct].d_enum.end());
597
360
            et = d_tinfo[ct].d_enum[erole_c];
598
          }
599
806
          Trace("sygus-unif-debug") << "Register child enumerator " << et
600
403
                                    << ", arg " << j << " of " << cop
601
403
                                    << ", role = " << erole_c << std::endl;
602
403
          Assert(!et.isNull());
603
403
          cons_strat->d_cenum.push_back(std::pair<Node, NodeRole>(et, nrole_c));
604
        }
605
        // children that are unused in the strategy can be arbitrary
606
596
        for (unsigned j = 0, stsize = sol_templ_children.size(); j < stsize;
607
             j++)
608
        {
609
427
          if (sol_templ_children[j].isNull())
610
          {
611
24
            sol_templ_children[j] = cop_to_sks[cop][j].getType().mkGroundTerm();
612
          }
613
        }
614
169
        sol_templ_children.insert(sol_templ_children.begin(), cop);
615
338
        cons_strat->d_sol_templ =
616
507
            nm->mkNode(APPLY_CONSTRUCTOR, sol_templ_children);
617
169
        if (strat == strat_CONCAT_SUFFIX)
618
        {
619
52
          std::reverse(cons_strat->d_cenum.begin(), cons_strat->d_cenum.end());
620
52
          std::reverse(cons_strat->d_sol_templ_args.begin(),
621
52
                       cons_strat->d_sol_templ_args.end());
622
        }
623
169
        if (Trace.isOn("sygus-unif"))
624
        {
625
          Trace("sygus-unif") << "Initialized strategy " << strat;
626
          Trace("sygus-unif")
627
              << " for " << tn.getDType().getName() << ", operator " << cop;
628
          Trace("sygus-unif") << ", #children = " << cons_strat->d_cenum.size()
629
                              << ", solution template = (lambda ( ";
630
          for (const Node& targ : cons_strat->d_sol_templ_args)
631
          {
632
            Trace("sygus-unif") << targ << " ";
633
          }
634
          Trace("sygus-unif") << ") " << cons_strat->d_sol_templ << ")";
635
          Trace("sygus-unif") << std::endl;
636
        }
637
        // make the strategy
638
169
        snode.d_strats.push_back(cons_strat);
639
      }
640
    }
641
  }
642
}
643
644
413
bool SygusUnifStrategy::inferTemplate(
645
    unsigned k,
646
    Node n,
647
    std::map<Node, unsigned>& templ_var_index,
648
    std::map<unsigned, unsigned>& templ_injection)
649
{
650
413
  if (n.getNumChildren() == 0)
651
  {
652
390
    std::map<Node, unsigned>::iterator itt = templ_var_index.find(n);
653
390
    if (itt != templ_var_index.end())
654
    {
655
349
      unsigned kk = itt->second;
656
349
      std::map<unsigned, unsigned>::iterator itti = templ_injection.find(k);
657
349
      if (itti == templ_injection.end())
658
      {
659
698
        Trace("sygus-unif-debug") << "...set template injection " << k << " -> "
660
349
                                  << kk << std::endl;
661
349
        templ_injection[k] = kk;
662
      }
663
      else if (itti->second != kk)
664
      {
665
        // two distinct variables in this term, we fail
666
        return false;
667
      }
668
    }
669
390
    return true;
670
  }
671
  else
672
  {
673
69
    for (unsigned i = 0; i < n.getNumChildren(); i++)
674
    {
675
46
      if (!inferTemplate(k, n[i], templ_var_index, templ_injection))
676
      {
677
        return false;
678
      }
679
    }
680
  }
681
23
  return true;
682
}
683
684
84
void SygusUnifStrategy::staticLearnRedundantOps(
685
    std::map<Node, std::vector<Node>>& strategy_lemmas)
686
{
687
168
  StrategyRestrictions restrictions;
688
84
  staticLearnRedundantOps(strategy_lemmas, restrictions);
689
84
}
690
691
107
void SygusUnifStrategy::staticLearnRedundantOps(
692
    std::map<Node, std::vector<Node>>& strategy_lemmas,
693
    StrategyRestrictions& restrictions)
694
{
695
253
  for (unsigned i = 0; i < d_esym_list.size(); i++)
696
  {
697
292
    Node e = d_esym_list[i];
698
146
    std::map<Node, EnumInfo>::iterator itn = d_einfo.find(e);
699
146
    Assert(itn != d_einfo.end());
700
    // see if there is anything we can eliminate
701
292
    Trace("sygus-unif") << "* Search enumerator #" << i << " : type "
702
146
                        << e.getType().getDType().getName() << " : ";
703
292
    Trace("sygus-unif") << e << " has " << itn->second.d_enum_slave.size()
704
146
                        << " slaves:" << std::endl;
705
396
    for (unsigned j = 0; j < itn->second.d_enum_slave.size(); j++)
706
    {
707
500
      Node es = itn->second.d_enum_slave[j];
708
250
      std::map<Node, EnumInfo>::iterator itns = d_einfo.find(es);
709
250
      Assert(itns != d_einfo.end());
710
500
      Trace("sygus-unif") << "  " << es << ", role = " << itns->second.getRole()
711
250
                          << std::endl;
712
    }
713
  }
714
107
  Trace("sygus-unif") << std::endl;
715
214
  Trace("sygus-unif") << "Strategy for candidate " << d_candidate
716
107
                      << " is : " << std::endl;
717
107
  debugPrint("sygus-unif");
718
214
  std::map<Node, std::map<NodeRole, bool> > visited;
719
214
  std::map<Node, std::map<unsigned, bool> > needs_cons;
720
107
  staticLearnRedundantOps(
721
214
      getRootEnumerator(), role_equal, visited, needs_cons, restrictions);
722
  // now, check the needs_cons map
723
299
  for (std::pair<const Node, std::map<unsigned, bool> >& nce : needs_cons)
724
  {
725
384
    Node em = nce.first;
726
192
    const DType& dt = em.getType().getDType();
727
384
    std::vector<Node> lemmas;
728
1482
    for (std::pair<const unsigned, bool>& nc : nce.second)
729
    {
730
1290
      Assert(nc.first < dt.getNumConstructors());
731
1290
      if (!nc.second)
732
      {
733
690
        Node tst = datatypes::utils::mkTester(em, nc.first, dt).negate();
734
735
345
        if (std::find(lemmas.begin(), lemmas.end(), tst) == lemmas.end())
736
        {
737
690
          Trace("sygus-unif") << "...can exclude based on  : " << tst
738
345
                              << std::endl;
739
345
          lemmas.push_back(tst);
740
        }
741
      }
742
    }
743
192
    if (!lemmas.empty())
744
    {
745
128
      strategy_lemmas[em] = lemmas;
746
    }
747
  }
748
107
}
749
750
107
void SygusUnifStrategy::debugPrint(const char* c)
751
{
752
107
  if (Trace.isOn(c))
753
  {
754
    std::map<Node, std::map<NodeRole, bool> > visited;
755
    debugPrint(c, getRootEnumerator(), role_equal, visited, 0);
756
  }
757
107
}
758
759
495
void SygusUnifStrategy::staticLearnRedundantOps(
760
    Node e,
761
    NodeRole nrole,
762
    std::map<Node, std::map<NodeRole, bool>>& visited,
763
    std::map<Node, std::map<unsigned, bool>>& needs_cons,
764
    StrategyRestrictions& restrictions)
765
{
766
495
  if (visited[e].find(nrole) != visited[e].end())
767
  {
768
521
    return;
769
  }
770
506
  Trace("sygus-strat-slearn") << "Learn redundant operators " << e << " "
771
253
                              << nrole << "..." << std::endl;
772
253
  visited[e][nrole] = true;
773
253
  EnumInfo& ei = getEnumInfo(e);
774
253
  if (ei.isTemplated())
775
  {
776
37
    return;
777
  }
778
432
  TypeNode etn = e.getType();
779
216
  EnumTypeInfo& tinfo = getEnumTypeInfo(etn);
780
216
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
781
  // the constructors of the current strategy point we need
782
432
  std::map<unsigned, bool> needs_cons_curr;
783
  // get the unused strategies
784
  std::map<Node, std::unordered_set<unsigned>>::iterator itus =
785
216
      restrictions.d_unused_strategies.find(e);
786
432
  std::unordered_set<unsigned> unused_strats;
787
216
  if (itus != restrictions.d_unused_strategies.end())
788
  {
789
3
    unused_strats.insert(itus->second.begin(), itus->second.end());
790
  }
791
379
  for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
792
  {
793
    // if we are not using this strategy, there is nothing to do
794
163
    if (unused_strats.find(j) != unused_strats.end())
795
    {
796
3
      continue;
797
    }
798
160
    EnumTypeInfoStrat* etis = snode.d_strats[j];
799
160
    unsigned cindex = datatypes::utils::indexOf(etis->d_cons);
800
    // constructors that correspond to strategies are not needed
801
    // the intuition is that the strategy itself is responsible for constructing
802
    // all terms that use the given constructor
803
320
    Trace("sygus-strat-slearn") << "...by strategy, can exclude operator "
804
160
                                << etis->d_cons << std::endl;
805
160
    needs_cons_curr[cindex] = false;
806
    // try to eliminate from etn's datatype all operators except TRUE/FALSE if
807
    // arguments of ITE are the same BOOL type
808
160
    if (restrictions.d_iteReturnBoolConst)
809
    {
810
17
      const DType& dt = etn.getDType();
811
34
      Node op = dt[cindex].getSygusOp();
812
34
      TypeNode sygus_tn = dt.getSygusType();
813
34
      if (op.getKind() == kind::BUILTIN
814
27
          && NodeManager::operatorToKind(op) == ITE && sygus_tn.isBoolean()
815
42
          && (dt[cindex].getArgType(1) == dt[cindex].getArgType(2)))
816
      {
817
8
        unsigned ncons = dt.getNumConstructors(), indexT = ncons,
818
8
                 indexF = ncons;
819
183
        for (unsigned k = 0; k < ncons; ++k)
820
        {
821
191
          Node op_arg = dt[k].getSygusOp();
822
175
          if (dt[k].getNumArgs() > 0 || !op_arg.isConst())
823
          {
824
159
            continue;
825
          }
826
16
          if (op_arg.getConst<bool>())
827
          {
828
8
            indexT = k;
829
          }
830
          else
831
          {
832
8
            indexF = k;
833
          }
834
        }
835
8
        if (indexT < ncons && indexF < ncons)
836
        {
837
16
          Trace("sygus-strat-slearn")
838
8
              << "...for ite boolean arg, can exclude all operators but T/F\n";
839
183
          for (unsigned k = 0; k < ncons; ++k)
840
          {
841
175
            needs_cons_curr[k] = false;
842
          }
843
8
          needs_cons_curr[indexT] = true;
844
8
          needs_cons_curr[indexF] = true;
845
        }
846
      }
847
    }
848
548
    for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
849
    {
850
388
      staticLearnRedundantOps(
851
          cec.first, cec.second, visited, needs_cons, restrictions);
852
    }
853
  }
854
  // get the current datatype
855
216
  const DType& dt = etn.getDType();
856
  // do not use recursive Boolean connectives for conditions of ITEs
857
216
  if (nrole == role_ite_condition && restrictions.d_iteCondOnlyAtoms)
858
  {
859
78
    TypeNode sygus_tn = dt.getSygusType();
860
323
    for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++)
861
    {
862
528
      Node op = dt[j].getSygusOp();
863
568
      Trace("sygus-strat-slearn")
864
284
          << "...for ite condition, look at operator : " << op << std::endl;
865
324
      if (op.isConst() && dt[j].getNumArgs() == 0)
866
      {
867
80
        Trace("sygus-strat-slearn")
868
40
            << "...for ite condition, can exclude Boolean constant " << op
869
40
            << std::endl;
870
40
        needs_cons_curr[j] = false;
871
40
        continue;
872
      }
873
244
      if (op.getKind() == kind::BUILTIN)
874
      {
875
60
        Kind kind = NodeManager::operatorToKind(op);
876
60
        if (kind == NOT || kind == OR || kind == AND || kind == ITE)
877
        {
878
          // can eliminate if their argument types are simple loops to this type
879
40
          bool type_ok = true;
880
120
          for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++)
881
          {
882
160
            TypeNode tn = dt[j].getArgType(k);
883
80
            if (tn != etn)
884
            {
885
              type_ok = false;
886
              break;
887
            }
888
          }
889
40
          if (type_ok)
890
          {
891
80
            Trace("sygus-strat-slearn")
892
40
                << "...for ite condition, can exclude Boolean connective : "
893
40
                << op << std::endl;
894
40
            needs_cons_curr[j] = false;
895
          }
896
        }
897
      }
898
    }
899
  }
900
  // all other constructors are needed
901
1628
  for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++)
902
  {
903
1412
    if (needs_cons_curr.find(j) == needs_cons_curr.end())
904
    {
905
1031
      needs_cons_curr[j] = true;
906
    }
907
  }
908
  // update the constructors that the master enumerator needs
909
216
  if (needs_cons.find(e) == needs_cons.end())
910
  {
911
192
    needs_cons[e] = needs_cons_curr;
912
  }
913
  else
914
  {
915
146
    for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++)
916
    {
917
122
      needs_cons[e][j] = needs_cons[e][j] || needs_cons_curr[j];
918
    }
919
  }
920
}
921
922
687
void SygusUnifStrategy::finishInit(
923
    Node e,
924
    NodeRole nrole,
925
    std::map<Node, std::map<NodeRole, bool> >& visited,
926
    bool isCond)
927
{
928
687
  EnumInfo& ei = getEnumInfo(e);
929
1374
  if (visited[e].find(nrole) != visited[e].end()
930
687
      && (!isCond || ei.isConditional()))
931
  {
932
781
    return;
933
  }
934
315
  visited[e][nrole] = true;
935
  // set conditional
936
315
  if (isCond)
937
  {
938
121
    ei.setConditional();
939
  }
940
315
  if (ei.isTemplated())
941
  {
942
37
    return;
943
  }
944
556
  TypeNode etn = e.getType();
945
278
  EnumTypeInfo& tinfo = getEnumTypeInfo(etn);
946
278
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
947
508
  for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
948
  {
949
230
    EnumTypeInfoStrat* etis = snode.d_strats[j];
950
230
    StrategyType strat = etis->d_this;
951
230
    bool newIsCond = isCond || strat == strat_ITE;
952
810
    for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
953
    {
954
580
      finishInit(cec.first, cec.second, visited, newIsCond);
955
    }
956
  }
957
}
958
959
void SygusUnifStrategy::debugPrint(
960
    const char* c,
961
    Node e,
962
    NodeRole nrole,
963
    std::map<Node, std::map<NodeRole, bool> >& visited,
964
    int ind)
965
{
966
  if (visited[e].find(nrole) != visited[e].end())
967
  {
968
    indent(c, ind);
969
    Trace(c) << e << " :: node role : " << nrole << std::endl;
970
    return;
971
  }
972
  visited[e][nrole] = true;
973
  EnumInfo& ei = getEnumInfo(e);
974
975
  TypeNode etn = e.getType();
976
977
  indent(c, ind);
978
  Trace(c) << e << " :: node role : " << nrole;
979
  Trace(c) << ", type : " << etn.getDType().getName();
980
  if (ei.isConditional())
981
  {
982
    Trace(c) << ", conditional";
983
  }
984
  Trace(c) << ", enum role : " << ei.getRole();
985
986
  if (ei.isTemplated())
987
  {
988
    Trace(c) << ", templated : (lambda " << ei.d_template_arg << " "
989
             << ei.d_template << ")" << std::endl;
990
    return;
991
  }
992
  Trace(c) << std::endl;
993
994
  EnumTypeInfo& tinfo = getEnumTypeInfo(etn);
995
  StrategyNode& snode = tinfo.getStrategyNode(nrole);
996
  for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
997
  {
998
    EnumTypeInfoStrat* etis = snode.d_strats[j];
999
    StrategyType strat = etis->d_this;
1000
    indent(c, ind + 1);
1001
    Trace(c) << "Strategy : " << strat << ", from cons : " << etis->d_cons
1002
             << std::endl;
1003
    for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
1004
    {
1005
      // recurse
1006
      debugPrint(c, cec.first, cec.second, visited, ind + 2);
1007
    }
1008
  }
1009
}
1010
1011
257
void EnumInfo::initialize(EnumRole role) { d_role = role; }
1012
1013
1865
StrategyNode& EnumTypeInfo::getStrategyNode(NodeRole nrole)
1014
{
1015
1865
  std::map<NodeRole, StrategyNode>::iterator it = d_snodes.find(nrole);
1016
1865
  Assert(it != d_snodes.end());
1017
1865
  return it->second;
1018
}
1019
1020
5773
bool EnumTypeInfoStrat::isValid(UnifContext& x)
1021
{
1022
11546
  if ((x.getCurrentRole() == role_string_prefix
1023
2578
       && d_this == strat_CONCAT_SUFFIX)
1024
11838
      || (x.getCurrentRole() == role_string_suffix
1025
1546
          && d_this == strat_CONCAT_PREFIX))
1026
  {
1027
912
    return false;
1028
  }
1029
4861
  return true;
1030
}
1031
1032
488
StrategyNode::~StrategyNode()
1033
{
1034
413
  for (unsigned j = 0, size = d_strats.size(); j < size; j++)
1035
  {
1036
169
    delete d_strats[j];
1037
  }
1038
244
  d_strats.clear();
1039
244
}
1040
1041
void SygusUnifStrategy::indent(const char* c, int ind)
1042
{
1043
  if (Trace.isOn(c))
1044
  {
1045
    for (int i = 0; i < ind; i++)
1046
    {
1047
      Trace(c) << "  ";
1048
    }
1049
  }
1050
}
1051
1052
} /* CVC4::theory::quantifiers namespace */
1053
} /* CVC4::theory namespace */
1054
26676
} /* CVC4 namespace */