xref: /llvm-project/mlir/lib/TableGen/Predicate.cpp (revision d8399d5dd6a5a7025621eddd97fc0fa1f494bad8)
1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Predicate.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
20 
21 using namespace mlir;
22 using namespace tblgen;
23 using llvm::Init;
24 using llvm::Record;
25 using llvm::SpecificBumpPtrAllocator;
26 
27 // Construct a Predicate from a record.
28 Pred::Pred(const Record *record) : def(record) {
29   assert(def->isSubClassOf("Pred") &&
30          "must be a subclass of TableGen 'Pred' class");
31 }
32 
33 // Construct a Predicate from an initializer.
34 Pred::Pred(const Init *init) {
35   if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
36     def = defInit->getDef();
37 }
38 
39 std::string Pred::getCondition() const {
40   // Static dispatch to subclasses.
41   if (def->isSubClassOf("CombinedPred"))
42     return static_cast<const CombinedPred *>(this)->getConditionImpl();
43   if (def->isSubClassOf("CPred"))
44     return static_cast<const CPred *>(this)->getConditionImpl();
45   llvm_unreachable("Pred::getCondition must be overridden in subclasses");
46 }
47 
48 bool Pred::isCombined() const {
49   return def && def->isSubClassOf("CombinedPred");
50 }
51 
52 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
53 
54 CPred::CPred(const Record *record) : Pred(record) {
55   assert(def->isSubClassOf("CPred") &&
56          "must be a subclass of Tablegen 'CPred' class");
57 }
58 
59 CPred::CPred(const Init *init) : Pred(init) {
60   assert((!def || def->isSubClassOf("CPred")) &&
61          "must be a subclass of Tablegen 'CPred' class");
62 }
63 
64 // Get condition of the C Predicate.
65 std::string CPred::getConditionImpl() const {
66   assert(!isNull() && "null predicate does not have a condition");
67   return std::string(def->getValueAsString("predExpr"));
68 }
69 
70 CombinedPred::CombinedPred(const Record *record) : Pred(record) {
71   assert(def->isSubClassOf("CombinedPred") &&
72          "must be a subclass of Tablegen 'CombinedPred' class");
73 }
74 
75 CombinedPred::CombinedPred(const Init *init) : Pred(init) {
76   assert((!def || def->isSubClassOf("CombinedPred")) &&
77          "must be a subclass of Tablegen 'CombinedPred' class");
78 }
79 
80 const Record *CombinedPred::getCombinerDef() const {
81   assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
82   return def->getValueAsDef("kind");
83 }
84 
85 std::vector<const Record *> CombinedPred::getChildren() const {
86   assert(def->getValue("children") &&
87          "CombinedPred must have a value 'children'");
88   return def->getValueAsListOfDefs("children");
89 }
90 
91 namespace {
92 // Kinds of nodes in a logical predicate tree.
93 enum class PredCombinerKind {
94   Leaf,
95   And,
96   Or,
97   Not,
98   SubstLeaves,
99   Concat,
100   // Special kinds that are used in simplification.
101   False,
102   True
103 };
104 
105 // A node in a logical predicate tree.
106 struct PredNode {
107   PredCombinerKind kind;
108   const Pred *predicate;
109   SmallVector<PredNode *, 4> children;
110   std::string expr;
111 
112   // Prefix and suffix are used by ConcatPred.
113   std::string prefix;
114   std::string suffix;
115 };
116 } // namespace
117 
118 // Get a predicate tree node kind based on the kind used in the predicate
119 // TableGen record.
120 static PredCombinerKind getPredCombinerKind(const Pred &pred) {
121   if (!pred.isCombined())
122     return PredCombinerKind::Leaf;
123 
124   const auto &combinedPred = static_cast<const CombinedPred &>(pred);
125   return StringSwitch<PredCombinerKind>(
126              combinedPred.getCombinerDef()->getName())
127       .Case("PredCombinerAnd", PredCombinerKind::And)
128       .Case("PredCombinerOr", PredCombinerKind::Or)
129       .Case("PredCombinerNot", PredCombinerKind::Not)
130       .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
131       .Case("PredCombinerConcat", PredCombinerKind::Concat);
132 }
133 
134 namespace {
135 // Substitution<pattern, replacement>.
136 using Subst = std::pair<StringRef, StringRef>;
137 } // namespace
138 
139 /// Perform the given substitutions on 'str' in-place.
140 static void performSubstitutions(std::string &str,
141                                  ArrayRef<Subst> substitutions) {
142   // Apply all parent substitutions from innermost to outermost.
143   for (const auto &subst : llvm::reverse(substitutions)) {
144     auto pos = str.find(std::string(subst.first));
145     while (pos != std::string::npos) {
146       str.replace(pos, subst.first.size(), std::string(subst.second));
147       // Skip the newly inserted substring, which itself may consider the
148       // pattern to match.
149       pos += subst.second.size();
150       // Find the next possible match position.
151       pos = str.find(std::string(subst.first), pos);
152     }
153   }
154 }
155 
156 // Build the predicate tree starting from the top-level predicate, which may
157 // have children, and perform leaf substitutions inplace.  Note that after
158 // substitution, nodes are still pointing to the original TableGen record.
159 // All nodes are created within "allocator".
160 static PredNode *
161 buildPredicateTree(const Pred &root,
162                    SpecificBumpPtrAllocator<PredNode> &allocator,
163                    ArrayRef<Subst> substitutions) {
164   auto *rootNode = allocator.Allocate();
165   new (rootNode) PredNode;
166   rootNode->kind = getPredCombinerKind(root);
167   rootNode->predicate = &root;
168   if (!root.isCombined()) {
169     rootNode->expr = root.getCondition();
170     performSubstitutions(rootNode->expr, substitutions);
171     return rootNode;
172   }
173 
174   // If the current combined predicate is a leaf substitution, append it to the
175   // list before continuing.
176   auto allSubstitutions = llvm::to_vector<4>(substitutions);
177   if (rootNode->kind == PredCombinerKind::SubstLeaves) {
178     const auto &substPred = static_cast<const SubstLeavesPred &>(root);
179     allSubstitutions.push_back(
180         {substPred.getPattern(), substPred.getReplacement()});
181 
182     // If the current predicate is a ConcatPred, record the prefix and suffix.
183   } else if (rootNode->kind == PredCombinerKind::Concat) {
184     const auto &concatPred = static_cast<const ConcatPred &>(root);
185     rootNode->prefix = std::string(concatPred.getPrefix());
186     performSubstitutions(rootNode->prefix, substitutions);
187     rootNode->suffix = std::string(concatPred.getSuffix());
188     performSubstitutions(rootNode->suffix, substitutions);
189   }
190 
191   // Build child subtrees.
192   auto combined = static_cast<const CombinedPred &>(root);
193   for (const auto *record : combined.getChildren()) {
194     auto *childTree =
195         buildPredicateTree(Pred(record), allocator, allSubstitutions);
196     rootNode->children.push_back(childTree);
197   }
198   return rootNode;
199 }
200 
201 // Simplify a predicate tree rooted at "node" using the predicates that are
202 // known to be true(false).  For AND(OR) combined predicates, if any of the
203 // children is known to be false(true), the result is also false(true).
204 // Furthermore, for AND(OR) combined predicates, children that are known to be
205 // true(false) don't have to be checked dynamically.
206 static PredNode *
207 propagateGroundTruth(PredNode *node,
208                      const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
209                      const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
210   // If the current predicate is known to be true or false, change the kind of
211   // the node and return immediately.
212   if (knownTruePreds.count(node->predicate) != 0) {
213     node->kind = PredCombinerKind::True;
214     node->children.clear();
215     return node;
216   }
217   if (knownFalsePreds.count(node->predicate) != 0) {
218     node->kind = PredCombinerKind::False;
219     node->children.clear();
220     return node;
221   }
222 
223   // If the current node is a substitution, stop recursion now.
224   // The expressions in the leaves below this node were rewritten, but the nodes
225   // still point to the original predicate records.  While the original
226   // predicate may be known to be true or false, it is not necessarily the case
227   // after rewriting.
228   // TODO: we can support ground truth for rewritten
229   // predicates by either (a) having our own unique'ing of the predicates
230   // instead of relying on TableGen record pointers or (b) taking ground truth
231   // values optionally prefixed with a list of substitutions to apply, e.g.
232   // "predX is true by itself as well as predSubY leaf substitution had been
233   // applied to it".
234   if (node->kind == PredCombinerKind::SubstLeaves) {
235     return node;
236   }
237 
238   if (node->kind == PredCombinerKind::And && node->children.empty()) {
239     node->kind = PredCombinerKind::True;
240     return node;
241   }
242 
243   if (node->kind == PredCombinerKind::Or && node->children.empty()) {
244     node->kind = PredCombinerKind::False;
245     return node;
246   }
247 
248   // Otherwise, look at child nodes.
249 
250   // Move child nodes into some local variable so that they can be optimized
251   // separately and re-added if necessary.
252   llvm::SmallVector<PredNode *, 4> children;
253   std::swap(node->children, children);
254 
255   for (auto &child : children) {
256     // First, simplify the child.  This maintains the predicate as it was.
257     auto *simplifiedChild =
258         propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
259 
260     // Just add the child if we don't know how to simplify the current node.
261     if (node->kind != PredCombinerKind::And &&
262         node->kind != PredCombinerKind::Or) {
263       node->children.push_back(simplifiedChild);
264       continue;
265     }
266 
267     // Second, based on the type define which known values of child predicates
268     // immediately collapse this predicate to a known value, and which others
269     // may be safely ignored.
270     //   OR(..., True, ...) = True
271     //   OR(..., False, ...) = OR(..., ...)
272     //   AND(..., False, ...) = False
273     //   AND(..., True, ...) = AND(..., ...)
274     auto collapseKind = node->kind == PredCombinerKind::And
275                             ? PredCombinerKind::False
276                             : PredCombinerKind::True;
277     auto eraseKind = node->kind == PredCombinerKind::And
278                          ? PredCombinerKind::True
279                          : PredCombinerKind::False;
280     const auto &collapseList =
281         node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
282     const auto &eraseList =
283         node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
284     if (simplifiedChild->kind == collapseKind ||
285         collapseList.count(simplifiedChild->predicate) != 0) {
286       node->kind = collapseKind;
287       node->children.clear();
288       return node;
289     }
290     if (simplifiedChild->kind == eraseKind ||
291         eraseList.count(simplifiedChild->predicate) != 0) {
292       continue;
293     }
294     node->children.push_back(simplifiedChild);
295   }
296   return node;
297 }
298 
299 // Combine a list of predicate expressions using a binary combiner.  If a list
300 // is empty, return "init".
301 static std::string combineBinary(ArrayRef<std::string> children,
302                                  const std::string &combiner,
303                                  std::string init) {
304   if (children.empty())
305     return init;
306 
307   auto size = children.size();
308   if (size == 1)
309     return children.front();
310 
311   std::string str;
312   llvm::raw_string_ostream os(str);
313   os << '(' << children.front() << ')';
314   for (unsigned i = 1; i < size; ++i) {
315     os << ' ' << combiner << " (" << children[i] << ')';
316   }
317   return str;
318 }
319 
320 // Prepend negation to the only condition in the predicate expression list.
321 static std::string combineNot(ArrayRef<std::string> children) {
322   assert(children.size() == 1 && "expected exactly one child predicate of Neg");
323   return (Twine("!(") + children.front() + Twine(')')).str();
324 }
325 
326 // Recursively traverse the predicate tree in depth-first post-order and build
327 // the final expression.
328 static std::string getCombinedCondition(const PredNode &root) {
329   // Immediately return for non-combiner predicates that don't have children.
330   if (root.kind == PredCombinerKind::Leaf)
331     return root.expr;
332   if (root.kind == PredCombinerKind::True)
333     return "true";
334   if (root.kind == PredCombinerKind::False)
335     return "false";
336 
337   // Recurse into children.
338   llvm::SmallVector<std::string, 4> childExpressions;
339   childExpressions.reserve(root.children.size());
340   for (const auto &child : root.children)
341     childExpressions.push_back(getCombinedCondition(*child));
342 
343   // Combine the expressions based on the predicate node kind.
344   if (root.kind == PredCombinerKind::And)
345     return combineBinary(childExpressions, "&&", "true");
346   if (root.kind == PredCombinerKind::Or)
347     return combineBinary(childExpressions, "||", "false");
348   if (root.kind == PredCombinerKind::Not)
349     return combineNot(childExpressions);
350   if (root.kind == PredCombinerKind::Concat) {
351     assert(childExpressions.size() == 1 &&
352            "ConcatPred should only have one child");
353     return root.prefix + childExpressions.front() + root.suffix;
354   }
355 
356   // Substitutions were applied before so just ignore them.
357   if (root.kind == PredCombinerKind::SubstLeaves) {
358     assert(childExpressions.size() == 1 &&
359            "substitution predicate must have one child");
360     return childExpressions[0];
361   }
362 
363   llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
364 }
365 
366 std::string CombinedPred::getConditionImpl() const {
367   SpecificBumpPtrAllocator<PredNode> allocator;
368   auto *predicateTree = buildPredicateTree(*this, allocator, {});
369   predicateTree =
370       propagateGroundTruth(predicateTree,
371                            /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
372                            /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
373 
374   return getCombinedCondition(*predicateTree);
375 }
376 
377 StringRef SubstLeavesPred::getPattern() const {
378   return def->getValueAsString("pattern");
379 }
380 
381 StringRef SubstLeavesPred::getReplacement() const {
382   return def->getValueAsString("replacement");
383 }
384 
385 StringRef ConcatPred::getPrefix() const {
386   return def->getValueAsString("prefix");
387 }
388 
389 StringRef ConcatPred::getSuffix() const {
390   return def->getValueAsString("suffix");
391 }
392