xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision a76ee58f3cbcec6e31ff0d25e7d9a89b81a2ccc8)
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Interfaces/InferTypeOpInterface.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include <queue>
21 
22 #define DEBUG_TYPE "pdl-predicate-tree"
23 
24 using namespace mlir;
25 using namespace mlir::pdl_to_pdl_interp;
26 
27 //===----------------------------------------------------------------------===//
28 // Predicate List Building
29 //===----------------------------------------------------------------------===//
30 
31 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32                               Value val, PredicateBuilder &builder,
33                               DenseMap<Value, Position *> &inputs,
34                               Position *pos);
35 
36 /// Compares the depths of two positions.
37 static bool comparePosDepth(Position *lhs, Position *rhs) {
38   return lhs->getOperationDepth() < rhs->getOperationDepth();
39 }
40 
41 /// Returns the number of non-range elements within `values`.
42 static unsigned getNumNonRangeValues(ValueRange values) {
43   return llvm::count_if(values.getTypes(),
44                         [](Type type) { return !type.isa<pdl::RangeType>(); });
45 }
46 
47 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
48                               Value val, PredicateBuilder &builder,
49                               DenseMap<Value, Position *> &inputs,
50                               AttributePosition *pos) {
51   assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
52   pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
53   predList.emplace_back(pos, builder.getIsNotNull());
54 
55   // If the attribute has a type or value, add a constraint.
56   if (Value type = attr.type())
57     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58   else if (Attribute value = attr.valueAttr())
59     predList.emplace_back(pos, builder.getAttributeConstraint(value));
60 }
61 
62 /// Collect all of the predicates for the given operand position.
63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64                                      Value val, PredicateBuilder &builder,
65                                      DenseMap<Value, Position *> &inputs,
66                                      Position *pos) {
67   Type valueType = val.getType();
68   bool isVariadic = valueType.isa<pdl::RangeType>();
69 
70   // If this is a typed operand, add a type constraint.
71   TypeSwitch<Operation *>(val.getDefiningOp())
72       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73         // Prevent traversal into a null value if the operand has a proper
74         // index.
75         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77           predList.emplace_back(pos, builder.getIsNotNull());
78 
79         if (Value type = op.type())
80           getTreePredicates(predList, type, builder, inputs,
81                             builder.getType(pos));
82       })
83       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84         Optional<unsigned> index = op.index();
85 
86         // Prevent traversal into a null value if the result has a proper index.
87         if (index)
88           predList.emplace_back(pos, builder.getIsNotNull());
89 
90         // Get the parent operation of this operand.
91         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92         predList.emplace_back(parentPos, builder.getIsNotNull());
93 
94         // Ensure that the operands match the corresponding results of the
95         // parent operation.
96         Position *resultPos = nullptr;
97         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98           resultPos = builder.getResult(parentPos, *index);
99         else
100           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101         predList.emplace_back(resultPos, builder.getEqualTo(pos));
102 
103         // Collect the predicates of the parent operation.
104         getTreePredicates(predList, op.parent(), builder, inputs,
105                           (Position *)parentPos);
106       });
107 }
108 
109 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
110                               Value val, PredicateBuilder &builder,
111                               DenseMap<Value, Position *> &inputs,
112                               OperationPosition *pos,
113                               Optional<unsigned> ignoreOperand = llvm::None) {
114   assert(val.getType().isa<pdl::OperationType>() && "expected operation");
115   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116   OperationPosition *opPos = cast<OperationPosition>(pos);
117 
118   // Ensure getDefiningOp returns a non-null operation.
119   if (!opPos->isRoot())
120     predList.emplace_back(pos, builder.getIsNotNull());
121 
122   // Check that this is the correct root operation.
123   if (Optional<StringRef> opName = op.name())
124     predList.emplace_back(pos, builder.getOperationName(*opName));
125 
126   // Check that the operation has the proper number of operands. If there are
127   // any variable length operands, we check a minimum instead of an exact count.
128   OperandRange operands = op.operands();
129   unsigned minOperands = getNumNonRangeValues(operands);
130   if (minOperands != operands.size()) {
131     if (minOperands)
132       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133   } else {
134     predList.emplace_back(pos, builder.getOperandCount(minOperands));
135   }
136 
137   // Check that the operation has the proper number of results. If there are
138   // any variable length results, we check a minimum instead of an exact count.
139   OperandRange types = op.types();
140   unsigned minResults = getNumNonRangeValues(types);
141   if (minResults == types.size())
142     predList.emplace_back(pos, builder.getResultCount(types.size()));
143   else if (minResults)
144     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145 
146   // Recurse into any attributes, operands, or results.
147   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
148     getTreePredicates(
149         predList, std::get<1>(it), builder, inputs,
150         builder.getAttribute(opPos,
151                              std::get<0>(it).cast<StringAttr>().getValue()));
152   }
153 
154   // Process the operands and results of the operation. For all values up to
155   // the first variable length value, we use the concrete operand/result
156   // number. After that, we use the "group" given that we can't know the
157   // concrete indices until runtime. If there is only one variadic operand
158   // group, we treat it as all of the operands/results of the operation.
159   /// Operands.
160   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
161     getTreePredicates(predList, operands.front(), builder, inputs,
162                       builder.getAllOperands(opPos));
163   } else {
164     bool foundVariableLength = false;
165     for (auto operandIt : llvm::enumerate(operands)) {
166       bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
167       foundVariableLength |= isVariadic;
168 
169       // Ignore the specified operand, usually because this position was
170       // visited in an upward traversal via an iterative choice.
171       if (ignoreOperand && *ignoreOperand == operandIt.index())
172         continue;
173 
174       Position *pos =
175           foundVariableLength
176               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
177               : builder.getOperand(opPos, operandIt.index());
178       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
179     }
180   }
181   /// Results.
182   if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
183     getTreePredicates(predList, types.front(), builder, inputs,
184                       builder.getType(builder.getAllResults(opPos)));
185   } else {
186     bool foundVariableLength = false;
187     for (auto &resultIt : llvm::enumerate(types)) {
188       bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
189       foundVariableLength |= isVariadic;
190 
191       auto *resultPos =
192           foundVariableLength
193               ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
194               : builder.getResult(pos, resultIt.index());
195       predList.emplace_back(resultPos, builder.getIsNotNull());
196       getTreePredicates(predList, resultIt.value(), builder, inputs,
197                         builder.getType(resultPos));
198     }
199   }
200 }
201 
202 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
203                               Value val, PredicateBuilder &builder,
204                               DenseMap<Value, Position *> &inputs,
205                               TypePosition *pos) {
206   // Check for a constraint on a constant type.
207   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
208     if (Attribute type = typeOp.typeAttr())
209       predList.emplace_back(pos, builder.getTypeConstraint(type));
210   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
211     if (Attribute typeAttr = typeOp.typesAttr())
212       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
213   }
214 }
215 
216 /// Collect the tree predicates anchored at the given value.
217 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
218                               Value val, PredicateBuilder &builder,
219                               DenseMap<Value, Position *> &inputs,
220                               Position *pos) {
221   // Make sure this input value is accessible to the rewrite.
222   auto it = inputs.try_emplace(val, pos);
223   if (!it.second) {
224     // If this is an input value that has been visited in the tree, add a
225     // constraint to ensure that both instances refer to the same value.
226     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
227             pdl::TypeOp>(val.getDefiningOp())) {
228       auto minMaxPositions =
229           std::minmax(pos, it.first->second, comparePosDepth);
230       predList.emplace_back(minMaxPositions.second,
231                             builder.getEqualTo(minMaxPositions.first));
232     }
233     return;
234   }
235 
236   TypeSwitch<Position *>(pos)
237       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
238         getTreePredicates(predList, val, builder, inputs, pos);
239       })
240       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
241         getOperandTreePredicates(predList, val, builder, inputs, pos);
242       })
243       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
244 }
245 
246 /// Collect all of the predicates related to constraints within the given
247 /// pattern operation.
248 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
249                                     std::vector<PositionalPredicate> &predList,
250                                     PredicateBuilder &builder,
251                                     DenseMap<Value, Position *> &inputs) {
252   OperandRange arguments = op.args();
253   ArrayAttr parameters = op.constParamsAttr();
254 
255   std::vector<Position *> allPositions;
256   allPositions.reserve(arguments.size());
257   for (Value arg : arguments)
258     allPositions.push_back(inputs.lookup(arg));
259 
260   // Push the constraint to the furthest position.
261   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
262                                     comparePosDepth);
263   PredicateBuilder::Predicate pred =
264       builder.getConstraint(op.name(), std::move(allPositions), parameters);
265   predList.emplace_back(pos, pred);
266 }
267 
268 static void getResultPredicates(pdl::ResultOp op,
269                                 std::vector<PositionalPredicate> &predList,
270                                 PredicateBuilder &builder,
271                                 DenseMap<Value, Position *> &inputs) {
272   Position *&resultPos = inputs[op];
273   if (resultPos)
274     return;
275 
276   // Ensure that the result isn't null.
277   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
278   resultPos = builder.getResult(parentPos, op.index());
279   predList.emplace_back(resultPos, builder.getIsNotNull());
280 }
281 
282 static void getResultPredicates(pdl::ResultsOp op,
283                                 std::vector<PositionalPredicate> &predList,
284                                 PredicateBuilder &builder,
285                                 DenseMap<Value, Position *> &inputs) {
286   Position *&resultPos = inputs[op];
287   if (resultPos)
288     return;
289 
290   // Ensure that the result isn't null if the result has an index.
291   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
292   bool isVariadic = op.getType().isa<pdl::RangeType>();
293   Optional<unsigned> index = op.index();
294   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
295   if (index)
296     predList.emplace_back(resultPos, builder.getIsNotNull());
297 }
298 
299 /// Collect all of the predicates that cannot be determined via walking the
300 /// tree.
301 static void getNonTreePredicates(pdl::PatternOp pattern,
302                                  std::vector<PositionalPredicate> &predList,
303                                  PredicateBuilder &builder,
304                                  DenseMap<Value, Position *> &inputs) {
305   for (Operation &op : pattern.body().getOps()) {
306     TypeSwitch<Operation *>(&op)
307         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
308           getConstraintPredicates(constraintOp, predList, builder, inputs);
309         })
310         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
311           getResultPredicates(resultOp, predList, builder, inputs);
312         });
313   }
314 }
315 
316 namespace {
317 
318 /// An op accepting a value at an optional index.
319 struct OpIndex {
320   Value parent;
321   Optional<unsigned> index;
322 };
323 
324 /// The parent and operand index of each operation for each root, stored
325 /// as a nested map [root][operation].
326 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
327 
328 } // namespace
329 
330 /// Given a pattern, determines the set of roots present in this pattern.
331 /// These are the operations whose results are not consumed by other operations.
332 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
333   // First, collect all the operations that are used as operands
334   // to other operations. These are not roots by default.
335   DenseSet<Value> used;
336   for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
337     for (Value operand : operationOp.operands())
338       TypeSwitch<Operation *>(operand.getDefiningOp())
339           .Case<pdl::ResultOp, pdl::ResultsOp>(
340               [&used](auto resultOp) { used.insert(resultOp.parent()); });
341   }
342 
343   // Remove the specified root from the use set, so that we can
344   // always select it as a root, even if it is used by other operations.
345   if (Value root = pattern.getRewriter().root())
346     used.erase(root);
347 
348   // Finally, collect all the unused operations.
349   SmallVector<Value> roots;
350   for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
351     if (!used.contains(operationOp))
352       roots.push_back(operationOp);
353 
354   return roots;
355 }
356 
357 /// Given a list of candidate roots, builds the cost graph for connecting them.
358 /// The graph is formed by traversing the DAG of operations starting from each
359 /// root and marking the depth of each connector value (operand). Then we join
360 /// the candidate roots based on the common connector values, taking the one
361 /// with the minimum depth. Along the way, we compute, for each candidate root,
362 /// a mapping from each operation (in the DAG underneath this root) to its
363 /// parent operation and the corresponding operand index.
364 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
365                            ParentMaps &parentMaps) {
366 
367   // The entry of a queue. The entry consists of the following items:
368   // * the value in the DAG underneath the root;
369   // * the parent of the value;
370   // * the operand index of the value in its parent;
371   // * the depth of the visited value.
372   struct Entry {
373     Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth)
374         : value(value), parent(parent), index(index), depth(depth) {}
375 
376     Value value;
377     Value parent;
378     Optional<unsigned> index;
379     unsigned depth;
380   };
381 
382   // A root of a value and its depth (distance from root to the value).
383   struct RootDepth {
384     Value root;
385     unsigned depth = 0;
386   };
387 
388   // Map from candidate connector values to their roots and depths. Using a
389   // small vector with 1 entry because most values belong to a single root.
390   llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
391 
392   // Perform a breadth-first traversal of the op DAG rooted at each root.
393   for (Value root : roots) {
394     // The queue of visited values. A value may be present multiple times in
395     // the queue, for multiple parents. We only accept the first occurrence,
396     // which is guaranteed to have the lowest depth.
397     std::queue<Entry> toVisit;
398     toVisit.emplace(root, Value(), 0, 0);
399 
400     // The map from value to its parent for the current root.
401     DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
402 
403     while (!toVisit.empty()) {
404       Entry entry = toVisit.front();
405       toVisit.pop();
406       // Skip if already visited.
407       if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
408         continue;
409 
410       // Mark the root and depth of the value.
411       connectorsRootsDepths[entry.value].push_back({root, entry.depth});
412 
413       // Traverse the operands of an operation and result ops.
414       // We intentionally do not traverse attributes and types, because those
415       // are expensive to join on.
416       TypeSwitch<Operation *>(entry.value.getDefiningOp())
417           .Case<pdl::OperationOp>([&](auto operationOp) {
418             OperandRange operands = operationOp.operands();
419             // Special case when we pass all the operands in one range.
420             // For those, the index is empty.
421             if (operands.size() == 1 &&
422                 operands[0].getType().isa<pdl::RangeType>()) {
423               toVisit.emplace(operands[0], entry.value, llvm::None,
424                               entry.depth + 1);
425               return;
426             }
427 
428             // Default case: visit all the operands.
429             for (auto p : llvm::enumerate(operationOp.operands()))
430               toVisit.emplace(p.value(), entry.value, p.index(),
431                               entry.depth + 1);
432           })
433           .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
434             toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
435                             entry.depth);
436           });
437     }
438   }
439 
440   // Now build the cost graph.
441   // This is simply a minimum over all depths for the target root.
442   unsigned nextID = 0;
443   for (const auto &connectorRootsDepths : connectorsRootsDepths) {
444     Value value = connectorRootsDepths.first;
445     ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
446     // If there is only one root for this value, this will not trigger
447     // any edges in the cost graph (a perf optimization).
448     if (rootsDepths.size() == 1)
449       continue;
450 
451     for (const RootDepth &p : rootsDepths) {
452       for (const RootDepth &q : rootsDepths) {
453         if (&p == &q)
454           continue;
455         // Insert or retrieve the property of edge from p to q.
456         RootOrderingCost &cost = graph[q.root][p.root];
457         if (!cost.connector /* new edge */ || cost.cost.first > q.depth) {
458           if (!cost.connector)
459             cost.cost.second = nextID++;
460           cost.cost.first = q.depth;
461           cost.connector = value;
462         }
463       }
464     }
465   }
466 
467   assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
468          "the pattern contains a candidate root disconnected from the others");
469 }
470 
471 /// Visit a node during upward traversal.
472 void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
473                  PredicateBuilder &builder,
474                  DenseMap<Value, Position *> &valueToPosition, Position *&pos,
475                  bool &first) {
476   Value value = opIndex.parent;
477   TypeSwitch<Operation *>(value.getDefiningOp())
478       .Case<pdl::OperationOp>([&](auto operationOp) {
479         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
480         OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index);
481 
482         // Guard against traversing back to where we came from.
483         if (first) {
484           Position *parent = pos->getParent();
485           predList.emplace_back(opPos, builder.getNotEqualTo(parent));
486           first = false;
487         }
488 
489         // Guard against duplicate upward visits. These are not possible,
490         // because if this value was already visited, it would have been
491         // cheaper to start the traversal at this value rather than at the
492         // `connector`, violating the optimality of our spanning tree.
493         bool inserted = valueToPosition.try_emplace(value, opPos).second;
494         assert(inserted && "duplicate upward visit");
495 
496         // Obtain the tree predicates at the current value.
497         getTreePredicates(predList, value, builder, valueToPosition, opPos,
498                           opIndex.index);
499 
500         // Update the position
501         pos = opPos;
502       })
503       .Case<pdl::ResultOp>([&](auto resultOp) {
504         // Traverse up an individual result.
505         auto *opPos = dyn_cast<OperationPosition>(pos);
506         assert(opPos && "operations and results must be interleaved");
507         pos = builder.getResult(opPos, *opIndex.index);
508       })
509       .Case<pdl::ResultsOp>([&](auto resultOp) {
510         // Traverse up a group of results.
511         auto *opPos = dyn_cast<OperationPosition>(pos);
512         assert(opPos && "operations and results must be interleaved");
513         bool isVariadic = value.getType().isa<pdl::RangeType>();
514         if (opIndex.index)
515           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
516         else
517           pos = builder.getAllResults(opPos);
518       });
519 }
520 
521 /// Given a pattern operation, build the set of matcher predicates necessary to
522 /// match this pattern.
523 static Value buildPredicateList(pdl::PatternOp pattern,
524                                 PredicateBuilder &builder,
525                                 std::vector<PositionalPredicate> &predList,
526                                 DenseMap<Value, Position *> &valueToPosition) {
527   SmallVector<Value> roots = detectRoots(pattern);
528 
529   // Build the root ordering graph and compute the parent maps.
530   RootOrderingGraph graph;
531   ParentMaps parentMaps;
532   buildCostGraph(roots, graph, parentMaps);
533   LLVM_DEBUG({
534     llvm::dbgs() << "Graph:\n";
535     for (auto &target : graph) {
536       llvm::dbgs() << "  * " << target.first << "\n";
537       for (auto &source : target.second) {
538         RootOrderingCost c = source.second;
539         llvm::dbgs() << "      <- " << source.first << ": " << c.cost.first
540                      << ":" << c.cost.second << " via " << c.connector.getLoc()
541                      << "\n";
542       }
543     }
544   });
545 
546   // Solve the optimal branching problem for each candidate root, or use the
547   // provided one.
548   Value bestRoot = pattern.getRewriter().root();
549   OptimalBranching::EdgeList bestEdges;
550   if (!bestRoot) {
551     unsigned bestCost = 0;
552     LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
553     for (Value root : roots) {
554       OptimalBranching solver(graph, root);
555       unsigned cost = solver.solve();
556       LLVM_DEBUG(llvm::dbgs() << "  * " << root << ": " << cost << "\n");
557       if (!bestRoot || bestCost > cost) {
558         bestCost = cost;
559         bestRoot = root;
560         bestEdges = solver.preOrderTraversal(roots);
561       }
562     }
563   } else {
564     OptimalBranching solver(graph, bestRoot);
565     solver.solve();
566     bestEdges = solver.preOrderTraversal(roots);
567   }
568 
569   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
570   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
571 
572   // The best root is the starting point for the traversal. Get the tree
573   // predicates for the DAG rooted at bestRoot.
574   getTreePredicates(predList, bestRoot, builder, valueToPosition,
575                     builder.getRoot());
576 
577   // Traverse the selected optimal branching. For all edges in order, traverse
578   // up starting from the connector, until the candidate root is reached, and
579   // call getTreePredicates at every node along the way.
580   for (const std::pair<Value, Value> &edge : bestEdges) {
581     Value target = edge.first;
582     Value source = edge.second;
583 
584     // Check if we already visited the target root. This happens in two cases:
585     // 1) the initial root (bestRoot);
586     // 2) a root that is dominated by (contained in the subtree rooted at) an
587     //    already visited root.
588     if (valueToPosition.count(target))
589       continue;
590 
591     // Determine the connector.
592     Value connector = graph[target][source].connector;
593     assert(connector && "invalid edge");
594     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
595     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
596     Position *pos = valueToPosition.lookup(connector);
597     assert(pos && "The value has not been traversed yet");
598     bool first = true;
599 
600     // Traverse from the connector upwards towards the target root.
601     for (Value value = connector; value != target;) {
602       OpIndex opIndex = parentMap.lookup(value);
603       assert(opIndex.parent && "missing parent");
604       visitUpward(predList, opIndex, builder, valueToPosition, pos, first);
605       value = opIndex.parent;
606     }
607   }
608 
609   getNonTreePredicates(pattern, predList, builder, valueToPosition);
610 
611   return bestRoot;
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // Pattern Predicate Tree Merging
616 //===----------------------------------------------------------------------===//
617 
618 namespace {
619 
620 /// This class represents a specific predicate applied to a position, and
621 /// provides hashing and ordering operators. This class allows for computing a
622 /// frequence sum and ordering predicates based on a cost model.
623 struct OrderedPredicate {
624   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
625       : position(ip.first), question(ip.second) {}
626   OrderedPredicate(const PositionalPredicate &ip)
627       : position(ip.position), question(ip.question) {}
628 
629   /// The position this predicate is applied to.
630   Position *position;
631 
632   /// The question that is applied by this predicate onto the position.
633   Qualifier *question;
634 
635   /// The first and second order benefit sums.
636   /// The primary sum is the number of occurrences of this predicate among all
637   /// of the patterns.
638   unsigned primary = 0;
639   /// The secondary sum is a squared summation of the primary sum of all of the
640   /// predicates within each pattern that contains this predicate. This allows
641   /// for favoring predicates that are more commonly shared within a pattern, as
642   /// opposed to those shared across patterns.
643   unsigned secondary = 0;
644 
645   /// A map between a pattern operation and the answer to the predicate question
646   /// within that pattern.
647   DenseMap<Operation *, Qualifier *> patternToAnswer;
648 
649   /// Returns true if this predicate is ordered before `rhs`, based on the cost
650   /// model.
651   bool operator<(const OrderedPredicate &rhs) const {
652     // Sort by:
653     // * higher first and secondary order sums
654     // * lower depth
655     // * lower position dependency
656     // * lower predicate dependency
657     auto *rhsPos = rhs.position;
658     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
659                            rhsPos->getKind(), rhs.question->getKind()) >
660            std::make_tuple(rhs.primary, rhs.secondary,
661                            position->getOperationDepth(), position->getKind(),
662                            question->getKind());
663   }
664 };
665 
666 /// A DenseMapInfo for OrderedPredicate based solely on the position and
667 /// question.
668 struct OrderedPredicateDenseInfo {
669   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
670 
671   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
672   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
673   static bool isEqual(const OrderedPredicate &lhs,
674                       const OrderedPredicate &rhs) {
675     return lhs.position == rhs.position && lhs.question == rhs.question;
676   }
677   static unsigned getHashValue(const OrderedPredicate &p) {
678     return llvm::hash_combine(p.position, p.question);
679   }
680 };
681 
682 /// This class wraps a set of ordered predicates that are used within a specific
683 /// pattern operation.
684 struct OrderedPredicateList {
685   OrderedPredicateList(pdl::PatternOp pattern, Value root)
686       : pattern(pattern), root(root) {}
687 
688   pdl::PatternOp pattern;
689   Value root;
690   DenseSet<OrderedPredicate *> predicates;
691 };
692 } // end anonymous namespace
693 
694 /// Returns true if the given matcher refers to the same predicate as the given
695 /// ordered predicate. This means that the position and questions of the two
696 /// match.
697 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
698   return node->getPosition() == predicate->position &&
699          node->getQuestion() == predicate->question;
700 }
701 
702 /// Get or insert a child matcher for the given parent switch node, given a
703 /// predicate and parent pattern.
704 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
705                                                OrderedPredicate *predicate,
706                                                pdl::PatternOp pattern) {
707   assert(isSamePredicate(node, predicate) &&
708          "expected matcher to equal the given predicate");
709 
710   auto it = predicate->patternToAnswer.find(pattern);
711   assert(it != predicate->patternToAnswer.end() &&
712          "expected pattern to exist in predicate");
713   return node->getChildren().insert({it->second, nullptr}).first->second;
714 }
715 
716 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
717 /// order. A pattern will traverse as far as possible using common predicates
718 /// and then either diverge from the CFG or reach the end of a branch and start
719 /// creating new nodes.
720 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
721                              OrderedPredicateList &list,
722                              std::vector<OrderedPredicate *>::iterator current,
723                              std::vector<OrderedPredicate *>::iterator end) {
724   if (current == end) {
725     // We've hit the end of a pattern, so create a successful result node.
726     node =
727         std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
728 
729     // If the pattern doesn't contain this predicate, ignore it.
730   } else if (list.predicates.find(*current) == list.predicates.end()) {
731     propagatePattern(node, list, std::next(current), end);
732 
733     // If the current matcher node is invalid, create a new one for this
734     // position and continue propagation.
735   } else if (!node) {
736     // Create a new node at this position and continue
737     node = std::make_unique<SwitchNode>((*current)->position,
738                                         (*current)->question);
739     propagatePattern(
740         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
741         list, std::next(current), end);
742 
743     // If the matcher has already been created, and it is for this predicate we
744     // continue propagation to the child.
745   } else if (isSamePredicate(node.get(), *current)) {
746     propagatePattern(
747         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
748         list, std::next(current), end);
749 
750     // If the matcher doesn't match the current predicate, insert a branch as
751     // the common set of matchers has diverged.
752   } else {
753     propagatePattern(node->getFailureNode(), list, current, end);
754   }
755 }
756 
757 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
758 /// `node` is updated in-place if it is a switch.
759 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
760   if (!node)
761     return;
762 
763   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
764     SwitchNode::ChildMapT &children = switchNode->getChildren();
765     for (auto &it : children)
766       foldSwitchToBool(it.second);
767 
768     // If the node only contains one child, collapse it into a boolean predicate
769     // node.
770     if (children.size() == 1) {
771       auto childIt = children.begin();
772       node = std::make_unique<BoolNode>(
773           node->getPosition(), node->getQuestion(), childIt->first,
774           std::move(childIt->second), std::move(node->getFailureNode()));
775     }
776   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
777     foldSwitchToBool(boolNode->getSuccessNode());
778   }
779 
780   foldSwitchToBool(node->getFailureNode());
781 }
782 
783 /// Insert an exit node at the end of the failure path of the `root`.
784 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
785   while (*root)
786     root = &(*root)->getFailureNode();
787   *root = std::make_unique<ExitNode>();
788 }
789 
790 /// Given a module containing PDL pattern operations, generate a matcher tree
791 /// using the patterns within the given module and return the root matcher node.
792 std::unique_ptr<MatcherNode>
793 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
794                                  DenseMap<Value, Position *> &valueToPosition) {
795   // The set of predicates contained within the pattern operations of the
796   // module.
797   struct PatternPredicates {
798     PatternPredicates(pdl::PatternOp pattern, Value root,
799                       std::vector<PositionalPredicate> predicates)
800         : pattern(pattern), root(root), predicates(std::move(predicates)) {}
801 
802     /// A pattern.
803     pdl::PatternOp pattern;
804 
805     /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
806     Value root;
807 
808     /// The extracted predicates for this pattern and root.
809     std::vector<PositionalPredicate> predicates;
810   };
811 
812   SmallVector<PatternPredicates, 16> patternsAndPredicates;
813   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
814     std::vector<PositionalPredicate> predicateList;
815     Value root =
816         buildPredicateList(pattern, builder, predicateList, valueToPosition);
817     patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
818   }
819 
820   // Associate a pattern result with each unique predicate.
821   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
822   for (auto &patternAndPredList : patternsAndPredicates) {
823     for (auto &predicate : patternAndPredList.predicates) {
824       auto it = uniqued.insert(predicate);
825       it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
826                                             predicate.answer);
827     }
828   }
829 
830   // Associate each pattern to a set of its ordered predicates for later lookup.
831   std::vector<OrderedPredicateList> lists;
832   lists.reserve(patternsAndPredicates.size());
833   for (auto &patternAndPredList : patternsAndPredicates) {
834     OrderedPredicateList list(patternAndPredList.pattern,
835                               patternAndPredList.root);
836     for (auto &predicate : patternAndPredList.predicates) {
837       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
838       list.predicates.insert(orderedPredicate);
839 
840       // Increment the primary sum for each reference to a particular predicate.
841       ++orderedPredicate->primary;
842     }
843     lists.push_back(std::move(list));
844   }
845 
846   // For a particular pattern, get the total primary sum and add it to the
847   // secondary sum of each predicate. Square the primary sums to emphasize
848   // shared predicates within rather than across patterns.
849   for (auto &list : lists) {
850     unsigned total = 0;
851     for (auto *predicate : list.predicates)
852       total += predicate->primary * predicate->primary;
853     for (auto *predicate : list.predicates)
854       predicate->secondary += total;
855   }
856 
857   // Sort the set of predicates now that the cost primary and secondary sums
858   // have been computed.
859   std::vector<OrderedPredicate *> ordered;
860   ordered.reserve(uniqued.size());
861   for (auto &ip : uniqued)
862     ordered.push_back(&ip);
863   std::stable_sort(
864       ordered.begin(), ordered.end(),
865       [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
866 
867   // Build the matchers for each of the pattern predicate lists.
868   std::unique_ptr<MatcherNode> root;
869   for (OrderedPredicateList &list : lists)
870     propagatePattern(root, list, ordered.begin(), ordered.end());
871 
872   // Collapse the graph and insert the exit node.
873   foldSwitchToBool(root);
874   insertExitNode(&root);
875   return root;
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // MatcherNode
880 //===----------------------------------------------------------------------===//
881 
882 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
883                          std::unique_ptr<MatcherNode> failureNode)
884     : position(p), question(q), failureNode(std::move(failureNode)),
885       matcherTypeID(matcherTypeID) {}
886 
887 //===----------------------------------------------------------------------===//
888 // BoolNode
889 //===----------------------------------------------------------------------===//
890 
891 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
892                    std::unique_ptr<MatcherNode> successNode,
893                    std::unique_ptr<MatcherNode> failureNode)
894     : MatcherNode(TypeID::get<BoolNode>(), position, question,
895                   std::move(failureNode)),
896       answer(answer), successNode(std::move(successNode)) {}
897 
898 //===----------------------------------------------------------------------===//
899 // SuccessNode
900 //===----------------------------------------------------------------------===//
901 
902 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
903                          std::unique_ptr<MatcherNode> failureNode)
904     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
905                   /*question=*/nullptr, std::move(failureNode)),
906       pattern(pattern), root(root) {}
907 
908 //===----------------------------------------------------------------------===//
909 // SwitchNode
910 //===----------------------------------------------------------------------===//
911 
912 SwitchNode::SwitchNode(Position *position, Qualifier *question)
913     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
914