xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision 242762c9a3313c8aea176ca76fb77adf8edf0907)
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "mlir/Dialect/PDL/IR/PDL.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::pdl_to_pdl_interp;
19 
20 //===----------------------------------------------------------------------===//
21 // Predicate List Building
22 //===----------------------------------------------------------------------===//
23 
24 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
25                               Value val, PredicateBuilder &builder,
26                               DenseMap<Value, Position *> &inputs,
27                               Position *pos);
28 
29 /// Compares the depths of two positions.
30 static bool comparePosDepth(Position *lhs, Position *rhs) {
31   return lhs->getIndex().size() < rhs->getIndex().size();
32 }
33 
34 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
35                               Value val, PredicateBuilder &builder,
36                               DenseMap<Value, Position *> &inputs,
37                               AttributePosition *pos) {
38   assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
39   pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
40   predList.emplace_back(pos, builder.getIsNotNull());
41 
42   // If the attribute has a type or value, add a constraint.
43   if (Value type = attr.type())
44     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
45   else if (Attribute value = attr.valueAttr())
46     predList.emplace_back(pos, builder.getAttributeConstraint(value));
47 }
48 
49 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
50                               Value val, PredicateBuilder &builder,
51                               DenseMap<Value, Position *> &inputs,
52                               OperandPosition *pos) {
53   assert(val.getType().isa<pdl::ValueType>() && "expected value type");
54 
55   // Prevent traversal into a null value.
56   predList.emplace_back(pos, builder.getIsNotNull());
57 
58   // If this is a typed operand, add a type constraint.
59   if (auto in = val.getDefiningOp<pdl::OperandOp>()) {
60     if (Value type = in.type())
61       getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
62 
63     // Otherwise, recurse into a result node.
64   } else if (auto resultOp = val.getDefiningOp<pdl::ResultOp>()) {
65     OperationPosition *parentPos = builder.getParent(pos);
66     Position *resultPos = builder.getResult(parentPos, resultOp.index());
67     predList.emplace_back(parentPos, builder.getIsNotNull());
68     predList.emplace_back(resultPos, builder.getEqualTo(pos));
69     getTreePredicates(predList, resultOp.parent(), builder, inputs, parentPos);
70   }
71 }
72 
73 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
74                               Value val, PredicateBuilder &builder,
75                               DenseMap<Value, Position *> &inputs,
76                               OperationPosition *pos) {
77   assert(val.getType().isa<pdl::OperationType>() && "expected operation");
78   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
79   OperationPosition *opPos = cast<OperationPosition>(pos);
80 
81   // Ensure getDefiningOp returns a non-null operation.
82   if (!opPos->isRoot())
83     predList.emplace_back(pos, builder.getIsNotNull());
84 
85   // Check that this is the correct root operation.
86   if (Optional<StringRef> opName = op.name())
87     predList.emplace_back(pos, builder.getOperationName(*opName));
88 
89   // Check that the operation has the proper number of operands and results.
90   OperandRange operands = op.operands();
91   OperandRange types = op.types();
92   predList.emplace_back(pos, builder.getOperandCount(operands.size()));
93   predList.emplace_back(pos, builder.getResultCount(types.size()));
94 
95   // Recurse into any attributes, operands, or results.
96   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
97     getTreePredicates(
98         predList, std::get<1>(it), builder, inputs,
99         builder.getAttribute(opPos,
100                              std::get<0>(it).cast<StringAttr>().getValue()));
101   }
102   for (auto operandIt : llvm::enumerate(operands)) {
103     getTreePredicates(predList, operandIt.value(), builder, inputs,
104                       builder.getOperand(opPos, operandIt.index()));
105   }
106   for (auto &resultIt : llvm::enumerate(types)) {
107     auto *resultPos = builder.getResult(pos, resultIt.index());
108     predList.emplace_back(resultPos, builder.getIsNotNull());
109     getTreePredicates(predList, resultIt.value(), builder, inputs,
110                       builder.getType(resultPos));
111   }
112 }
113 
114 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
115                               Value val, PredicateBuilder &builder,
116                               DenseMap<Value, Position *> &inputs,
117                               TypePosition *pos) {
118   assert(val.getType().isa<pdl::TypeType>() && "expected value type");
119   pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
120 
121   // Check for a constraint on a constant type.
122   if (Optional<Type> type = typeOp.type())
123     predList.emplace_back(pos, builder.getTypeConstraint(*type));
124 }
125 
126 /// Collect the tree predicates anchored at the given value.
127 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
128                               Value val, PredicateBuilder &builder,
129                               DenseMap<Value, Position *> &inputs,
130                               Position *pos) {
131   // Make sure this input value is accessible to the rewrite.
132   auto it = inputs.try_emplace(val, pos);
133   if (!it.second) {
134     // If this is an input value that has been visited in the tree, add a
135     // constraint to ensure that both instances refer to the same value.
136     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperationOp, pdl::TypeOp>(
137             val.getDefiningOp())) {
138       auto minMaxPositions =
139           std::minmax(pos, it.first->second, comparePosDepth);
140       predList.emplace_back(minMaxPositions.second,
141                             builder.getEqualTo(minMaxPositions.first));
142     }
143     return;
144   }
145 
146   TypeSwitch<Position *>(pos)
147       .Case<AttributePosition, OperandPosition, OperationPosition,
148             TypePosition>([&](auto *derivedPos) {
149         getTreePredicates(predList, val, builder, inputs, derivedPos);
150       })
151       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
152 }
153 
154 /// Collect all of the predicates related to constraints within the given
155 /// pattern operation.
156 static void getConstraintPredicates(pdl::ApplyConstraintOp op,
157                                     std::vector<PositionalPredicate> &predList,
158                                     PredicateBuilder &builder,
159                                     DenseMap<Value, Position *> &inputs) {
160   OperandRange arguments = op.args();
161   ArrayAttr parameters = op.constParamsAttr();
162 
163   std::vector<Position *> allPositions;
164   allPositions.reserve(arguments.size());
165   for (Value arg : arguments)
166     allPositions.push_back(inputs.lookup(arg));
167 
168   // Push the constraint to the furthest position.
169   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
170                                     comparePosDepth);
171   PredicateBuilder::Predicate pred =
172       builder.getConstraint(op.name(), std::move(allPositions), parameters);
173   predList.emplace_back(pos, pred);
174 }
175 
176 static void getResultPredicates(pdl::ResultOp op,
177                                 std::vector<PositionalPredicate> &predList,
178                                 PredicateBuilder &builder,
179                                 DenseMap<Value, Position *> &inputs) {
180   Position *&resultPos = inputs[op];
181   if (resultPos)
182     return;
183   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
184   resultPos = builder.getResult(parentPos, op.index());
185   predList.emplace_back(resultPos, builder.getIsNotNull());
186 }
187 
188 /// Collect all of the predicates that cannot be determined via walking the
189 /// tree.
190 static void getNonTreePredicates(pdl::PatternOp pattern,
191                                  std::vector<PositionalPredicate> &predList,
192                                  PredicateBuilder &builder,
193                                  DenseMap<Value, Position *> &inputs) {
194   for (Operation &op : pattern.body().getOps()) {
195     if (auto constraintOp = dyn_cast<pdl::ApplyConstraintOp>(&op))
196       getConstraintPredicates(constraintOp, predList, builder, inputs);
197     else if (auto resultOp = dyn_cast<pdl::ResultOp>(&op))
198       getResultPredicates(resultOp, predList, builder, inputs);
199   }
200 }
201 
202 /// Given a pattern operation, build the set of matcher predicates necessary to
203 /// match this pattern.
204 static void buildPredicateList(pdl::PatternOp pattern,
205                                PredicateBuilder &builder,
206                                std::vector<PositionalPredicate> &predList,
207                                DenseMap<Value, Position *> &valueToPosition) {
208   getTreePredicates(predList, pattern.getRewriter().root(), builder,
209                     valueToPosition, builder.getRoot());
210   getNonTreePredicates(pattern, predList, builder, valueToPosition);
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // Pattern Predicate Tree Merging
215 //===----------------------------------------------------------------------===//
216 
217 namespace {
218 
219 /// This class represents a specific predicate applied to a position, and
220 /// provides hashing and ordering operators. This class allows for computing a
221 /// frequence sum and ordering predicates based on a cost model.
222 struct OrderedPredicate {
223   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
224       : position(ip.first), question(ip.second) {}
225   OrderedPredicate(const PositionalPredicate &ip)
226       : position(ip.position), question(ip.question) {}
227 
228   /// The position this predicate is applied to.
229   Position *position;
230 
231   /// The question that is applied by this predicate onto the position.
232   Qualifier *question;
233 
234   /// The first and second order benefit sums.
235   /// The primary sum is the number of occurrences of this predicate among all
236   /// of the patterns.
237   unsigned primary = 0;
238   /// The secondary sum is a squared summation of the primary sum of all of the
239   /// predicates within each pattern that contains this predicate. This allows
240   /// for favoring predicates that are more commonly shared within a pattern, as
241   /// opposed to those shared across patterns.
242   unsigned secondary = 0;
243 
244   /// A map between a pattern operation and the answer to the predicate question
245   /// within that pattern.
246   DenseMap<Operation *, Qualifier *> patternToAnswer;
247 
248   /// Returns true if this predicate is ordered before `rhs`, based on the cost
249   /// model.
250   bool operator<(const OrderedPredicate &rhs) const {
251     // Sort by:
252     // * higher first and secondary order sums
253     // * lower depth
254     // * lower position dependency
255     // * lower predicate dependency
256     auto *rhsPos = rhs.position;
257     return std::make_tuple(primary, secondary, rhsPos->getIndex().size(),
258                            rhsPos->getKind(), rhs.question->getKind()) >
259            std::make_tuple(rhs.primary, rhs.secondary,
260                            position->getIndex().size(), position->getKind(),
261                            question->getKind());
262   }
263 };
264 
265 /// A DenseMapInfo for OrderedPredicate based solely on the position and
266 /// question.
267 struct OrderedPredicateDenseInfo {
268   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
269 
270   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
271   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
272   static bool isEqual(const OrderedPredicate &lhs,
273                       const OrderedPredicate &rhs) {
274     return lhs.position == rhs.position && lhs.question == rhs.question;
275   }
276   static unsigned getHashValue(const OrderedPredicate &p) {
277     return llvm::hash_combine(p.position, p.question);
278   }
279 };
280 
281 /// This class wraps a set of ordered predicates that are used within a specific
282 /// pattern operation.
283 struct OrderedPredicateList {
284   OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {}
285 
286   pdl::PatternOp pattern;
287   DenseSet<OrderedPredicate *> predicates;
288 };
289 } // end anonymous namespace
290 
291 /// Returns true if the given matcher refers to the same predicate as the given
292 /// ordered predicate. This means that the position and questions of the two
293 /// match.
294 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
295   return node->getPosition() == predicate->position &&
296          node->getQuestion() == predicate->question;
297 }
298 
299 /// Get or insert a child matcher for the given parent switch node, given a
300 /// predicate and parent pattern.
301 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
302                                                OrderedPredicate *predicate,
303                                                pdl::PatternOp pattern) {
304   assert(isSamePredicate(node, predicate) &&
305          "expected matcher to equal the given predicate");
306 
307   auto it = predicate->patternToAnswer.find(pattern);
308   assert(it != predicate->patternToAnswer.end() &&
309          "expected pattern to exist in predicate");
310   return node->getChildren().insert({it->second, nullptr}).first->second;
311 }
312 
313 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
314 /// order. A pattern will traverse as far as possible using common predicates
315 /// and then either diverge from the CFG or reach the end of a branch and start
316 /// creating new nodes.
317 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
318                              OrderedPredicateList &list,
319                              std::vector<OrderedPredicate *>::iterator current,
320                              std::vector<OrderedPredicate *>::iterator end) {
321   if (current == end) {
322     // We've hit the end of a pattern, so create a successful result node.
323     node = std::make_unique<SuccessNode>(list.pattern, std::move(node));
324 
325     // If the pattern doesn't contain this predicate, ignore it.
326   } else if (list.predicates.find(*current) == list.predicates.end()) {
327     propagatePattern(node, list, std::next(current), end);
328 
329     // If the current matcher node is invalid, create a new one for this
330     // position and continue propagation.
331   } else if (!node) {
332     // Create a new node at this position and continue
333     node = std::make_unique<SwitchNode>((*current)->position,
334                                         (*current)->question);
335     propagatePattern(
336         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
337         list, std::next(current), end);
338 
339     // If the matcher has already been created, and it is for this predicate we
340     // continue propagation to the child.
341   } else if (isSamePredicate(node.get(), *current)) {
342     propagatePattern(
343         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
344         list, std::next(current), end);
345 
346     // If the matcher doesn't match the current predicate, insert a branch as
347     // the common set of matchers has diverged.
348   } else {
349     propagatePattern(node->getFailureNode(), list, current, end);
350   }
351 }
352 
353 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
354 /// `node` is updated in-place if it is a switch.
355 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
356   if (!node)
357     return;
358 
359   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
360     SwitchNode::ChildMapT &children = switchNode->getChildren();
361     for (auto &it : children)
362       foldSwitchToBool(it.second);
363 
364     // If the node only contains one child, collapse it into a boolean predicate
365     // node.
366     if (children.size() == 1) {
367       auto childIt = children.begin();
368       node = std::make_unique<BoolNode>(
369           node->getPosition(), node->getQuestion(), childIt->first,
370           std::move(childIt->second), std::move(node->getFailureNode()));
371     }
372   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
373     foldSwitchToBool(boolNode->getSuccessNode());
374   }
375 
376   foldSwitchToBool(node->getFailureNode());
377 }
378 
379 /// Insert an exit node at the end of the failure path of the `root`.
380 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
381   while (*root)
382     root = &(*root)->getFailureNode();
383   *root = std::make_unique<ExitNode>();
384 }
385 
386 /// Given a module containing PDL pattern operations, generate a matcher tree
387 /// using the patterns within the given module and return the root matcher node.
388 std::unique_ptr<MatcherNode>
389 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
390                                  DenseMap<Value, Position *> &valueToPosition) {
391   // Collect the set of predicates contained within the pattern operations of
392   // the module.
393   SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16>
394       patternsAndPredicates;
395   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
396     std::vector<PositionalPredicate> predicateList;
397     buildPredicateList(pattern, builder, predicateList, valueToPosition);
398     patternsAndPredicates.emplace_back(pattern, std::move(predicateList));
399   }
400 
401   // Associate a pattern result with each unique predicate.
402   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
403   for (auto &patternAndPredList : patternsAndPredicates) {
404     for (auto &predicate : patternAndPredList.second) {
405       auto it = uniqued.insert(predicate);
406       it.first->patternToAnswer.try_emplace(patternAndPredList.first,
407                                             predicate.answer);
408     }
409   }
410 
411   // Associate each pattern to a set of its ordered predicates for later lookup.
412   std::vector<OrderedPredicateList> lists;
413   lists.reserve(patternsAndPredicates.size());
414   for (auto &patternAndPredList : patternsAndPredicates) {
415     OrderedPredicateList list(patternAndPredList.first);
416     for (auto &predicate : patternAndPredList.second) {
417       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
418       list.predicates.insert(orderedPredicate);
419 
420       // Increment the primary sum for each reference to a particular predicate.
421       ++orderedPredicate->primary;
422     }
423     lists.push_back(std::move(list));
424   }
425 
426   // For a particular pattern, get the total primary sum and add it to the
427   // secondary sum of each predicate. Square the primary sums to emphasize
428   // shared predicates within rather than across patterns.
429   for (auto &list : lists) {
430     unsigned total = 0;
431     for (auto *predicate : list.predicates)
432       total += predicate->primary * predicate->primary;
433     for (auto *predicate : list.predicates)
434       predicate->secondary += total;
435   }
436 
437   // Sort the set of predicates now that the cost primary and secondary sums
438   // have been computed.
439   std::vector<OrderedPredicate *> ordered;
440   ordered.reserve(uniqued.size());
441   for (auto &ip : uniqued)
442     ordered.push_back(&ip);
443   std::stable_sort(
444       ordered.begin(), ordered.end(),
445       [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
446 
447   // Build the matchers for each of the pattern predicate lists.
448   std::unique_ptr<MatcherNode> root;
449   for (OrderedPredicateList &list : lists)
450     propagatePattern(root, list, ordered.begin(), ordered.end());
451 
452   // Collapse the graph and insert the exit node.
453   foldSwitchToBool(root);
454   insertExitNode(&root);
455   return root;
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // MatcherNode
460 //===----------------------------------------------------------------------===//
461 
462 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
463                          std::unique_ptr<MatcherNode> failureNode)
464     : position(p), question(q), failureNode(std::move(failureNode)),
465       matcherTypeID(matcherTypeID) {}
466 
467 //===----------------------------------------------------------------------===//
468 // BoolNode
469 //===----------------------------------------------------------------------===//
470 
471 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
472                    std::unique_ptr<MatcherNode> successNode,
473                    std::unique_ptr<MatcherNode> failureNode)
474     : MatcherNode(TypeID::get<BoolNode>(), position, question,
475                   std::move(failureNode)),
476       answer(answer), successNode(std::move(successNode)) {}
477 
478 //===----------------------------------------------------------------------===//
479 // SuccessNode
480 //===----------------------------------------------------------------------===//
481 
482 SuccessNode::SuccessNode(pdl::PatternOp pattern,
483                          std::unique_ptr<MatcherNode> failureNode)
484     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
485                   /*question=*/nullptr, std::move(failureNode)),
486       pattern(pattern) {}
487 
488 //===----------------------------------------------------------------------===//
489 // SwitchNode
490 //===----------------------------------------------------------------------===//
491 
492 SwitchNode::SwitchNode(Position *position, Qualifier *question)
493     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
494