xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision be0a7e9f27083ada6072fcc0711ffa5630daa5ec)
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Interfaces/InferTypeOpInterface.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include <queue>
21 
22 #define DEBUG_TYPE "pdl-predicate-tree"
23 
24 using namespace mlir;
25 using namespace mlir::pdl_to_pdl_interp;
26 
27 //===----------------------------------------------------------------------===//
28 // Predicate List Building
29 //===----------------------------------------------------------------------===//
30 
31 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32                               Value val, PredicateBuilder &builder,
33                               DenseMap<Value, Position *> &inputs,
34                               Position *pos);
35 
36 /// Compares the depths of two positions.
37 static bool comparePosDepth(Position *lhs, Position *rhs) {
38   return lhs->getOperationDepth() < rhs->getOperationDepth();
39 }
40 
41 /// Returns the number of non-range elements within `values`.
42 static unsigned getNumNonRangeValues(ValueRange values) {
43   return llvm::count_if(values.getTypes(),
44                         [](Type type) { return !type.isa<pdl::RangeType>(); });
45 }
46 
47 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
48                               Value val, PredicateBuilder &builder,
49                               DenseMap<Value, Position *> &inputs,
50                               AttributePosition *pos) {
51   assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
52   pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
53   predList.emplace_back(pos, builder.getIsNotNull());
54 
55   // If the attribute has a type or value, add a constraint.
56   if (Value type = attr.type())
57     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58   else if (Attribute value = attr.valueAttr())
59     predList.emplace_back(pos, builder.getAttributeConstraint(value));
60 }
61 
62 /// Collect all of the predicates for the given operand position.
63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64                                      Value val, PredicateBuilder &builder,
65                                      DenseMap<Value, Position *> &inputs,
66                                      Position *pos) {
67   Type valueType = val.getType();
68   bool isVariadic = valueType.isa<pdl::RangeType>();
69 
70   // If this is a typed operand, add a type constraint.
71   TypeSwitch<Operation *>(val.getDefiningOp())
72       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73         // Prevent traversal into a null value if the operand has a proper
74         // index.
75         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77           predList.emplace_back(pos, builder.getIsNotNull());
78 
79         if (Value type = op.type())
80           getTreePredicates(predList, type, builder, inputs,
81                             builder.getType(pos));
82       })
83       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84         Optional<unsigned> index = op.index();
85 
86         // Prevent traversal into a null value if the result has a proper index.
87         if (index)
88           predList.emplace_back(pos, builder.getIsNotNull());
89 
90         // Get the parent operation of this operand.
91         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92         predList.emplace_back(parentPos, builder.getIsNotNull());
93 
94         // Ensure that the operands match the corresponding results of the
95         // parent operation.
96         Position *resultPos = nullptr;
97         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98           resultPos = builder.getResult(parentPos, *index);
99         else
100           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101         predList.emplace_back(resultPos, builder.getEqualTo(pos));
102 
103         // Collect the predicates of the parent operation.
104         getTreePredicates(predList, op.parent(), builder, inputs,
105                           (Position *)parentPos);
106       });
107 }
108 
109 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
110                               Value val, PredicateBuilder &builder,
111                               DenseMap<Value, Position *> &inputs,
112                               OperationPosition *pos,
113                               Optional<unsigned> ignoreOperand = llvm::None) {
114   assert(val.getType().isa<pdl::OperationType>() && "expected operation");
115   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116   OperationPosition *opPos = cast<OperationPosition>(pos);
117 
118   // Ensure getDefiningOp returns a non-null operation.
119   if (!opPos->isRoot())
120     predList.emplace_back(pos, builder.getIsNotNull());
121 
122   // Check that this is the correct root operation.
123   if (Optional<StringRef> opName = op.name())
124     predList.emplace_back(pos, builder.getOperationName(*opName));
125 
126   // Check that the operation has the proper number of operands. If there are
127   // any variable length operands, we check a minimum instead of an exact count.
128   OperandRange operands = op.operands();
129   unsigned minOperands = getNumNonRangeValues(operands);
130   if (minOperands != operands.size()) {
131     if (minOperands)
132       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133   } else {
134     predList.emplace_back(pos, builder.getOperandCount(minOperands));
135   }
136 
137   // Check that the operation has the proper number of results. If there are
138   // any variable length results, we check a minimum instead of an exact count.
139   OperandRange types = op.types();
140   unsigned minResults = getNumNonRangeValues(types);
141   if (minResults == types.size())
142     predList.emplace_back(pos, builder.getResultCount(types.size()));
143   else if (minResults)
144     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145 
146   // Recurse into any attributes, operands, or results.
147   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
148     getTreePredicates(
149         predList, std::get<1>(it), builder, inputs,
150         builder.getAttribute(opPos,
151                              std::get<0>(it).cast<StringAttr>().getValue()));
152   }
153 
154   // Process the operands and results of the operation. For all values up to
155   // the first variable length value, we use the concrete operand/result
156   // number. After that, we use the "group" given that we can't know the
157   // concrete indices until runtime. If there is only one variadic operand
158   // group, we treat it as all of the operands/results of the operation.
159   /// Operands.
160   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
161     getTreePredicates(predList, operands.front(), builder, inputs,
162                       builder.getAllOperands(opPos));
163   } else {
164     bool foundVariableLength = false;
165     for (auto operandIt : llvm::enumerate(operands)) {
166       bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
167       foundVariableLength |= isVariadic;
168 
169       // Ignore the specified operand, usually because this position was
170       // visited in an upward traversal via an iterative choice.
171       if (ignoreOperand && *ignoreOperand == operandIt.index())
172         continue;
173 
174       Position *pos =
175           foundVariableLength
176               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
177               : builder.getOperand(opPos, operandIt.index());
178       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
179     }
180   }
181   /// Results.
182   if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
183     getTreePredicates(predList, types.front(), builder, inputs,
184                       builder.getType(builder.getAllResults(opPos)));
185   } else {
186     bool foundVariableLength = false;
187     for (auto &resultIt : llvm::enumerate(types)) {
188       bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
189       foundVariableLength |= isVariadic;
190 
191       auto *resultPos =
192           foundVariableLength
193               ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
194               : builder.getResult(pos, resultIt.index());
195       predList.emplace_back(resultPos, builder.getIsNotNull());
196       getTreePredicates(predList, resultIt.value(), builder, inputs,
197                         builder.getType(resultPos));
198     }
199   }
200 }
201 
202 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
203                               Value val, PredicateBuilder &builder,
204                               DenseMap<Value, Position *> &inputs,
205                               TypePosition *pos) {
206   // Check for a constraint on a constant type.
207   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
208     if (Attribute type = typeOp.typeAttr())
209       predList.emplace_back(pos, builder.getTypeConstraint(type));
210   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
211     if (Attribute typeAttr = typeOp.typesAttr())
212       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
213   }
214 }
215 
216 /// Collect the tree predicates anchored at the given value.
217 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
218                               Value val, PredicateBuilder &builder,
219                               DenseMap<Value, Position *> &inputs,
220                               Position *pos) {
221   // Make sure this input value is accessible to the rewrite.
222   auto it = inputs.try_emplace(val, pos);
223   if (!it.second) {
224     // If this is an input value that has been visited in the tree, add a
225     // constraint to ensure that both instances refer to the same value.
226     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
227             pdl::TypeOp>(val.getDefiningOp())) {
228       auto minMaxPositions =
229           std::minmax(pos, it.first->second, comparePosDepth);
230       predList.emplace_back(minMaxPositions.second,
231                             builder.getEqualTo(minMaxPositions.first));
232     }
233     return;
234   }
235 
236   TypeSwitch<Position *>(pos)
237       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
238         getTreePredicates(predList, val, builder, inputs, pos);
239       })
240       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
241         getOperandTreePredicates(predList, val, builder, inputs, pos);
242       })
243       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
244 }
245 
246 /// Collect all of the predicates related to constraints within the given
247 /// pattern operation.
248 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
249                                     std::vector<PositionalPredicate> &predList,
250                                     PredicateBuilder &builder,
251                                     DenseMap<Value, Position *> &inputs) {
252   OperandRange arguments = op.args();
253   ArrayAttr parameters = op.constParamsAttr();
254 
255   std::vector<Position *> allPositions;
256   allPositions.reserve(arguments.size());
257   for (Value arg : arguments)
258     allPositions.push_back(inputs.lookup(arg));
259 
260   // Push the constraint to the furthest position.
261   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
262                                     comparePosDepth);
263   PredicateBuilder::Predicate pred =
264       builder.getConstraint(op.name(), std::move(allPositions), parameters);
265   predList.emplace_back(pos, pred);
266 }
267 
268 static void getResultPredicates(pdl::ResultOp op,
269                                 std::vector<PositionalPredicate> &predList,
270                                 PredicateBuilder &builder,
271                                 DenseMap<Value, Position *> &inputs) {
272   Position *&resultPos = inputs[op];
273   if (resultPos)
274     return;
275 
276   // Ensure that the result isn't null.
277   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
278   resultPos = builder.getResult(parentPos, op.index());
279   predList.emplace_back(resultPos, builder.getIsNotNull());
280 }
281 
282 static void getResultPredicates(pdl::ResultsOp op,
283                                 std::vector<PositionalPredicate> &predList,
284                                 PredicateBuilder &builder,
285                                 DenseMap<Value, Position *> &inputs) {
286   Position *&resultPos = inputs[op];
287   if (resultPos)
288     return;
289 
290   // Ensure that the result isn't null if the result has an index.
291   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
292   bool isVariadic = op.getType().isa<pdl::RangeType>();
293   Optional<unsigned> index = op.index();
294   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
295   if (index)
296     predList.emplace_back(resultPos, builder.getIsNotNull());
297 }
298 
299 /// Collect all of the predicates that cannot be determined via walking the
300 /// tree.
301 static void getNonTreePredicates(pdl::PatternOp pattern,
302                                  std::vector<PositionalPredicate> &predList,
303                                  PredicateBuilder &builder,
304                                  DenseMap<Value, Position *> &inputs) {
305   for (Operation &op : pattern.body().getOps()) {
306     TypeSwitch<Operation *>(&op)
307         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
308           getConstraintPredicates(constraintOp, predList, builder, inputs);
309         })
310         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
311           getResultPredicates(resultOp, predList, builder, inputs);
312         });
313   }
314 }
315 
316 namespace {
317 
318 /// An op accepting a value at an optional index.
319 struct OpIndex {
320   Value parent;
321   Optional<unsigned> index;
322 };
323 
324 /// The parent and operand index of each operation for each root, stored
325 /// as a nested map [root][operation].
326 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
327 
328 } // namespace
329 
330 /// Given a pattern, determines the set of roots present in this pattern.
331 /// These are the operations whose results are not consumed by other operations.
332 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
333   // First, collect all the operations that are used as operands
334   // to other operations. These are not roots by default.
335   DenseSet<Value> used;
336   for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
337     for (Value operand : operationOp.operands())
338       TypeSwitch<Operation *>(operand.getDefiningOp())
339           .Case<pdl::ResultOp, pdl::ResultsOp>(
340               [&used](auto resultOp) { used.insert(resultOp.parent()); });
341   }
342 
343   // Remove the specified root from the use set, so that we can
344   // always select it as a root, even if it is used by other operations.
345   if (Value root = pattern.getRewriter().root())
346     used.erase(root);
347 
348   // Finally, collect all the unused operations.
349   SmallVector<Value> roots;
350   for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
351     if (!used.contains(operationOp))
352       roots.push_back(operationOp);
353 
354   return roots;
355 }
356 
357 /// Given a list of candidate roots, builds the cost graph for connecting them.
358 /// The graph is formed by traversing the DAG of operations starting from each
359 /// root and marking the depth of each connector value (operand). Then we join
360 /// the candidate roots based on the common connector values, taking the one
361 /// with the minimum depth. Along the way, we compute, for each candidate root,
362 /// a mapping from each operation (in the DAG underneath this root) to its
363 /// parent operation and the corresponding operand index.
364 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
365                            ParentMaps &parentMaps) {
366 
367   // The entry of a queue. The entry consists of the following items:
368   // * the value in the DAG underneath the root;
369   // * the parent of the value;
370   // * the operand index of the value in its parent;
371   // * the depth of the visited value.
372   struct Entry {
373     Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth)
374         : value(value), parent(parent), index(index), depth(depth) {}
375 
376     Value value;
377     Value parent;
378     Optional<unsigned> index;
379     unsigned depth;
380   };
381 
382   // A root of a value and its depth (distance from root to the value).
383   struct RootDepth {
384     Value root;
385     unsigned depth = 0;
386   };
387 
388   // Map from candidate connector values to their roots and depths. Using a
389   // small vector with 1 entry because most values belong to a single root.
390   llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
391 
392   // Perform a breadth-first traversal of the op DAG rooted at each root.
393   for (Value root : roots) {
394     // The queue of visited values. A value may be present multiple times in
395     // the queue, for multiple parents. We only accept the first occurrence,
396     // which is guaranteed to have the lowest depth.
397     std::queue<Entry> toVisit;
398     toVisit.emplace(root, Value(), 0, 0);
399 
400     // The map from value to its parent for the current root.
401     DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
402 
403     while (!toVisit.empty()) {
404       Entry entry = toVisit.front();
405       toVisit.pop();
406       // Skip if already visited.
407       if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
408         continue;
409 
410       // Mark the root and depth of the value.
411       connectorsRootsDepths[entry.value].push_back({root, entry.depth});
412 
413       // Traverse the operands of an operation and result ops.
414       // We intentionally do not traverse attributes and types, because those
415       // are expensive to join on.
416       TypeSwitch<Operation *>(entry.value.getDefiningOp())
417           .Case<pdl::OperationOp>([&](auto operationOp) {
418             OperandRange operands = operationOp.operands();
419             // Special case when we pass all the operands in one range.
420             // For those, the index is empty.
421             if (operands.size() == 1 &&
422                 operands[0].getType().isa<pdl::RangeType>()) {
423               toVisit.emplace(operands[0], entry.value, llvm::None,
424                               entry.depth + 1);
425               return;
426             }
427 
428             // Default case: visit all the operands.
429             for (auto p : llvm::enumerate(operationOp.operands()))
430               toVisit.emplace(p.value(), entry.value, p.index(),
431                               entry.depth + 1);
432           })
433           .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
434             toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
435                             entry.depth);
436           });
437     }
438   }
439 
440   // Now build the cost graph.
441   // This is simply a minimum over all depths for the target root.
442   unsigned nextID = 0;
443   for (const auto &connectorRootsDepths : connectorsRootsDepths) {
444     Value value = connectorRootsDepths.first;
445     ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
446     // If there is only one root for this value, this will not trigger
447     // any edges in the cost graph (a perf optimization).
448     if (rootsDepths.size() == 1)
449       continue;
450 
451     for (const RootDepth &p : rootsDepths) {
452       for (const RootDepth &q : rootsDepths) {
453         if (&p == &q)
454           continue;
455         // Insert or retrieve the property of edge from p to q.
456         RootOrderingCost &cost = graph[q.root][p.root];
457         if (!cost.connector /* new edge */ || cost.cost.first > q.depth) {
458           if (!cost.connector)
459             cost.cost.second = nextID++;
460           cost.cost.first = q.depth;
461           cost.connector = value;
462         }
463       }
464     }
465   }
466 
467   assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
468          "the pattern contains a candidate root disconnected from the others");
469 }
470 
471 /// Visit a node during upward traversal.
472 void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
473                  PredicateBuilder &builder,
474                  DenseMap<Value, Position *> &valueToPosition, Position *&pos,
475                  bool &first) {
476   Value value = opIndex.parent;
477   TypeSwitch<Operation *>(value.getDefiningOp())
478       .Case<pdl::OperationOp>([&](auto operationOp) {
479         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
480         OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index);
481 
482         // Guard against traversing back to where we came from.
483         if (first) {
484           Position *parent = pos->getParent();
485           predList.emplace_back(opPos, builder.getNotEqualTo(parent));
486           first = false;
487         }
488 
489         // Guard against duplicate upward visits. These are not possible,
490         // because if this value was already visited, it would have been
491         // cheaper to start the traversal at this value rather than at the
492         // `connector`, violating the optimality of our spanning tree.
493         bool inserted = valueToPosition.try_emplace(value, opPos).second;
494         (void)inserted;
495         assert(inserted && "duplicate upward visit");
496 
497         // Obtain the tree predicates at the current value.
498         getTreePredicates(predList, value, builder, valueToPosition, opPos,
499                           opIndex.index);
500 
501         // Update the position
502         pos = opPos;
503       })
504       .Case<pdl::ResultOp>([&](auto resultOp) {
505         // Traverse up an individual result.
506         auto *opPos = dyn_cast<OperationPosition>(pos);
507         assert(opPos && "operations and results must be interleaved");
508         pos = builder.getResult(opPos, *opIndex.index);
509       })
510       .Case<pdl::ResultsOp>([&](auto resultOp) {
511         // Traverse up a group of results.
512         auto *opPos = dyn_cast<OperationPosition>(pos);
513         assert(opPos && "operations and results must be interleaved");
514         bool isVariadic = value.getType().isa<pdl::RangeType>();
515         if (opIndex.index)
516           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
517         else
518           pos = builder.getAllResults(opPos);
519       });
520 }
521 
522 /// Given a pattern operation, build the set of matcher predicates necessary to
523 /// match this pattern.
524 static Value buildPredicateList(pdl::PatternOp pattern,
525                                 PredicateBuilder &builder,
526                                 std::vector<PositionalPredicate> &predList,
527                                 DenseMap<Value, Position *> &valueToPosition) {
528   SmallVector<Value> roots = detectRoots(pattern);
529 
530   // Build the root ordering graph and compute the parent maps.
531   RootOrderingGraph graph;
532   ParentMaps parentMaps;
533   buildCostGraph(roots, graph, parentMaps);
534   LLVM_DEBUG({
535     llvm::dbgs() << "Graph:\n";
536     for (auto &target : graph) {
537       llvm::dbgs() << "  * " << target.first << "\n";
538       for (auto &source : target.second) {
539         RootOrderingCost c = source.second;
540         llvm::dbgs() << "      <- " << source.first << ": " << c.cost.first
541                      << ":" << c.cost.second << " via " << c.connector.getLoc()
542                      << "\n";
543       }
544     }
545   });
546 
547   // Solve the optimal branching problem for each candidate root, or use the
548   // provided one.
549   Value bestRoot = pattern.getRewriter().root();
550   OptimalBranching::EdgeList bestEdges;
551   if (!bestRoot) {
552     unsigned bestCost = 0;
553     LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
554     for (Value root : roots) {
555       OptimalBranching solver(graph, root);
556       unsigned cost = solver.solve();
557       LLVM_DEBUG(llvm::dbgs() << "  * " << root << ": " << cost << "\n");
558       if (!bestRoot || bestCost > cost) {
559         bestCost = cost;
560         bestRoot = root;
561         bestEdges = solver.preOrderTraversal(roots);
562       }
563     }
564   } else {
565     OptimalBranching solver(graph, bestRoot);
566     solver.solve();
567     bestEdges = solver.preOrderTraversal(roots);
568   }
569 
570   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
571   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
572 
573   // The best root is the starting point for the traversal. Get the tree
574   // predicates for the DAG rooted at bestRoot.
575   getTreePredicates(predList, bestRoot, builder, valueToPosition,
576                     builder.getRoot());
577 
578   // Traverse the selected optimal branching. For all edges in order, traverse
579   // up starting from the connector, until the candidate root is reached, and
580   // call getTreePredicates at every node along the way.
581   for (const std::pair<Value, Value> &edge : bestEdges) {
582     Value target = edge.first;
583     Value source = edge.second;
584 
585     // Check if we already visited the target root. This happens in two cases:
586     // 1) the initial root (bestRoot);
587     // 2) a root that is dominated by (contained in the subtree rooted at) an
588     //    already visited root.
589     if (valueToPosition.count(target))
590       continue;
591 
592     // Determine the connector.
593     Value connector = graph[target][source].connector;
594     assert(connector && "invalid edge");
595     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
596     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
597     Position *pos = valueToPosition.lookup(connector);
598     assert(pos && "The value has not been traversed yet");
599     bool first = true;
600 
601     // Traverse from the connector upwards towards the target root.
602     for (Value value = connector; value != target;) {
603       OpIndex opIndex = parentMap.lookup(value);
604       assert(opIndex.parent && "missing parent");
605       visitUpward(predList, opIndex, builder, valueToPosition, pos, first);
606       value = opIndex.parent;
607     }
608   }
609 
610   getNonTreePredicates(pattern, predList, builder, valueToPosition);
611 
612   return bestRoot;
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // Pattern Predicate Tree Merging
617 //===----------------------------------------------------------------------===//
618 
619 namespace {
620 
621 /// This class represents a specific predicate applied to a position, and
622 /// provides hashing and ordering operators. This class allows for computing a
623 /// frequence sum and ordering predicates based on a cost model.
624 struct OrderedPredicate {
625   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
626       : position(ip.first), question(ip.second) {}
627   OrderedPredicate(const PositionalPredicate &ip)
628       : position(ip.position), question(ip.question) {}
629 
630   /// The position this predicate is applied to.
631   Position *position;
632 
633   /// The question that is applied by this predicate onto the position.
634   Qualifier *question;
635 
636   /// The first and second order benefit sums.
637   /// The primary sum is the number of occurrences of this predicate among all
638   /// of the patterns.
639   unsigned primary = 0;
640   /// The secondary sum is a squared summation of the primary sum of all of the
641   /// predicates within each pattern that contains this predicate. This allows
642   /// for favoring predicates that are more commonly shared within a pattern, as
643   /// opposed to those shared across patterns.
644   unsigned secondary = 0;
645 
646   /// A map between a pattern operation and the answer to the predicate question
647   /// within that pattern.
648   DenseMap<Operation *, Qualifier *> patternToAnswer;
649 
650   /// Returns true if this predicate is ordered before `rhs`, based on the cost
651   /// model.
652   bool operator<(const OrderedPredicate &rhs) const {
653     // Sort by:
654     // * higher first and secondary order sums
655     // * lower depth
656     // * lower position dependency
657     // * lower predicate dependency
658     auto *rhsPos = rhs.position;
659     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
660                            rhsPos->getKind(), rhs.question->getKind()) >
661            std::make_tuple(rhs.primary, rhs.secondary,
662                            position->getOperationDepth(), position->getKind(),
663                            question->getKind());
664   }
665 };
666 
667 /// A DenseMapInfo for OrderedPredicate based solely on the position and
668 /// question.
669 struct OrderedPredicateDenseInfo {
670   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
671 
672   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
673   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
674   static bool isEqual(const OrderedPredicate &lhs,
675                       const OrderedPredicate &rhs) {
676     return lhs.position == rhs.position && lhs.question == rhs.question;
677   }
678   static unsigned getHashValue(const OrderedPredicate &p) {
679     return llvm::hash_combine(p.position, p.question);
680   }
681 };
682 
683 /// This class wraps a set of ordered predicates that are used within a specific
684 /// pattern operation.
685 struct OrderedPredicateList {
686   OrderedPredicateList(pdl::PatternOp pattern, Value root)
687       : pattern(pattern), root(root) {}
688 
689   pdl::PatternOp pattern;
690   Value root;
691   DenseSet<OrderedPredicate *> predicates;
692 };
693 } // namespace
694 
695 /// Returns true if the given matcher refers to the same predicate as the given
696 /// ordered predicate. This means that the position and questions of the two
697 /// match.
698 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
699   return node->getPosition() == predicate->position &&
700          node->getQuestion() == predicate->question;
701 }
702 
703 /// Get or insert a child matcher for the given parent switch node, given a
704 /// predicate and parent pattern.
705 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
706                                                OrderedPredicate *predicate,
707                                                pdl::PatternOp pattern) {
708   assert(isSamePredicate(node, predicate) &&
709          "expected matcher to equal the given predicate");
710 
711   auto it = predicate->patternToAnswer.find(pattern);
712   assert(it != predicate->patternToAnswer.end() &&
713          "expected pattern to exist in predicate");
714   return node->getChildren().insert({it->second, nullptr}).first->second;
715 }
716 
717 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
718 /// order. A pattern will traverse as far as possible using common predicates
719 /// and then either diverge from the CFG or reach the end of a branch and start
720 /// creating new nodes.
721 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
722                              OrderedPredicateList &list,
723                              std::vector<OrderedPredicate *>::iterator current,
724                              std::vector<OrderedPredicate *>::iterator end) {
725   if (current == end) {
726     // We've hit the end of a pattern, so create a successful result node.
727     node =
728         std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
729 
730     // If the pattern doesn't contain this predicate, ignore it.
731   } else if (list.predicates.find(*current) == list.predicates.end()) {
732     propagatePattern(node, list, std::next(current), end);
733 
734     // If the current matcher node is invalid, create a new one for this
735     // position and continue propagation.
736   } else if (!node) {
737     // Create a new node at this position and continue
738     node = std::make_unique<SwitchNode>((*current)->position,
739                                         (*current)->question);
740     propagatePattern(
741         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
742         list, std::next(current), end);
743 
744     // If the matcher has already been created, and it is for this predicate we
745     // continue propagation to the child.
746   } else if (isSamePredicate(node.get(), *current)) {
747     propagatePattern(
748         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
749         list, std::next(current), end);
750 
751     // If the matcher doesn't match the current predicate, insert a branch as
752     // the common set of matchers has diverged.
753   } else {
754     propagatePattern(node->getFailureNode(), list, current, end);
755   }
756 }
757 
758 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
759 /// `node` is updated in-place if it is a switch.
760 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
761   if (!node)
762     return;
763 
764   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
765     SwitchNode::ChildMapT &children = switchNode->getChildren();
766     for (auto &it : children)
767       foldSwitchToBool(it.second);
768 
769     // If the node only contains one child, collapse it into a boolean predicate
770     // node.
771     if (children.size() == 1) {
772       auto childIt = children.begin();
773       node = std::make_unique<BoolNode>(
774           node->getPosition(), node->getQuestion(), childIt->first,
775           std::move(childIt->second), std::move(node->getFailureNode()));
776     }
777   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
778     foldSwitchToBool(boolNode->getSuccessNode());
779   }
780 
781   foldSwitchToBool(node->getFailureNode());
782 }
783 
784 /// Insert an exit node at the end of the failure path of the `root`.
785 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
786   while (*root)
787     root = &(*root)->getFailureNode();
788   *root = std::make_unique<ExitNode>();
789 }
790 
791 /// Given a module containing PDL pattern operations, generate a matcher tree
792 /// using the patterns within the given module and return the root matcher node.
793 std::unique_ptr<MatcherNode>
794 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
795                                  DenseMap<Value, Position *> &valueToPosition) {
796   // The set of predicates contained within the pattern operations of the
797   // module.
798   struct PatternPredicates {
799     PatternPredicates(pdl::PatternOp pattern, Value root,
800                       std::vector<PositionalPredicate> predicates)
801         : pattern(pattern), root(root), predicates(std::move(predicates)) {}
802 
803     /// A pattern.
804     pdl::PatternOp pattern;
805 
806     /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
807     Value root;
808 
809     /// The extracted predicates for this pattern and root.
810     std::vector<PositionalPredicate> predicates;
811   };
812 
813   SmallVector<PatternPredicates, 16> patternsAndPredicates;
814   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
815     std::vector<PositionalPredicate> predicateList;
816     Value root =
817         buildPredicateList(pattern, builder, predicateList, valueToPosition);
818     patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
819   }
820 
821   // Associate a pattern result with each unique predicate.
822   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
823   for (auto &patternAndPredList : patternsAndPredicates) {
824     for (auto &predicate : patternAndPredList.predicates) {
825       auto it = uniqued.insert(predicate);
826       it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
827                                             predicate.answer);
828     }
829   }
830 
831   // Associate each pattern to a set of its ordered predicates for later lookup.
832   std::vector<OrderedPredicateList> lists;
833   lists.reserve(patternsAndPredicates.size());
834   for (auto &patternAndPredList : patternsAndPredicates) {
835     OrderedPredicateList list(patternAndPredList.pattern,
836                               patternAndPredList.root);
837     for (auto &predicate : patternAndPredList.predicates) {
838       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
839       list.predicates.insert(orderedPredicate);
840 
841       // Increment the primary sum for each reference to a particular predicate.
842       ++orderedPredicate->primary;
843     }
844     lists.push_back(std::move(list));
845   }
846 
847   // For a particular pattern, get the total primary sum and add it to the
848   // secondary sum of each predicate. Square the primary sums to emphasize
849   // shared predicates within rather than across patterns.
850   for (auto &list : lists) {
851     unsigned total = 0;
852     for (auto *predicate : list.predicates)
853       total += predicate->primary * predicate->primary;
854     for (auto *predicate : list.predicates)
855       predicate->secondary += total;
856   }
857 
858   // Sort the set of predicates now that the cost primary and secondary sums
859   // have been computed.
860   std::vector<OrderedPredicate *> ordered;
861   ordered.reserve(uniqued.size());
862   for (auto &ip : uniqued)
863     ordered.push_back(&ip);
864   std::stable_sort(
865       ordered.begin(), ordered.end(),
866       [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
867 
868   // Build the matchers for each of the pattern predicate lists.
869   std::unique_ptr<MatcherNode> root;
870   for (OrderedPredicateList &list : lists)
871     propagatePattern(root, list, ordered.begin(), ordered.end());
872 
873   // Collapse the graph and insert the exit node.
874   foldSwitchToBool(root);
875   insertExitNode(&root);
876   return root;
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // MatcherNode
881 //===----------------------------------------------------------------------===//
882 
883 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
884                          std::unique_ptr<MatcherNode> failureNode)
885     : position(p), question(q), failureNode(std::move(failureNode)),
886       matcherTypeID(matcherTypeID) {}
887 
888 //===----------------------------------------------------------------------===//
889 // BoolNode
890 //===----------------------------------------------------------------------===//
891 
892 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
893                    std::unique_ptr<MatcherNode> successNode,
894                    std::unique_ptr<MatcherNode> failureNode)
895     : MatcherNode(TypeID::get<BoolNode>(), position, question,
896                   std::move(failureNode)),
897       answer(answer), successNode(std::move(successNode)) {}
898 
899 //===----------------------------------------------------------------------===//
900 // SuccessNode
901 //===----------------------------------------------------------------------===//
902 
903 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
904                          std::unique_ptr<MatcherNode> failureNode)
905     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
906                   /*question=*/nullptr, std::move(failureNode)),
907       pattern(pattern), root(root) {}
908 
909 //===----------------------------------------------------------------------===//
910 // SwitchNode
911 //===----------------------------------------------------------------------===//
912 
913 SwitchNode::SwitchNode(Position *position, Qualifier *question)
914     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
915