xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision 8c258fda1f9b33fc847cfff9c2ee33ce4615ff0b)
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Interfaces/InferTypeOpInterface.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include <queue>
21 
22 #define DEBUG_TYPE "pdl-predicate-tree"
23 
24 using namespace mlir;
25 using namespace mlir::pdl_to_pdl_interp;
26 
27 //===----------------------------------------------------------------------===//
28 // Predicate List Building
29 //===----------------------------------------------------------------------===//
30 
31 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32                               Value val, PredicateBuilder &builder,
33                               DenseMap<Value, Position *> &inputs,
34                               Position *pos);
35 
36 /// Compares the depths of two positions.
37 static bool comparePosDepth(Position *lhs, Position *rhs) {
38   return lhs->getOperationDepth() < rhs->getOperationDepth();
39 }
40 
41 /// Returns the number of non-range elements within `values`.
42 static unsigned getNumNonRangeValues(ValueRange values) {
43   return llvm::count_if(values.getTypes(),
44                         [](Type type) { return !type.isa<pdl::RangeType>(); });
45 }
46 
47 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
48                               Value val, PredicateBuilder &builder,
49                               DenseMap<Value, Position *> &inputs,
50                               AttributePosition *pos) {
51   assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
52   pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
53   predList.emplace_back(pos, builder.getIsNotNull());
54 
55   // If the attribute has a type or value, add a constraint.
56   if (Value type = attr.getValueType())
57     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58   else if (Attribute value = attr.getValueAttr())
59     predList.emplace_back(pos, builder.getAttributeConstraint(value));
60 }
61 
62 /// Collect all of the predicates for the given operand position.
63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64                                      Value val, PredicateBuilder &builder,
65                                      DenseMap<Value, Position *> &inputs,
66                                      Position *pos) {
67   Type valueType = val.getType();
68   bool isVariadic = valueType.isa<pdl::RangeType>();
69 
70   // If this is a typed operand, add a type constraint.
71   TypeSwitch<Operation *>(val.getDefiningOp())
72       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73         // Prevent traversal into a null value if the operand has a proper
74         // index.
75         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77           predList.emplace_back(pos, builder.getIsNotNull());
78 
79         if (Value type = op.getValueType())
80           getTreePredicates(predList, type, builder, inputs,
81                             builder.getType(pos));
82       })
83       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84         std::optional<unsigned> index = op.getIndex();
85 
86         // Prevent traversal into a null value if the result has a proper index.
87         if (index)
88           predList.emplace_back(pos, builder.getIsNotNull());
89 
90         // Get the parent operation of this operand.
91         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92         predList.emplace_back(parentPos, builder.getIsNotNull());
93 
94         // Ensure that the operands match the corresponding results of the
95         // parent operation.
96         Position *resultPos = nullptr;
97         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98           resultPos = builder.getResult(parentPos, *index);
99         else
100           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101         predList.emplace_back(resultPos, builder.getEqualTo(pos));
102 
103         // Collect the predicates of the parent operation.
104         getTreePredicates(predList, op.getParent(), builder, inputs,
105                           (Position *)parentPos);
106       });
107 }
108 
109 static void
110 getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
111                   PredicateBuilder &builder,
112                   DenseMap<Value, Position *> &inputs, OperationPosition *pos,
113                   std::optional<unsigned> ignoreOperand = std::nullopt) {
114   assert(val.getType().isa<pdl::OperationType>() && "expected operation");
115   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116   OperationPosition *opPos = cast<OperationPosition>(pos);
117 
118   // Ensure getDefiningOp returns a non-null operation.
119   if (!opPos->isRoot())
120     predList.emplace_back(pos, builder.getIsNotNull());
121 
122   // Check that this is the correct root operation.
123   if (std::optional<StringRef> opName = op.getOpName())
124     predList.emplace_back(pos, builder.getOperationName(*opName));
125 
126   // Check that the operation has the proper number of operands. If there are
127   // any variable length operands, we check a minimum instead of an exact count.
128   OperandRange operands = op.getOperandValues();
129   unsigned minOperands = getNumNonRangeValues(operands);
130   if (minOperands != operands.size()) {
131     if (minOperands)
132       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133   } else {
134     predList.emplace_back(pos, builder.getOperandCount(minOperands));
135   }
136 
137   // Check that the operation has the proper number of results. If there are
138   // any variable length results, we check a minimum instead of an exact count.
139   OperandRange types = op.getTypeValues();
140   unsigned minResults = getNumNonRangeValues(types);
141   if (minResults == types.size())
142     predList.emplace_back(pos, builder.getResultCount(types.size()));
143   else if (minResults)
144     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145 
146   // Recurse into any attributes, operands, or results.
147   for (auto [attrName, attr] :
148        llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
149     getTreePredicates(
150         predList, attr, builder, inputs,
151         builder.getAttribute(opPos, attrName.cast<StringAttr>().getValue()));
152   }
153 
154   // Process the operands and results of the operation. For all values up to
155   // the first variable length value, we use the concrete operand/result
156   // number. After that, we use the "group" given that we can't know the
157   // concrete indices until runtime. If there is only one variadic operand
158   // group, we treat it as all of the operands/results of the operation.
159   /// Operands.
160   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
161     // Ignore the operands if we are performing an upward traversal (in that
162     // case, they have already been visited).
163     if (opPos->isRoot() || opPos->isOperandDefiningOp())
164       getTreePredicates(predList, operands.front(), builder, inputs,
165                         builder.getAllOperands(opPos));
166   } else {
167     bool foundVariableLength = false;
168     for (const auto &operandIt : llvm::enumerate(operands)) {
169       bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
170       foundVariableLength |= isVariadic;
171 
172       // Ignore the specified operand, usually because this position was
173       // visited in an upward traversal via an iterative choice.
174       if (ignoreOperand && *ignoreOperand == operandIt.index())
175         continue;
176 
177       Position *pos =
178           foundVariableLength
179               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
180               : builder.getOperand(opPos, operandIt.index());
181       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182     }
183   }
184   /// Results.
185   if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
186     getTreePredicates(predList, types.front(), builder, inputs,
187                       builder.getType(builder.getAllResults(opPos)));
188     return;
189   }
190 
191   bool foundVariableLength = false;
192   for (auto [idx, typeValue] : llvm::enumerate(types)) {
193     bool isVariadic = typeValue.getType().isa<pdl::RangeType>();
194     foundVariableLength |= isVariadic;
195 
196     auto *resultPos = foundVariableLength
197                           ? builder.getResultGroup(pos, idx, isVariadic)
198                           : builder.getResult(pos, idx);
199     predList.emplace_back(resultPos, builder.getIsNotNull());
200     getTreePredicates(predList, typeValue, builder, inputs,
201                       builder.getType(resultPos));
202   }
203 }
204 
205 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206                               Value val, PredicateBuilder &builder,
207                               DenseMap<Value, Position *> &inputs,
208                               TypePosition *pos) {
209   // Check for a constraint on a constant type.
210   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
211     if (Attribute type = typeOp.getConstantTypeAttr())
212       predList.emplace_back(pos, builder.getTypeConstraint(type));
213   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
214     if (Attribute typeAttr = typeOp.getConstantTypesAttr())
215       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
216   }
217 }
218 
219 /// Collect the tree predicates anchored at the given value.
220 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221                               Value val, PredicateBuilder &builder,
222                               DenseMap<Value, Position *> &inputs,
223                               Position *pos) {
224   // Make sure this input value is accessible to the rewrite.
225   auto it = inputs.try_emplace(val, pos);
226   if (!it.second) {
227     // If this is an input value that has been visited in the tree, add a
228     // constraint to ensure that both instances refer to the same value.
229     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230             pdl::TypeOp>(val.getDefiningOp())) {
231       auto minMaxPositions =
232           std::minmax(pos, it.first->second, comparePosDepth);
233       predList.emplace_back(minMaxPositions.second,
234                             builder.getEqualTo(minMaxPositions.first));
235     }
236     return;
237   }
238 
239   TypeSwitch<Position *>(pos)
240       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
241         getTreePredicates(predList, val, builder, inputs, pos);
242       })
243       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
244         getOperandTreePredicates(predList, val, builder, inputs, pos);
245       })
246       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
247 }
248 
249 static void getAttributePredicates(pdl::AttributeOp op,
250                                    std::vector<PositionalPredicate> &predList,
251                                    PredicateBuilder &builder,
252                                    DenseMap<Value, Position *> &inputs) {
253   Position *&attrPos = inputs[op];
254   if (attrPos)
255     return;
256   Attribute value = op.getValueAttr();
257   assert(value && "expected non-tree `pdl.attribute` to contain a value");
258   attrPos = builder.getAttributeLiteral(value);
259 }
260 
261 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262                                     std::vector<PositionalPredicate> &predList,
263                                     PredicateBuilder &builder,
264                                     DenseMap<Value, Position *> &inputs) {
265   OperandRange arguments = op.getArgs();
266 
267   std::vector<Position *> allPositions;
268   allPositions.reserve(arguments.size());
269   for (Value arg : arguments)
270     allPositions.push_back(inputs.lookup(arg));
271 
272   // Push the constraint to the furthest position.
273   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
274                                     comparePosDepth);
275   PredicateBuilder::Predicate pred =
276       builder.getConstraint(op.getName(), allPositions);
277   predList.emplace_back(pos, pred);
278 }
279 
280 static void getResultPredicates(pdl::ResultOp op,
281                                 std::vector<PositionalPredicate> &predList,
282                                 PredicateBuilder &builder,
283                                 DenseMap<Value, Position *> &inputs) {
284   Position *&resultPos = inputs[op];
285   if (resultPos)
286     return;
287 
288   // Ensure that the result isn't null.
289   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
290   resultPos = builder.getResult(parentPos, op.getIndex());
291   predList.emplace_back(resultPos, builder.getIsNotNull());
292 }
293 
294 static void getResultPredicates(pdl::ResultsOp op,
295                                 std::vector<PositionalPredicate> &predList,
296                                 PredicateBuilder &builder,
297                                 DenseMap<Value, Position *> &inputs) {
298   Position *&resultPos = inputs[op];
299   if (resultPos)
300     return;
301 
302   // Ensure that the result isn't null if the result has an index.
303   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
304   bool isVariadic = op.getType().isa<pdl::RangeType>();
305   std::optional<unsigned> index = op.getIndex();
306   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
307   if (index)
308     predList.emplace_back(resultPos, builder.getIsNotNull());
309 }
310 
311 static void getTypePredicates(Value typeValue,
312                               function_ref<Attribute()> typeAttrFn,
313                               PredicateBuilder &builder,
314                               DenseMap<Value, Position *> &inputs) {
315   Position *&typePos = inputs[typeValue];
316   if (typePos)
317     return;
318   Attribute typeAttr = typeAttrFn();
319   assert(typeAttr &&
320          "expected non-tree `pdl.type`/`pdl.types` to contain a value");
321   typePos = builder.getTypeLiteral(typeAttr);
322 }
323 
324 /// Collect all of the predicates that cannot be determined via walking the
325 /// tree.
326 static void getNonTreePredicates(pdl::PatternOp pattern,
327                                  std::vector<PositionalPredicate> &predList,
328                                  PredicateBuilder &builder,
329                                  DenseMap<Value, Position *> &inputs) {
330   for (Operation &op : pattern.getBodyRegion().getOps()) {
331     TypeSwitch<Operation *>(&op)
332         .Case([&](pdl::AttributeOp attrOp) {
333           getAttributePredicates(attrOp, predList, builder, inputs);
334         })
335         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
336           getConstraintPredicates(constraintOp, predList, builder, inputs);
337         })
338         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
339           getResultPredicates(resultOp, predList, builder, inputs);
340         })
341         .Case([&](pdl::TypeOp typeOp) {
342           getTypePredicates(
343               typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
344               inputs);
345         })
346         .Case([&](pdl::TypesOp typeOp) {
347           getTypePredicates(
348               typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
349               inputs);
350         });
351   }
352 }
353 
354 namespace {
355 
356 /// An op accepting a value at an optional index.
357 struct OpIndex {
358   Value parent;
359   std::optional<unsigned> index;
360 };
361 
362 /// The parent and operand index of each operation for each root, stored
363 /// as a nested map [root][operation].
364 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
365 
366 } // namespace
367 
368 /// Given a pattern, determines the set of roots present in this pattern.
369 /// These are the operations whose results are not consumed by other operations.
370 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
371   // First, collect all the operations that are used as operands
372   // to other operations. These are not roots by default.
373   DenseSet<Value> used;
374   for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
375     for (Value operand : operationOp.getOperandValues())
376       TypeSwitch<Operation *>(operand.getDefiningOp())
377           .Case<pdl::ResultOp, pdl::ResultsOp>(
378               [&used](auto resultOp) { used.insert(resultOp.getParent()); });
379   }
380 
381   // Remove the specified root from the use set, so that we can
382   // always select it as a root, even if it is used by other operations.
383   if (Value root = pattern.getRewriter().getRoot())
384     used.erase(root);
385 
386   // Finally, collect all the unused operations.
387   SmallVector<Value> roots;
388   for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
389     if (!used.contains(operationOp))
390       roots.push_back(operationOp);
391 
392   return roots;
393 }
394 
395 /// Given a list of candidate roots, builds the cost graph for connecting them.
396 /// The graph is formed by traversing the DAG of operations starting from each
397 /// root and marking the depth of each connector value (operand). Then we join
398 /// the candidate roots based on the common connector values, taking the one
399 /// with the minimum depth. Along the way, we compute, for each candidate root,
400 /// a mapping from each operation (in the DAG underneath this root) to its
401 /// parent operation and the corresponding operand index.
402 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
403                            ParentMaps &parentMaps) {
404 
405   // The entry of a queue. The entry consists of the following items:
406   // * the value in the DAG underneath the root;
407   // * the parent of the value;
408   // * the operand index of the value in its parent;
409   // * the depth of the visited value.
410   struct Entry {
411     Entry(Value value, Value parent, std::optional<unsigned> index,
412           unsigned depth)
413         : value(value), parent(parent), index(index), depth(depth) {}
414 
415     Value value;
416     Value parent;
417     std::optional<unsigned> index;
418     unsigned depth;
419   };
420 
421   // A root of a value and its depth (distance from root to the value).
422   struct RootDepth {
423     Value root;
424     unsigned depth = 0;
425   };
426 
427   // Map from candidate connector values to their roots and depths. Using a
428   // small vector with 1 entry because most values belong to a single root.
429   llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
430 
431   // Perform a breadth-first traversal of the op DAG rooted at each root.
432   for (Value root : roots) {
433     // The queue of visited values. A value may be present multiple times in
434     // the queue, for multiple parents. We only accept the first occurrence,
435     // which is guaranteed to have the lowest depth.
436     std::queue<Entry> toVisit;
437     toVisit.emplace(root, Value(), 0, 0);
438 
439     // The map from value to its parent for the current root.
440     DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
441 
442     while (!toVisit.empty()) {
443       Entry entry = toVisit.front();
444       toVisit.pop();
445       // Skip if already visited.
446       if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
447         continue;
448 
449       // Mark the root and depth of the value.
450       connectorsRootsDepths[entry.value].push_back({root, entry.depth});
451 
452       // Traverse the operands of an operation and result ops.
453       // We intentionally do not traverse attributes and types, because those
454       // are expensive to join on.
455       TypeSwitch<Operation *>(entry.value.getDefiningOp())
456           .Case<pdl::OperationOp>([&](auto operationOp) {
457             OperandRange operands = operationOp.getOperandValues();
458             // Special case when we pass all the operands in one range.
459             // For those, the index is empty.
460             if (operands.size() == 1 &&
461                 operands[0].getType().isa<pdl::RangeType>()) {
462               toVisit.emplace(operands[0], entry.value, std::nullopt,
463                               entry.depth + 1);
464               return;
465             }
466 
467             // Default case: visit all the operands.
468             for (const auto &p :
469                  llvm::enumerate(operationOp.getOperandValues()))
470               toVisit.emplace(p.value(), entry.value, p.index(),
471                               entry.depth + 1);
472           })
473           .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
474             toVisit.emplace(resultOp.getParent(), entry.value,
475                             resultOp.getIndex(), entry.depth);
476           });
477     }
478   }
479 
480   // Now build the cost graph.
481   // This is simply a minimum over all depths for the target root.
482   unsigned nextID = 0;
483   for (const auto &connectorRootsDepths : connectorsRootsDepths) {
484     Value value = connectorRootsDepths.first;
485     ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
486     // If there is only one root for this value, this will not trigger
487     // any edges in the cost graph (a perf optimization).
488     if (rootsDepths.size() == 1)
489       continue;
490 
491     for (const RootDepth &p : rootsDepths) {
492       for (const RootDepth &q : rootsDepths) {
493         if (&p == &q)
494           continue;
495         // Insert or retrieve the property of edge from p to q.
496         RootOrderingEntry &entry = graph[q.root][p.root];
497         if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
498           if (!entry.connector)
499             entry.cost.second = nextID++;
500           entry.cost.first = q.depth;
501           entry.connector = value;
502         }
503       }
504     }
505   }
506 
507   assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
508          "the pattern contains a candidate root disconnected from the others");
509 }
510 
511 /// Returns true if the operand at the given index needs to be queried using an
512 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
513 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
514   OperandRange operands = op.getOperandValues();
515   assert(index < operands.size() && "operand index out of range");
516   for (unsigned i = 0; i <= index; ++i)
517     if (operands[i].getType().isa<pdl::RangeType>())
518       return true;
519   return false;
520 }
521 
522 /// Visit a node during upward traversal.
523 static void visitUpward(std::vector<PositionalPredicate> &predList,
524                         OpIndex opIndex, PredicateBuilder &builder,
525                         DenseMap<Value, Position *> &valueToPosition,
526                         Position *&pos, unsigned rootID) {
527   Value value = opIndex.parent;
528   TypeSwitch<Operation *>(value.getDefiningOp())
529       .Case<pdl::OperationOp>([&](auto operationOp) {
530         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
531 
532         // Get users and iterate over them.
533         Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
534         Position *foreachPos = builder.getForEach(usersPos, rootID);
535         OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
536 
537         // Compare the operand(s) of the user against the input value(s).
538         Position *operandPos;
539         if (!opIndex.index) {
540           // We are querying all the operands of the operation.
541           operandPos = builder.getAllOperands(opPos);
542         } else if (useOperandGroup(operationOp, *opIndex.index)) {
543           // We are querying an operand group.
544           Type type = operationOp.getOperandValues()[*opIndex.index].getType();
545           bool variadic = type.isa<pdl::RangeType>();
546           operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
547         } else {
548           // We are querying an individual operand.
549           operandPos = builder.getOperand(opPos, *opIndex.index);
550         }
551         predList.emplace_back(operandPos, builder.getEqualTo(pos));
552 
553         // Guard against duplicate upward visits. These are not possible,
554         // because if this value was already visited, it would have been
555         // cheaper to start the traversal at this value rather than at the
556         // `connector`, violating the optimality of our spanning tree.
557         bool inserted = valueToPosition.try_emplace(value, opPos).second;
558         (void)inserted;
559         assert(inserted && "duplicate upward visit");
560 
561         // Obtain the tree predicates at the current value.
562         getTreePredicates(predList, value, builder, valueToPosition, opPos,
563                           opIndex.index);
564 
565         // Update the position
566         pos = opPos;
567       })
568       .Case<pdl::ResultOp>([&](auto resultOp) {
569         // Traverse up an individual result.
570         auto *opPos = dyn_cast<OperationPosition>(pos);
571         assert(opPos && "operations and results must be interleaved");
572         pos = builder.getResult(opPos, *opIndex.index);
573 
574         // Insert the result position in case we have not visited it yet.
575         valueToPosition.try_emplace(value, pos);
576       })
577       .Case<pdl::ResultsOp>([&](auto resultOp) {
578         // Traverse up a group of results.
579         auto *opPos = dyn_cast<OperationPosition>(pos);
580         assert(opPos && "operations and results must be interleaved");
581         bool isVariadic = value.getType().isa<pdl::RangeType>();
582         if (opIndex.index)
583           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
584         else
585           pos = builder.getAllResults(opPos);
586 
587         // Insert the result position in case we have not visited it yet.
588         valueToPosition.try_emplace(value, pos);
589       });
590 }
591 
592 /// Given a pattern operation, build the set of matcher predicates necessary to
593 /// match this pattern.
594 static Value buildPredicateList(pdl::PatternOp pattern,
595                                 PredicateBuilder &builder,
596                                 std::vector<PositionalPredicate> &predList,
597                                 DenseMap<Value, Position *> &valueToPosition) {
598   SmallVector<Value> roots = detectRoots(pattern);
599 
600   // Build the root ordering graph and compute the parent maps.
601   RootOrderingGraph graph;
602   ParentMaps parentMaps;
603   buildCostGraph(roots, graph, parentMaps);
604   LLVM_DEBUG({
605     llvm::dbgs() << "Graph:\n";
606     for (auto &target : graph) {
607       llvm::dbgs() << "  * " << target.first.getLoc() << " " << target.first
608                    << "\n";
609       for (auto &source : target.second) {
610         RootOrderingEntry &entry = source.second;
611         llvm::dbgs() << "      <- " << source.first << ": " << entry.cost.first
612                      << ":" << entry.cost.second << " via "
613                      << entry.connector.getLoc() << "\n";
614       }
615     }
616   });
617 
618   // Solve the optimal branching problem for each candidate root, or use the
619   // provided one.
620   Value bestRoot = pattern.getRewriter().getRoot();
621   OptimalBranching::EdgeList bestEdges;
622   if (!bestRoot) {
623     unsigned bestCost = 0;
624     LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
625     for (Value root : roots) {
626       OptimalBranching solver(graph, root);
627       unsigned cost = solver.solve();
628       LLVM_DEBUG(llvm::dbgs() << "  * " << root << ": " << cost << "\n");
629       if (!bestRoot || bestCost > cost) {
630         bestCost = cost;
631         bestRoot = root;
632         bestEdges = solver.preOrderTraversal(roots);
633       }
634     }
635   } else {
636     OptimalBranching solver(graph, bestRoot);
637     solver.solve();
638     bestEdges = solver.preOrderTraversal(roots);
639   }
640 
641   // Print the best solution.
642   LLVM_DEBUG({
643     llvm::dbgs() << "Best tree:\n";
644     for (const std::pair<Value, Value> &edge : bestEdges) {
645       llvm::dbgs() << "  * " << edge.first;
646       if (edge.second)
647         llvm::dbgs() << " <- " << edge.second;
648       llvm::dbgs() << "\n";
649     }
650   });
651 
652   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
653   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
654 
655   // The best root is the starting point for the traversal. Get the tree
656   // predicates for the DAG rooted at bestRoot.
657   getTreePredicates(predList, bestRoot, builder, valueToPosition,
658                     builder.getRoot());
659 
660   // Traverse the selected optimal branching. For all edges in order, traverse
661   // up starting from the connector, until the candidate root is reached, and
662   // call getTreePredicates at every node along the way.
663   for (const auto &it : llvm::enumerate(bestEdges)) {
664     Value target = it.value().first;
665     Value source = it.value().second;
666 
667     // Check if we already visited the target root. This happens in two cases:
668     // 1) the initial root (bestRoot);
669     // 2) a root that is dominated by (contained in the subtree rooted at) an
670     //    already visited root.
671     if (valueToPosition.count(target))
672       continue;
673 
674     // Determine the connector.
675     Value connector = graph[target][source].connector;
676     assert(connector && "invalid edge");
677     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
678     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
679     Position *pos = valueToPosition.lookup(connector);
680     assert(pos && "connector has not been traversed yet");
681 
682     // Traverse from the connector upwards towards the target root.
683     for (Value value = connector; value != target;) {
684       OpIndex opIndex = parentMap.lookup(value);
685       assert(opIndex.parent && "missing parent");
686       visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
687       value = opIndex.parent;
688     }
689   }
690 
691   getNonTreePredicates(pattern, predList, builder, valueToPosition);
692 
693   return bestRoot;
694 }
695 
696 //===----------------------------------------------------------------------===//
697 // Pattern Predicate Tree Merging
698 //===----------------------------------------------------------------------===//
699 
700 namespace {
701 
702 /// This class represents a specific predicate applied to a position, and
703 /// provides hashing and ordering operators. This class allows for computing a
704 /// frequence sum and ordering predicates based on a cost model.
705 struct OrderedPredicate {
706   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
707       : position(ip.first), question(ip.second) {}
708   OrderedPredicate(const PositionalPredicate &ip)
709       : position(ip.position), question(ip.question) {}
710 
711   /// The position this predicate is applied to.
712   Position *position;
713 
714   /// The question that is applied by this predicate onto the position.
715   Qualifier *question;
716 
717   /// The first and second order benefit sums.
718   /// The primary sum is the number of occurrences of this predicate among all
719   /// of the patterns.
720   unsigned primary = 0;
721   /// The secondary sum is a squared summation of the primary sum of all of the
722   /// predicates within each pattern that contains this predicate. This allows
723   /// for favoring predicates that are more commonly shared within a pattern, as
724   /// opposed to those shared across patterns.
725   unsigned secondary = 0;
726 
727   /// The tie breaking ID, used to preserve a deterministic (insertion) order
728   /// among all the predicates with the same priority, depth, and position /
729   /// predicate dependency.
730   unsigned id = 0;
731 
732   /// A map between a pattern operation and the answer to the predicate question
733   /// within that pattern.
734   DenseMap<Operation *, Qualifier *> patternToAnswer;
735 
736   /// Returns true if this predicate is ordered before `rhs`, based on the cost
737   /// model.
738   bool operator<(const OrderedPredicate &rhs) const {
739     // Sort by:
740     // * higher first and secondary order sums
741     // * lower depth
742     // * lower position dependency
743     // * lower predicate dependency
744     // * lower tie breaking ID
745     auto *rhsPos = rhs.position;
746     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
747                            rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
748            std::make_tuple(rhs.primary, rhs.secondary,
749                            position->getOperationDepth(), position->getKind(),
750                            question->getKind(), id);
751   }
752 };
753 
754 /// A DenseMapInfo for OrderedPredicate based solely on the position and
755 /// question.
756 struct OrderedPredicateDenseInfo {
757   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
758 
759   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
760   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
761   static bool isEqual(const OrderedPredicate &lhs,
762                       const OrderedPredicate &rhs) {
763     return lhs.position == rhs.position && lhs.question == rhs.question;
764   }
765   static unsigned getHashValue(const OrderedPredicate &p) {
766     return llvm::hash_combine(p.position, p.question);
767   }
768 };
769 
770 /// This class wraps a set of ordered predicates that are used within a specific
771 /// pattern operation.
772 struct OrderedPredicateList {
773   OrderedPredicateList(pdl::PatternOp pattern, Value root)
774       : pattern(pattern), root(root) {}
775 
776   pdl::PatternOp pattern;
777   Value root;
778   DenseSet<OrderedPredicate *> predicates;
779 };
780 } // namespace
781 
782 /// Returns true if the given matcher refers to the same predicate as the given
783 /// ordered predicate. This means that the position and questions of the two
784 /// match.
785 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
786   return node->getPosition() == predicate->position &&
787          node->getQuestion() == predicate->question;
788 }
789 
790 /// Get or insert a child matcher for the given parent switch node, given a
791 /// predicate and parent pattern.
792 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
793                                                OrderedPredicate *predicate,
794                                                pdl::PatternOp pattern) {
795   assert(isSamePredicate(node, predicate) &&
796          "expected matcher to equal the given predicate");
797 
798   auto it = predicate->patternToAnswer.find(pattern);
799   assert(it != predicate->patternToAnswer.end() &&
800          "expected pattern to exist in predicate");
801   return node->getChildren().insert({it->second, nullptr}).first->second;
802 }
803 
804 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
805 /// order. A pattern will traverse as far as possible using common predicates
806 /// and then either diverge from the CFG or reach the end of a branch and start
807 /// creating new nodes.
808 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
809                              OrderedPredicateList &list,
810                              std::vector<OrderedPredicate *>::iterator current,
811                              std::vector<OrderedPredicate *>::iterator end) {
812   if (current == end) {
813     // We've hit the end of a pattern, so create a successful result node.
814     node =
815         std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
816 
817     // If the pattern doesn't contain this predicate, ignore it.
818   } else if (!list.predicates.contains(*current)) {
819     propagatePattern(node, list, std::next(current), end);
820 
821     // If the current matcher node is invalid, create a new one for this
822     // position and continue propagation.
823   } else if (!node) {
824     // Create a new node at this position and continue
825     node = std::make_unique<SwitchNode>((*current)->position,
826                                         (*current)->question);
827     propagatePattern(
828         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
829         list, std::next(current), end);
830 
831     // If the matcher has already been created, and it is for this predicate we
832     // continue propagation to the child.
833   } else if (isSamePredicate(node.get(), *current)) {
834     propagatePattern(
835         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
836         list, std::next(current), end);
837 
838     // If the matcher doesn't match the current predicate, insert a branch as
839     // the common set of matchers has diverged.
840   } else {
841     propagatePattern(node->getFailureNode(), list, current, end);
842   }
843 }
844 
845 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
846 /// `node` is updated in-place if it is a switch.
847 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
848   if (!node)
849     return;
850 
851   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
852     SwitchNode::ChildMapT &children = switchNode->getChildren();
853     for (auto &it : children)
854       foldSwitchToBool(it.second);
855 
856     // If the node only contains one child, collapse it into a boolean predicate
857     // node.
858     if (children.size() == 1) {
859       auto childIt = children.begin();
860       node = std::make_unique<BoolNode>(
861           node->getPosition(), node->getQuestion(), childIt->first,
862           std::move(childIt->second), std::move(node->getFailureNode()));
863     }
864   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
865     foldSwitchToBool(boolNode->getSuccessNode());
866   }
867 
868   foldSwitchToBool(node->getFailureNode());
869 }
870 
871 /// Insert an exit node at the end of the failure path of the `root`.
872 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
873   while (*root)
874     root = &(*root)->getFailureNode();
875   *root = std::make_unique<ExitNode>();
876 }
877 
878 /// Given a module containing PDL pattern operations, generate a matcher tree
879 /// using the patterns within the given module and return the root matcher node.
880 std::unique_ptr<MatcherNode>
881 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
882                                  DenseMap<Value, Position *> &valueToPosition) {
883   // The set of predicates contained within the pattern operations of the
884   // module.
885   struct PatternPredicates {
886     PatternPredicates(pdl::PatternOp pattern, Value root,
887                       std::vector<PositionalPredicate> predicates)
888         : pattern(pattern), root(root), predicates(std::move(predicates)) {}
889 
890     /// A pattern.
891     pdl::PatternOp pattern;
892 
893     /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
894     Value root;
895 
896     /// The extracted predicates for this pattern and root.
897     std::vector<PositionalPredicate> predicates;
898   };
899 
900   SmallVector<PatternPredicates, 16> patternsAndPredicates;
901   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
902     std::vector<PositionalPredicate> predicateList;
903     Value root =
904         buildPredicateList(pattern, builder, predicateList, valueToPosition);
905     patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
906   }
907 
908   // Associate a pattern result with each unique predicate.
909   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
910   for (auto &patternAndPredList : patternsAndPredicates) {
911     for (auto &predicate : patternAndPredList.predicates) {
912       auto it = uniqued.insert(predicate);
913       it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
914                                             predicate.answer);
915       // Mark the insertion order (0-based indexing).
916       if (it.second)
917         it.first->id = uniqued.size() - 1;
918     }
919   }
920 
921   // Associate each pattern to a set of its ordered predicates for later lookup.
922   std::vector<OrderedPredicateList> lists;
923   lists.reserve(patternsAndPredicates.size());
924   for (auto &patternAndPredList : patternsAndPredicates) {
925     OrderedPredicateList list(patternAndPredList.pattern,
926                               patternAndPredList.root);
927     for (auto &predicate : patternAndPredList.predicates) {
928       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
929       list.predicates.insert(orderedPredicate);
930 
931       // Increment the primary sum for each reference to a particular predicate.
932       ++orderedPredicate->primary;
933     }
934     lists.push_back(std::move(list));
935   }
936 
937   // For a particular pattern, get the total primary sum and add it to the
938   // secondary sum of each predicate. Square the primary sums to emphasize
939   // shared predicates within rather than across patterns.
940   for (auto &list : lists) {
941     unsigned total = 0;
942     for (auto *predicate : list.predicates)
943       total += predicate->primary * predicate->primary;
944     for (auto *predicate : list.predicates)
945       predicate->secondary += total;
946   }
947 
948   // Sort the set of predicates now that the cost primary and secondary sums
949   // have been computed.
950   std::vector<OrderedPredicate *> ordered;
951   ordered.reserve(uniqued.size());
952   for (auto &ip : uniqued)
953     ordered.push_back(&ip);
954   llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
955     return *lhs < *rhs;
956   });
957 
958   // Build the matchers for each of the pattern predicate lists.
959   std::unique_ptr<MatcherNode> root;
960   for (OrderedPredicateList &list : lists)
961     propagatePattern(root, list, ordered.begin(), ordered.end());
962 
963   // Collapse the graph and insert the exit node.
964   foldSwitchToBool(root);
965   insertExitNode(&root);
966   return root;
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // MatcherNode
971 //===----------------------------------------------------------------------===//
972 
973 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
974                          std::unique_ptr<MatcherNode> failureNode)
975     : position(p), question(q), failureNode(std::move(failureNode)),
976       matcherTypeID(matcherTypeID) {}
977 
978 //===----------------------------------------------------------------------===//
979 // BoolNode
980 //===----------------------------------------------------------------------===//
981 
982 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
983                    std::unique_ptr<MatcherNode> successNode,
984                    std::unique_ptr<MatcherNode> failureNode)
985     : MatcherNode(TypeID::get<BoolNode>(), position, question,
986                   std::move(failureNode)),
987       answer(answer), successNode(std::move(successNode)) {}
988 
989 //===----------------------------------------------------------------------===//
990 // SuccessNode
991 //===----------------------------------------------------------------------===//
992 
993 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
994                          std::unique_ptr<MatcherNode> failureNode)
995     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
996                   /*question=*/nullptr, std::move(failureNode)),
997       pattern(pattern), root(root) {}
998 
999 //===----------------------------------------------------------------------===//
1000 // SwitchNode
1001 //===----------------------------------------------------------------------===//
1002 
1003 SwitchNode::SwitchNode(Position *position, Qualifier *question)
1004     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
1005