xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision 8ec28af8eaff5acd0df3e53340159c034f08533d)
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Interfaces/InferTypeOpInterface.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 
23 #define DEBUG_TYPE "pdl-predicate-tree"
24 
25 using namespace mlir;
26 using namespace mlir::pdl_to_pdl_interp;
27 
28 //===----------------------------------------------------------------------===//
29 // Predicate List Building
30 //===----------------------------------------------------------------------===//
31 
32 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
33                               Value val, PredicateBuilder &builder,
34                               DenseMap<Value, Position *> &inputs,
35                               Position *pos);
36 
37 /// Compares the depths of two positions.
38 static bool comparePosDepth(Position *lhs, Position *rhs) {
39   return lhs->getOperationDepth() < rhs->getOperationDepth();
40 }
41 
42 /// Returns the number of non-range elements within `values`.
43 static unsigned getNumNonRangeValues(ValueRange values) {
44   return llvm::count_if(values.getTypes(),
45                         [](Type type) { return !isa<pdl::RangeType>(type); });
46 }
47 
48 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
49                               Value val, PredicateBuilder &builder,
50                               DenseMap<Value, Position *> &inputs,
51                               AttributePosition *pos) {
52   assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
53   predList.emplace_back(pos, builder.getIsNotNull());
54 
55   if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
56     // If the attribute has a type or value, add a constraint.
57     if (Value type = attr.getValueType())
58       getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
59     else if (Attribute value = attr.getValueAttr())
60       predList.emplace_back(pos, builder.getAttributeConstraint(value));
61   }
62 }
63 
64 /// Collect all of the predicates for the given operand position.
65 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
66                                      Value val, PredicateBuilder &builder,
67                                      DenseMap<Value, Position *> &inputs,
68                                      Position *pos) {
69   Type valueType = val.getType();
70   bool isVariadic = isa<pdl::RangeType>(valueType);
71 
72   // If this is a typed operand, add a type constraint.
73   TypeSwitch<Operation *>(val.getDefiningOp())
74       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
75         // Prevent traversal into a null value if the operand has a proper
76         // index.
77         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
78             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
79           predList.emplace_back(pos, builder.getIsNotNull());
80 
81         if (Value type = op.getValueType())
82           getTreePredicates(predList, type, builder, inputs,
83                             builder.getType(pos));
84       })
85       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
86         std::optional<unsigned> index = op.getIndex();
87 
88         // Prevent traversal into a null value if the result has a proper index.
89         if (index)
90           predList.emplace_back(pos, builder.getIsNotNull());
91 
92         // Get the parent operation of this operand.
93         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
94         predList.emplace_back(parentPos, builder.getIsNotNull());
95 
96         // Ensure that the operands match the corresponding results of the
97         // parent operation.
98         Position *resultPos = nullptr;
99         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
100           resultPos = builder.getResult(parentPos, *index);
101         else
102           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
103         predList.emplace_back(resultPos, builder.getEqualTo(pos));
104 
105         // Collect the predicates of the parent operation.
106         getTreePredicates(predList, op.getParent(), builder, inputs,
107                           (Position *)parentPos);
108       });
109 }
110 
111 static void
112 getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
113                   PredicateBuilder &builder,
114                   DenseMap<Value, Position *> &inputs, OperationPosition *pos,
115                   std::optional<unsigned> ignoreOperand = std::nullopt) {
116   assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
117   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
118   OperationPosition *opPos = cast<OperationPosition>(pos);
119 
120   // Ensure getDefiningOp returns a non-null operation.
121   if (!opPos->isRoot())
122     predList.emplace_back(pos, builder.getIsNotNull());
123 
124   // Check that this is the correct root operation.
125   if (std::optional<StringRef> opName = op.getOpName())
126     predList.emplace_back(pos, builder.getOperationName(*opName));
127 
128   // Check that the operation has the proper number of operands. If there are
129   // any variable length operands, we check a minimum instead of an exact count.
130   OperandRange operands = op.getOperandValues();
131   unsigned minOperands = getNumNonRangeValues(operands);
132   if (minOperands != operands.size()) {
133     if (minOperands)
134       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
135   } else {
136     predList.emplace_back(pos, builder.getOperandCount(minOperands));
137   }
138 
139   // Check that the operation has the proper number of results. If there are
140   // any variable length results, we check a minimum instead of an exact count.
141   OperandRange types = op.getTypeValues();
142   unsigned minResults = getNumNonRangeValues(types);
143   if (minResults == types.size())
144     predList.emplace_back(pos, builder.getResultCount(types.size()));
145   else if (minResults)
146     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
147 
148   // Recurse into any attributes, operands, or results.
149   for (auto [attrName, attr] :
150        llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
151     getTreePredicates(
152         predList, attr, builder, inputs,
153         builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
154   }
155 
156   // Process the operands and results of the operation. For all values up to
157   // the first variable length value, we use the concrete operand/result
158   // number. After that, we use the "group" given that we can't know the
159   // concrete indices until runtime. If there is only one variadic operand
160   // group, we treat it as all of the operands/results of the operation.
161   /// Operands.
162   if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
163     // Ignore the operands if we are performing an upward traversal (in that
164     // case, they have already been visited).
165     if (opPos->isRoot() || opPos->isOperandDefiningOp())
166       getTreePredicates(predList, operands.front(), builder, inputs,
167                         builder.getAllOperands(opPos));
168   } else {
169     bool foundVariableLength = false;
170     for (const auto &operandIt : llvm::enumerate(operands)) {
171       bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
172       foundVariableLength |= isVariadic;
173 
174       // Ignore the specified operand, usually because this position was
175       // visited in an upward traversal via an iterative choice.
176       if (ignoreOperand && *ignoreOperand == operandIt.index())
177         continue;
178 
179       Position *pos =
180           foundVariableLength
181               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
182               : builder.getOperand(opPos, operandIt.index());
183       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
184     }
185   }
186   /// Results.
187   if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
188     getTreePredicates(predList, types.front(), builder, inputs,
189                       builder.getType(builder.getAllResults(opPos)));
190     return;
191   }
192 
193   bool foundVariableLength = false;
194   for (auto [idx, typeValue] : llvm::enumerate(types)) {
195     bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
196     foundVariableLength |= isVariadic;
197 
198     auto *resultPos = foundVariableLength
199                           ? builder.getResultGroup(pos, idx, isVariadic)
200                           : builder.getResult(pos, idx);
201     predList.emplace_back(resultPos, builder.getIsNotNull());
202     getTreePredicates(predList, typeValue, builder, inputs,
203                       builder.getType(resultPos));
204   }
205 }
206 
207 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
208                               Value val, PredicateBuilder &builder,
209                               DenseMap<Value, Position *> &inputs,
210                               TypePosition *pos) {
211   // Check for a constraint on a constant type.
212   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
213     if (Attribute type = typeOp.getConstantTypeAttr())
214       predList.emplace_back(pos, builder.getTypeConstraint(type));
215   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
216     if (Attribute typeAttr = typeOp.getConstantTypesAttr())
217       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
218   }
219 }
220 
221 /// Collect the tree predicates anchored at the given value.
222 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
223                               Value val, PredicateBuilder &builder,
224                               DenseMap<Value, Position *> &inputs,
225                               Position *pos) {
226   // Make sure this input value is accessible to the rewrite.
227   auto it = inputs.try_emplace(val, pos);
228   if (!it.second) {
229     // If this is an input value that has been visited in the tree, add a
230     // constraint to ensure that both instances refer to the same value.
231     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
232             pdl::TypeOp>(val.getDefiningOp())) {
233       auto minMaxPositions =
234           std::minmax(pos, it.first->second, comparePosDepth);
235       predList.emplace_back(minMaxPositions.second,
236                             builder.getEqualTo(minMaxPositions.first));
237     }
238     return;
239   }
240 
241   TypeSwitch<Position *>(pos)
242       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
243         getTreePredicates(predList, val, builder, inputs, pos);
244       })
245       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
246         getOperandTreePredicates(predList, val, builder, inputs, pos);
247       })
248       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
249 }
250 
251 static void getAttributePredicates(pdl::AttributeOp op,
252                                    std::vector<PositionalPredicate> &predList,
253                                    PredicateBuilder &builder,
254                                    DenseMap<Value, Position *> &inputs) {
255   Position *&attrPos = inputs[op];
256   if (attrPos)
257     return;
258   Attribute value = op.getValueAttr();
259   assert(value && "expected non-tree `pdl.attribute` to contain a value");
260   attrPos = builder.getAttributeLiteral(value);
261 }
262 
263 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
264                                     std::vector<PositionalPredicate> &predList,
265                                     PredicateBuilder &builder,
266                                     DenseMap<Value, Position *> &inputs) {
267   OperandRange arguments = op.getArgs();
268 
269   std::vector<Position *> allPositions;
270   allPositions.reserve(arguments.size());
271   for (Value arg : arguments)
272     allPositions.push_back(inputs.lookup(arg));
273 
274   // Push the constraint to the furthest position.
275   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
276                                     comparePosDepth);
277   ResultRange results = op.getResults();
278   PredicateBuilder::Predicate pred = builder.getConstraint(
279       op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
280       op.getIsNegated());
281 
282   // For each result register a position so it can be used later
283   for (auto [i, result] : llvm::enumerate(results)) {
284     ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
285     ConstraintPosition *pos = builder.getConstraintPosition(q, i);
286     auto [it, inserted] = inputs.try_emplace(result, pos);
287     // If this is an input value that has been visited in the tree, add a
288     // constraint to ensure that both instances refer to the same value.
289     if (!inserted) {
290       Position *first = pos;
291       Position *second = it->second;
292       if (comparePosDepth(second, first))
293         std::tie(second, first) = std::make_pair(first, second);
294 
295       predList.emplace_back(second, builder.getEqualTo(first));
296     }
297   }
298   predList.emplace_back(pos, pred);
299 }
300 
301 static void getResultPredicates(pdl::ResultOp op,
302                                 std::vector<PositionalPredicate> &predList,
303                                 PredicateBuilder &builder,
304                                 DenseMap<Value, Position *> &inputs) {
305   Position *&resultPos = inputs[op];
306   if (resultPos)
307     return;
308 
309   // Ensure that the result isn't null.
310   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
311   resultPos = builder.getResult(parentPos, op.getIndex());
312   predList.emplace_back(resultPos, builder.getIsNotNull());
313 }
314 
315 static void getResultPredicates(pdl::ResultsOp op,
316                                 std::vector<PositionalPredicate> &predList,
317                                 PredicateBuilder &builder,
318                                 DenseMap<Value, Position *> &inputs) {
319   Position *&resultPos = inputs[op];
320   if (resultPos)
321     return;
322 
323   // Ensure that the result isn't null if the result has an index.
324   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
325   bool isVariadic = isa<pdl::RangeType>(op.getType());
326   std::optional<unsigned> index = op.getIndex();
327   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
328   if (index)
329     predList.emplace_back(resultPos, builder.getIsNotNull());
330 }
331 
332 static void getTypePredicates(Value typeValue,
333                               function_ref<Attribute()> typeAttrFn,
334                               PredicateBuilder &builder,
335                               DenseMap<Value, Position *> &inputs) {
336   Position *&typePos = inputs[typeValue];
337   if (typePos)
338     return;
339   Attribute typeAttr = typeAttrFn();
340   assert(typeAttr &&
341          "expected non-tree `pdl.type`/`pdl.types` to contain a value");
342   typePos = builder.getTypeLiteral(typeAttr);
343 }
344 
345 /// Collect all of the predicates that cannot be determined via walking the
346 /// tree.
347 static void getNonTreePredicates(pdl::PatternOp pattern,
348                                  std::vector<PositionalPredicate> &predList,
349                                  PredicateBuilder &builder,
350                                  DenseMap<Value, Position *> &inputs) {
351   for (Operation &op : pattern.getBodyRegion().getOps()) {
352     TypeSwitch<Operation *>(&op)
353         .Case([&](pdl::AttributeOp attrOp) {
354           getAttributePredicates(attrOp, predList, builder, inputs);
355         })
356         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
357           getConstraintPredicates(constraintOp, predList, builder, inputs);
358         })
359         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
360           getResultPredicates(resultOp, predList, builder, inputs);
361         })
362         .Case([&](pdl::TypeOp typeOp) {
363           getTypePredicates(
364               typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
365               inputs);
366         })
367         .Case([&](pdl::TypesOp typeOp) {
368           getTypePredicates(
369               typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
370               inputs);
371         });
372   }
373 }
374 
375 namespace {
376 
377 /// An op accepting a value at an optional index.
378 struct OpIndex {
379   Value parent;
380   std::optional<unsigned> index;
381 };
382 
383 /// The parent and operand index of each operation for each root, stored
384 /// as a nested map [root][operation].
385 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
386 
387 } // namespace
388 
389 /// Given a pattern, determines the set of roots present in this pattern.
390 /// These are the operations whose results are not consumed by other operations.
391 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
392   // First, collect all the operations that are used as operands
393   // to other operations. These are not roots by default.
394   DenseSet<Value> used;
395   for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
396     for (Value operand : operationOp.getOperandValues())
397       TypeSwitch<Operation *>(operand.getDefiningOp())
398           .Case<pdl::ResultOp, pdl::ResultsOp>(
399               [&used](auto resultOp) { used.insert(resultOp.getParent()); });
400   }
401 
402   // Remove the specified root from the use set, so that we can
403   // always select it as a root, even if it is used by other operations.
404   if (Value root = pattern.getRewriter().getRoot())
405     used.erase(root);
406 
407   // Finally, collect all the unused operations.
408   SmallVector<Value> roots;
409   for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
410     if (!used.contains(operationOp))
411       roots.push_back(operationOp);
412 
413   return roots;
414 }
415 
416 /// Given a list of candidate roots, builds the cost graph for connecting them.
417 /// The graph is formed by traversing the DAG of operations starting from each
418 /// root and marking the depth of each connector value (operand). Then we join
419 /// the candidate roots based on the common connector values, taking the one
420 /// with the minimum depth. Along the way, we compute, for each candidate root,
421 /// a mapping from each operation (in the DAG underneath this root) to its
422 /// parent operation and the corresponding operand index.
423 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
424                            ParentMaps &parentMaps) {
425 
426   // The entry of a queue. The entry consists of the following items:
427   // * the value in the DAG underneath the root;
428   // * the parent of the value;
429   // * the operand index of the value in its parent;
430   // * the depth of the visited value.
431   struct Entry {
432     Entry(Value value, Value parent, std::optional<unsigned> index,
433           unsigned depth)
434         : value(value), parent(parent), index(index), depth(depth) {}
435 
436     Value value;
437     Value parent;
438     std::optional<unsigned> index;
439     unsigned depth;
440   };
441 
442   // A root of a value and its depth (distance from root to the value).
443   struct RootDepth {
444     Value root;
445     unsigned depth = 0;
446   };
447 
448   // Map from candidate connector values to their roots and depths. Using a
449   // small vector with 1 entry because most values belong to a single root.
450   llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
451 
452   // Perform a breadth-first traversal of the op DAG rooted at each root.
453   for (Value root : roots) {
454     // The queue of visited values. A value may be present multiple times in
455     // the queue, for multiple parents. We only accept the first occurrence,
456     // which is guaranteed to have the lowest depth.
457     std::queue<Entry> toVisit;
458     toVisit.emplace(root, Value(), 0, 0);
459 
460     // The map from value to its parent for the current root.
461     DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
462 
463     while (!toVisit.empty()) {
464       Entry entry = toVisit.front();
465       toVisit.pop();
466       // Skip if already visited.
467       if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
468         continue;
469 
470       // Mark the root and depth of the value.
471       connectorsRootsDepths[entry.value].push_back({root, entry.depth});
472 
473       // Traverse the operands of an operation and result ops.
474       // We intentionally do not traverse attributes and types, because those
475       // are expensive to join on.
476       TypeSwitch<Operation *>(entry.value.getDefiningOp())
477           .Case<pdl::OperationOp>([&](auto operationOp) {
478             OperandRange operands = operationOp.getOperandValues();
479             // Special case when we pass all the operands in one range.
480             // For those, the index is empty.
481             if (operands.size() == 1 &&
482                 isa<pdl::RangeType>(operands[0].getType())) {
483               toVisit.emplace(operands[0], entry.value, std::nullopt,
484                               entry.depth + 1);
485               return;
486             }
487 
488             // Default case: visit all the operands.
489             for (const auto &p :
490                  llvm::enumerate(operationOp.getOperandValues()))
491               toVisit.emplace(p.value(), entry.value, p.index(),
492                               entry.depth + 1);
493           })
494           .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
495             toVisit.emplace(resultOp.getParent(), entry.value,
496                             resultOp.getIndex(), entry.depth);
497           });
498     }
499   }
500 
501   // Now build the cost graph.
502   // This is simply a minimum over all depths for the target root.
503   unsigned nextID = 0;
504   for (const auto &connectorRootsDepths : connectorsRootsDepths) {
505     Value value = connectorRootsDepths.first;
506     ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
507     // If there is only one root for this value, this will not trigger
508     // any edges in the cost graph (a perf optimization).
509     if (rootsDepths.size() == 1)
510       continue;
511 
512     for (const RootDepth &p : rootsDepths) {
513       for (const RootDepth &q : rootsDepths) {
514         if (&p == &q)
515           continue;
516         // Insert or retrieve the property of edge from p to q.
517         RootOrderingEntry &entry = graph[q.root][p.root];
518         if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
519           if (!entry.connector)
520             entry.cost.second = nextID++;
521           entry.cost.first = q.depth;
522           entry.connector = value;
523         }
524       }
525     }
526   }
527 
528   assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
529          "the pattern contains a candidate root disconnected from the others");
530 }
531 
532 /// Returns true if the operand at the given index needs to be queried using an
533 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
534 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
535   OperandRange operands = op.getOperandValues();
536   assert(index < operands.size() && "operand index out of range");
537   for (unsigned i = 0; i <= index; ++i)
538     if (isa<pdl::RangeType>(operands[i].getType()))
539       return true;
540   return false;
541 }
542 
543 /// Visit a node during upward traversal.
544 static void visitUpward(std::vector<PositionalPredicate> &predList,
545                         OpIndex opIndex, PredicateBuilder &builder,
546                         DenseMap<Value, Position *> &valueToPosition,
547                         Position *&pos, unsigned rootID) {
548   Value value = opIndex.parent;
549   TypeSwitch<Operation *>(value.getDefiningOp())
550       .Case<pdl::OperationOp>([&](auto operationOp) {
551         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
552 
553         // Get users and iterate over them.
554         Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
555         Position *foreachPos = builder.getForEach(usersPos, rootID);
556         OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
557 
558         // Compare the operand(s) of the user against the input value(s).
559         Position *operandPos;
560         if (!opIndex.index) {
561           // We are querying all the operands of the operation.
562           operandPos = builder.getAllOperands(opPos);
563         } else if (useOperandGroup(operationOp, *opIndex.index)) {
564           // We are querying an operand group.
565           Type type = operationOp.getOperandValues()[*opIndex.index].getType();
566           bool variadic = isa<pdl::RangeType>(type);
567           operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
568         } else {
569           // We are querying an individual operand.
570           operandPos = builder.getOperand(opPos, *opIndex.index);
571         }
572         predList.emplace_back(operandPos, builder.getEqualTo(pos));
573 
574         // Guard against duplicate upward visits. These are not possible,
575         // because if this value was already visited, it would have been
576         // cheaper to start the traversal at this value rather than at the
577         // `connector`, violating the optimality of our spanning tree.
578         bool inserted = valueToPosition.try_emplace(value, opPos).second;
579         (void)inserted;
580         assert(inserted && "duplicate upward visit");
581 
582         // Obtain the tree predicates at the current value.
583         getTreePredicates(predList, value, builder, valueToPosition, opPos,
584                           opIndex.index);
585 
586         // Update the position
587         pos = opPos;
588       })
589       .Case<pdl::ResultOp>([&](auto resultOp) {
590         // Traverse up an individual result.
591         auto *opPos = dyn_cast<OperationPosition>(pos);
592         assert(opPos && "operations and results must be interleaved");
593         pos = builder.getResult(opPos, *opIndex.index);
594 
595         // Insert the result position in case we have not visited it yet.
596         valueToPosition.try_emplace(value, pos);
597       })
598       .Case<pdl::ResultsOp>([&](auto resultOp) {
599         // Traverse up a group of results.
600         auto *opPos = dyn_cast<OperationPosition>(pos);
601         assert(opPos && "operations and results must be interleaved");
602         bool isVariadic = isa<pdl::RangeType>(value.getType());
603         if (opIndex.index)
604           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
605         else
606           pos = builder.getAllResults(opPos);
607 
608         // Insert the result position in case we have not visited it yet.
609         valueToPosition.try_emplace(value, pos);
610       });
611 }
612 
613 /// Given a pattern operation, build the set of matcher predicates necessary to
614 /// match this pattern.
615 static Value buildPredicateList(pdl::PatternOp pattern,
616                                 PredicateBuilder &builder,
617                                 std::vector<PositionalPredicate> &predList,
618                                 DenseMap<Value, Position *> &valueToPosition) {
619   SmallVector<Value> roots = detectRoots(pattern);
620 
621   // Build the root ordering graph and compute the parent maps.
622   RootOrderingGraph graph;
623   ParentMaps parentMaps;
624   buildCostGraph(roots, graph, parentMaps);
625   LLVM_DEBUG({
626     llvm::dbgs() << "Graph:\n";
627     for (auto &target : graph) {
628       llvm::dbgs() << "  * " << target.first.getLoc() << " " << target.first
629                    << "\n";
630       for (auto &source : target.second) {
631         RootOrderingEntry &entry = source.second;
632         llvm::dbgs() << "      <- " << source.first << ": " << entry.cost.first
633                      << ":" << entry.cost.second << " via "
634                      << entry.connector.getLoc() << "\n";
635       }
636     }
637   });
638 
639   // Solve the optimal branching problem for each candidate root, or use the
640   // provided one.
641   Value bestRoot = pattern.getRewriter().getRoot();
642   OptimalBranching::EdgeList bestEdges;
643   if (!bestRoot) {
644     unsigned bestCost = 0;
645     LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
646     for (Value root : roots) {
647       OptimalBranching solver(graph, root);
648       unsigned cost = solver.solve();
649       LLVM_DEBUG(llvm::dbgs() << "  * " << root << ": " << cost << "\n");
650       if (!bestRoot || bestCost > cost) {
651         bestCost = cost;
652         bestRoot = root;
653         bestEdges = solver.preOrderTraversal(roots);
654       }
655     }
656   } else {
657     OptimalBranching solver(graph, bestRoot);
658     solver.solve();
659     bestEdges = solver.preOrderTraversal(roots);
660   }
661 
662   // Print the best solution.
663   LLVM_DEBUG({
664     llvm::dbgs() << "Best tree:\n";
665     for (const std::pair<Value, Value> &edge : bestEdges) {
666       llvm::dbgs() << "  * " << edge.first;
667       if (edge.second)
668         llvm::dbgs() << " <- " << edge.second;
669       llvm::dbgs() << "\n";
670     }
671   });
672 
673   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
674   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
675 
676   // The best root is the starting point for the traversal. Get the tree
677   // predicates for the DAG rooted at bestRoot.
678   getTreePredicates(predList, bestRoot, builder, valueToPosition,
679                     builder.getRoot());
680 
681   // Traverse the selected optimal branching. For all edges in order, traverse
682   // up starting from the connector, until the candidate root is reached, and
683   // call getTreePredicates at every node along the way.
684   for (const auto &it : llvm::enumerate(bestEdges)) {
685     Value target = it.value().first;
686     Value source = it.value().second;
687 
688     // Check if we already visited the target root. This happens in two cases:
689     // 1) the initial root (bestRoot);
690     // 2) a root that is dominated by (contained in the subtree rooted at) an
691     //    already visited root.
692     if (valueToPosition.count(target))
693       continue;
694 
695     // Determine the connector.
696     Value connector = graph[target][source].connector;
697     assert(connector && "invalid edge");
698     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
699     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
700     Position *pos = valueToPosition.lookup(connector);
701     assert(pos && "connector has not been traversed yet");
702 
703     // Traverse from the connector upwards towards the target root.
704     for (Value value = connector; value != target;) {
705       OpIndex opIndex = parentMap.lookup(value);
706       assert(opIndex.parent && "missing parent");
707       visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
708       value = opIndex.parent;
709     }
710   }
711 
712   getNonTreePredicates(pattern, predList, builder, valueToPosition);
713 
714   return bestRoot;
715 }
716 
717 //===----------------------------------------------------------------------===//
718 // Pattern Predicate Tree Merging
719 //===----------------------------------------------------------------------===//
720 
721 namespace {
722 
723 /// This class represents a specific predicate applied to a position, and
724 /// provides hashing and ordering operators. This class allows for computing a
725 /// frequence sum and ordering predicates based on a cost model.
726 struct OrderedPredicate {
727   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
728       : position(ip.first), question(ip.second) {}
729   OrderedPredicate(const PositionalPredicate &ip)
730       : position(ip.position), question(ip.question) {}
731 
732   /// The position this predicate is applied to.
733   Position *position;
734 
735   /// The question that is applied by this predicate onto the position.
736   Qualifier *question;
737 
738   /// The first and second order benefit sums.
739   /// The primary sum is the number of occurrences of this predicate among all
740   /// of the patterns.
741   unsigned primary = 0;
742   /// The secondary sum is a squared summation of the primary sum of all of the
743   /// predicates within each pattern that contains this predicate. This allows
744   /// for favoring predicates that are more commonly shared within a pattern, as
745   /// opposed to those shared across patterns.
746   unsigned secondary = 0;
747 
748   /// The tie breaking ID, used to preserve a deterministic (insertion) order
749   /// among all the predicates with the same priority, depth, and position /
750   /// predicate dependency.
751   unsigned id = 0;
752 
753   /// A map between a pattern operation and the answer to the predicate question
754   /// within that pattern.
755   DenseMap<Operation *, Qualifier *> patternToAnswer;
756 
757   /// Returns true if this predicate is ordered before `rhs`, based on the cost
758   /// model.
759   bool operator<(const OrderedPredicate &rhs) const {
760     // Sort by:
761     // * higher first and secondary order sums
762     // * lower depth
763     // * lower position dependency
764     // * lower predicate dependency
765     // * lower tie breaking ID
766     auto *rhsPos = rhs.position;
767     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
768                            rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
769            std::make_tuple(rhs.primary, rhs.secondary,
770                            position->getOperationDepth(), position->getKind(),
771                            question->getKind(), id);
772   }
773 };
774 
775 /// A DenseMapInfo for OrderedPredicate based solely on the position and
776 /// question.
777 struct OrderedPredicateDenseInfo {
778   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
779 
780   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
781   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
782   static bool isEqual(const OrderedPredicate &lhs,
783                       const OrderedPredicate &rhs) {
784     return lhs.position == rhs.position && lhs.question == rhs.question;
785   }
786   static unsigned getHashValue(const OrderedPredicate &p) {
787     return llvm::hash_combine(p.position, p.question);
788   }
789 };
790 
791 /// This class wraps a set of ordered predicates that are used within a specific
792 /// pattern operation.
793 struct OrderedPredicateList {
794   OrderedPredicateList(pdl::PatternOp pattern, Value root)
795       : pattern(pattern), root(root) {}
796 
797   pdl::PatternOp pattern;
798   Value root;
799   DenseSet<OrderedPredicate *> predicates;
800 };
801 } // namespace
802 
803 /// Returns true if the given matcher refers to the same predicate as the given
804 /// ordered predicate. This means that the position and questions of the two
805 /// match.
806 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
807   return node->getPosition() == predicate->position &&
808          node->getQuestion() == predicate->question;
809 }
810 
811 /// Get or insert a child matcher for the given parent switch node, given a
812 /// predicate and parent pattern.
813 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
814                                                OrderedPredicate *predicate,
815                                                pdl::PatternOp pattern) {
816   assert(isSamePredicate(node, predicate) &&
817          "expected matcher to equal the given predicate");
818 
819   auto it = predicate->patternToAnswer.find(pattern);
820   assert(it != predicate->patternToAnswer.end() &&
821          "expected pattern to exist in predicate");
822   return node->getChildren().insert({it->second, nullptr}).first->second;
823 }
824 
825 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
826 /// order. A pattern will traverse as far as possible using common predicates
827 /// and then either diverge from the CFG or reach the end of a branch and start
828 /// creating new nodes.
829 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
830                              OrderedPredicateList &list,
831                              std::vector<OrderedPredicate *>::iterator current,
832                              std::vector<OrderedPredicate *>::iterator end) {
833   if (current == end) {
834     // We've hit the end of a pattern, so create a successful result node.
835     node =
836         std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
837 
838     // If the pattern doesn't contain this predicate, ignore it.
839   } else if (!list.predicates.contains(*current)) {
840     propagatePattern(node, list, std::next(current), end);
841 
842     // If the current matcher node is invalid, create a new one for this
843     // position and continue propagation.
844   } else if (!node) {
845     // Create a new node at this position and continue
846     node = std::make_unique<SwitchNode>((*current)->position,
847                                         (*current)->question);
848     propagatePattern(
849         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
850         list, std::next(current), end);
851 
852     // If the matcher has already been created, and it is for this predicate we
853     // continue propagation to the child.
854   } else if (isSamePredicate(node.get(), *current)) {
855     propagatePattern(
856         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
857         list, std::next(current), end);
858 
859     // If the matcher doesn't match the current predicate, insert a branch as
860     // the common set of matchers has diverged.
861   } else {
862     propagatePattern(node->getFailureNode(), list, current, end);
863   }
864 }
865 
866 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
867 /// `node` is updated in-place if it is a switch.
868 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
869   if (!node)
870     return;
871 
872   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
873     SwitchNode::ChildMapT &children = switchNode->getChildren();
874     for (auto &it : children)
875       foldSwitchToBool(it.second);
876 
877     // If the node only contains one child, collapse it into a boolean predicate
878     // node.
879     if (children.size() == 1) {
880       auto *childIt = children.begin();
881       node = std::make_unique<BoolNode>(
882           node->getPosition(), node->getQuestion(), childIt->first,
883           std::move(childIt->second), std::move(node->getFailureNode()));
884     }
885   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
886     foldSwitchToBool(boolNode->getSuccessNode());
887   }
888 
889   foldSwitchToBool(node->getFailureNode());
890 }
891 
892 /// Insert an exit node at the end of the failure path of the `root`.
893 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
894   while (*root)
895     root = &(*root)->getFailureNode();
896   *root = std::make_unique<ExitNode>();
897 }
898 
899 /// Sorts the range begin/end with the partial order given by cmp.
900 template <typename Iterator, typename Compare>
901 static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
902   while (begin != end) {
903     // Cannot compute sortBeforeOthers in the predicate of stable_partition
904     // because stable_partition will not keep the [begin, end) range intact
905     // while it runs.
906     llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
907     for (auto i = begin; i != end; ++i) {
908       if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
909         sortBeforeOthers.insert(*i);
910     }
911 
912     auto const next = std::stable_partition(begin, end, [&](auto const &a) {
913       return sortBeforeOthers.contains(a);
914     });
915     assert(next != begin && "not a partial ordering");
916     begin = next;
917   }
918 }
919 
920 /// Returns true if 'b' depends on a result of 'a'.
921 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
922   auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
923   if (!cqa)
924     return false;
925 
926   auto positionDependsOnA = [&](Position *p) {
927     auto *cp = dyn_cast<ConstraintPosition>(p);
928     return cp && cp->getQuestion() == cqa;
929   };
930 
931   if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
932     // Does any argument of b use a?
933     return llvm::any_of(cqb->getArgs(), positionDependsOnA);
934   }
935   if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
936     return positionDependsOnA(b->position) ||
937            positionDependsOnA(equalTo->getValue());
938   }
939   return positionDependsOnA(b->position);
940 }
941 
942 /// Given a module containing PDL pattern operations, generate a matcher tree
943 /// using the patterns within the given module and return the root matcher node.
944 std::unique_ptr<MatcherNode>
945 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
946                                  DenseMap<Value, Position *> &valueToPosition) {
947   // The set of predicates contained within the pattern operations of the
948   // module.
949   struct PatternPredicates {
950     PatternPredicates(pdl::PatternOp pattern, Value root,
951                       std::vector<PositionalPredicate> predicates)
952         : pattern(pattern), root(root), predicates(std::move(predicates)) {}
953 
954     /// A pattern.
955     pdl::PatternOp pattern;
956 
957     /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
958     Value root;
959 
960     /// The extracted predicates for this pattern and root.
961     std::vector<PositionalPredicate> predicates;
962   };
963 
964   SmallVector<PatternPredicates, 16> patternsAndPredicates;
965   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
966     std::vector<PositionalPredicate> predicateList;
967     Value root =
968         buildPredicateList(pattern, builder, predicateList, valueToPosition);
969     patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
970   }
971 
972   // Associate a pattern result with each unique predicate.
973   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
974   for (auto &patternAndPredList : patternsAndPredicates) {
975     for (auto &predicate : patternAndPredList.predicates) {
976       auto it = uniqued.insert(predicate);
977       it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
978                                             predicate.answer);
979       // Mark the insertion order (0-based indexing).
980       if (it.second)
981         it.first->id = uniqued.size() - 1;
982     }
983   }
984 
985   // Associate each pattern to a set of its ordered predicates for later lookup.
986   std::vector<OrderedPredicateList> lists;
987   lists.reserve(patternsAndPredicates.size());
988   for (auto &patternAndPredList : patternsAndPredicates) {
989     OrderedPredicateList list(patternAndPredList.pattern,
990                               patternAndPredList.root);
991     for (auto &predicate : patternAndPredList.predicates) {
992       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
993       list.predicates.insert(orderedPredicate);
994 
995       // Increment the primary sum for each reference to a particular predicate.
996       ++orderedPredicate->primary;
997     }
998     lists.push_back(std::move(list));
999   }
1000 
1001   // For a particular pattern, get the total primary sum and add it to the
1002   // secondary sum of each predicate. Square the primary sums to emphasize
1003   // shared predicates within rather than across patterns.
1004   for (auto &list : lists) {
1005     unsigned total = 0;
1006     for (auto *predicate : list.predicates)
1007       total += predicate->primary * predicate->primary;
1008     for (auto *predicate : list.predicates)
1009       predicate->secondary += total;
1010   }
1011 
1012   // Sort the set of predicates now that the cost primary and secondary sums
1013   // have been computed.
1014   std::vector<OrderedPredicate *> ordered;
1015   ordered.reserve(uniqued.size());
1016   for (auto &ip : uniqued)
1017     ordered.push_back(&ip);
1018   llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1019     return *lhs < *rhs;
1020   });
1021 
1022   // Mostly keep the now established order, but also ensure that
1023   // ConstraintQuestions come after the results they use.
1024   stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
1025 
1026   // Build the matchers for each of the pattern predicate lists.
1027   std::unique_ptr<MatcherNode> root;
1028   for (OrderedPredicateList &list : lists)
1029     propagatePattern(root, list, ordered.begin(), ordered.end());
1030 
1031   // Collapse the graph and insert the exit node.
1032   foldSwitchToBool(root);
1033   insertExitNode(&root);
1034   return root;
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // MatcherNode
1039 //===----------------------------------------------------------------------===//
1040 
1041 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
1042                          std::unique_ptr<MatcherNode> failureNode)
1043     : position(p), question(q), failureNode(std::move(failureNode)),
1044       matcherTypeID(matcherTypeID) {}
1045 
1046 //===----------------------------------------------------------------------===//
1047 // BoolNode
1048 //===----------------------------------------------------------------------===//
1049 
1050 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1051                    std::unique_ptr<MatcherNode> successNode,
1052                    std::unique_ptr<MatcherNode> failureNode)
1053     : MatcherNode(TypeID::get<BoolNode>(), position, question,
1054                   std::move(failureNode)),
1055       answer(answer), successNode(std::move(successNode)) {}
1056 
1057 //===----------------------------------------------------------------------===//
1058 // SuccessNode
1059 //===----------------------------------------------------------------------===//
1060 
1061 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
1062                          std::unique_ptr<MatcherNode> failureNode)
1063     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
1064                   /*question=*/nullptr, std::move(failureNode)),
1065       pattern(pattern), root(root) {}
1066 
1067 //===----------------------------------------------------------------------===//
1068 // SwitchNode
1069 //===----------------------------------------------------------------------===//
1070 
1071 SwitchNode::SwitchNode(Position *position, Qualifier *question)
1072     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
1073