xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (revision 5d6cb6f78ac93aedcf96e3a3bca61401a2177f31)
18a1ca2cdSRiver Riddle //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
28a1ca2cdSRiver Riddle //
38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a1ca2cdSRiver Riddle //
78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
88a1ca2cdSRiver Riddle 
98a1ca2cdSRiver Riddle #include "PredicateTree.h"
10a76ee58fSStanislav Funiak #include "RootOrdering.h"
11a76ee58fSStanislav Funiak 
128a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
138a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
148a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
1565fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
168a1ca2cdSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h"
17a76ee58fSStanislav Funiak #include "llvm/ADT/MapVector.h"
188ec28af8SMatthias Gehre #include "llvm/ADT/SmallPtrSet.h"
19242762c9SRiver Riddle #include "llvm/ADT/TypeSwitch.h"
20a76ee58fSStanislav Funiak #include "llvm/Support/Debug.h"
21a76ee58fSStanislav Funiak #include <queue>
22a76ee58fSStanislav Funiak 
23a76ee58fSStanislav Funiak #define DEBUG_TYPE "pdl-predicate-tree"
248a1ca2cdSRiver Riddle 
258a1ca2cdSRiver Riddle using namespace mlir;
268a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp;
278a1ca2cdSRiver Riddle 
288a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
298a1ca2cdSRiver Riddle // Predicate List Building
308a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
318a1ca2cdSRiver Riddle 
32242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
33242762c9SRiver Riddle                               Value val, PredicateBuilder &builder,
34242762c9SRiver Riddle                               DenseMap<Value, Position *> &inputs,
35242762c9SRiver Riddle                               Position *pos);
36242762c9SRiver Riddle 
378a1ca2cdSRiver Riddle /// Compares the depths of two positions.
388a1ca2cdSRiver Riddle static bool comparePosDepth(Position *lhs, Position *rhs) {
393a833a0eSRiver Riddle   return lhs->getOperationDepth() < rhs->getOperationDepth();
403a833a0eSRiver Riddle }
413a833a0eSRiver Riddle 
423a833a0eSRiver Riddle /// Returns the number of non-range elements within `values`.
433a833a0eSRiver Riddle static unsigned getNumNonRangeValues(ValueRange values) {
443a833a0eSRiver Riddle   return llvm::count_if(values.getTypes(),
455550c821STres Popp                         [](Type type) { return !isa<pdl::RangeType>(type); });
468a1ca2cdSRiver Riddle }
478a1ca2cdSRiver Riddle 
488a1ca2cdSRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
498a1ca2cdSRiver Riddle                               Value val, PredicateBuilder &builder,
508a1ca2cdSRiver Riddle                               DenseMap<Value, Position *> &inputs,
51242762c9SRiver Riddle                               AttributePosition *pos) {
525550c821STres Popp   assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
538a1ca2cdSRiver Riddle   predList.emplace_back(pos, builder.getIsNotNull());
548a1ca2cdSRiver Riddle 
558ec28af8SMatthias Gehre   if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
56242762c9SRiver Riddle     // If the attribute has a type or value, add a constraint.
5772fddfb5SRiver Riddle     if (Value type = attr.getValueType())
588a1ca2cdSRiver Riddle       getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
59310c3ee4SRiver Riddle     else if (Attribute value = attr.getValueAttr())
60242762c9SRiver Riddle       predList.emplace_back(pos, builder.getAttributeConstraint(value));
61242762c9SRiver Riddle   }
628ec28af8SMatthias Gehre }
638a1ca2cdSRiver Riddle 
643a833a0eSRiver Riddle /// Collect all of the predicates for the given operand position.
653a833a0eSRiver Riddle static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
66242762c9SRiver Riddle                                      Value val, PredicateBuilder &builder,
67242762c9SRiver Riddle                                      DenseMap<Value, Position *> &inputs,
683a833a0eSRiver Riddle                                      Position *pos) {
693a833a0eSRiver Riddle   Type valueType = val.getType();
705550c821STres Popp   bool isVariadic = isa<pdl::RangeType>(valueType);
718a1ca2cdSRiver Riddle 
72e07c968aSRiver Riddle   // If this is a typed operand, add a type constraint.
733a833a0eSRiver Riddle   TypeSwitch<Operation *>(val.getDefiningOp())
743a833a0eSRiver Riddle       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
753a833a0eSRiver Riddle         // Prevent traversal into a null value if the operand has a proper
763a833a0eSRiver Riddle         // index.
773a833a0eSRiver Riddle         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
783a833a0eSRiver Riddle             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
793a833a0eSRiver Riddle           predList.emplace_back(pos, builder.getIsNotNull());
80242762c9SRiver Riddle 
8172fddfb5SRiver Riddle         if (Value type = op.getValueType())
823a833a0eSRiver Riddle           getTreePredicates(predList, type, builder, inputs,
833a833a0eSRiver Riddle                             builder.getType(pos));
843a833a0eSRiver Riddle       })
853a833a0eSRiver Riddle       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
8622426110SRamkumar Ramachandra         std::optional<unsigned> index = op.getIndex();
873a833a0eSRiver Riddle 
883a833a0eSRiver Riddle         // Prevent traversal into a null value if the result has a proper index.
893a833a0eSRiver Riddle         if (index)
903a833a0eSRiver Riddle           predList.emplace_back(pos, builder.getIsNotNull());
913a833a0eSRiver Riddle 
923a833a0eSRiver Riddle         // Get the parent operation of this operand.
933a833a0eSRiver Riddle         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
94242762c9SRiver Riddle         predList.emplace_back(parentPos, builder.getIsNotNull());
953a833a0eSRiver Riddle 
963a833a0eSRiver Riddle         // Ensure that the operands match the corresponding results of the
973a833a0eSRiver Riddle         // parent operation.
983a833a0eSRiver Riddle         Position *resultPos = nullptr;
993a833a0eSRiver Riddle         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
1003a833a0eSRiver Riddle           resultPos = builder.getResult(parentPos, *index);
1013a833a0eSRiver Riddle         else
1023a833a0eSRiver Riddle           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
103242762c9SRiver Riddle         predList.emplace_back(resultPos, builder.getEqualTo(pos));
1043a833a0eSRiver Riddle 
1053a833a0eSRiver Riddle         // Collect the predicates of the parent operation.
106310c3ee4SRiver Riddle         getTreePredicates(predList, op.getParent(), builder, inputs,
1071f13963eSRiver Riddle                           (Position *)parentPos);
1083a833a0eSRiver Riddle       });
1098a1ca2cdSRiver Riddle }
1108a1ca2cdSRiver Riddle 
11122426110SRamkumar Ramachandra static void
11222426110SRamkumar Ramachandra getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
11322426110SRamkumar Ramachandra                   PredicateBuilder &builder,
11422426110SRamkumar Ramachandra                   DenseMap<Value, Position *> &inputs, OperationPosition *pos,
11522426110SRamkumar Ramachandra                   std::optional<unsigned> ignoreOperand = std::nullopt) {
1165550c821STres Popp   assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
1178a1ca2cdSRiver Riddle   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
1188a1ca2cdSRiver Riddle   OperationPosition *opPos = cast<OperationPosition>(pos);
1198a1ca2cdSRiver Riddle 
1208a1ca2cdSRiver Riddle   // Ensure getDefiningOp returns a non-null operation.
1218a1ca2cdSRiver Riddle   if (!opPos->isRoot())
1228a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getIsNotNull());
1238a1ca2cdSRiver Riddle 
1248a1ca2cdSRiver Riddle   // Check that this is the correct root operation.
12522426110SRamkumar Ramachandra   if (std::optional<StringRef> opName = op.getOpName())
1268a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getOperationName(*opName));
1278a1ca2cdSRiver Riddle 
1283a833a0eSRiver Riddle   // Check that the operation has the proper number of operands. If there are
1293a833a0eSRiver Riddle   // any variable length operands, we check a minimum instead of an exact count.
13072fddfb5SRiver Riddle   OperandRange operands = op.getOperandValues();
1313a833a0eSRiver Riddle   unsigned minOperands = getNumNonRangeValues(operands);
1323a833a0eSRiver Riddle   if (minOperands != operands.size()) {
1333a833a0eSRiver Riddle     if (minOperands)
1343a833a0eSRiver Riddle       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
1353a833a0eSRiver Riddle   } else {
1363a833a0eSRiver Riddle     predList.emplace_back(pos, builder.getOperandCount(minOperands));
1373a833a0eSRiver Riddle   }
1383a833a0eSRiver Riddle 
1393a833a0eSRiver Riddle   // Check that the operation has the proper number of results. If there are
1403a833a0eSRiver Riddle   // any variable length results, we check a minimum instead of an exact count.
14172fddfb5SRiver Riddle   OperandRange types = op.getTypeValues();
1423a833a0eSRiver Riddle   unsigned minResults = getNumNonRangeValues(types);
1433a833a0eSRiver Riddle   if (minResults == types.size())
144242762c9SRiver Riddle     predList.emplace_back(pos, builder.getResultCount(types.size()));
1453a833a0eSRiver Riddle   else if (minResults)
1463a833a0eSRiver Riddle     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
1478a1ca2cdSRiver Riddle 
1488a1ca2cdSRiver Riddle   // Recurse into any attributes, operands, or results.
14972fddfb5SRiver Riddle   for (auto [attrName, attr] :
15072fddfb5SRiver Riddle        llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
1518a1ca2cdSRiver Riddle     getTreePredicates(
15272fddfb5SRiver Riddle         predList, attr, builder, inputs,
1535550c821STres Popp         builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
1548a1ca2cdSRiver Riddle   }
1553a833a0eSRiver Riddle 
1563a833a0eSRiver Riddle   // Process the operands and results of the operation. For all values up to
1573a833a0eSRiver Riddle   // the first variable length value, we use the concrete operand/result
1583a833a0eSRiver Riddle   // number. After that, we use the "group" given that we can't know the
1593a833a0eSRiver Riddle   // concrete indices until runtime. If there is only one variadic operand
1603a833a0eSRiver Riddle   // group, we treat it as all of the operands/results of the operation.
1613a833a0eSRiver Riddle   /// Operands.
1625550c821STres Popp   if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
1632692eae5SStanislav Funiak     // Ignore the operands if we are performing an upward traversal (in that
1642692eae5SStanislav Funiak     // case, they have already been visited).
1652692eae5SStanislav Funiak     if (opPos->isRoot() || opPos->isOperandDefiningOp())
1663a833a0eSRiver Riddle       getTreePredicates(predList, operands.front(), builder, inputs,
1673a833a0eSRiver Riddle                         builder.getAllOperands(opPos));
1683a833a0eSRiver Riddle   } else {
1693a833a0eSRiver Riddle     bool foundVariableLength = false;
170e4853be2SMehdi Amini     for (const auto &operandIt : llvm::enumerate(operands)) {
1715550c821STres Popp       bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
1723a833a0eSRiver Riddle       foundVariableLength |= isVariadic;
1733a833a0eSRiver Riddle 
174a76ee58fSStanislav Funiak       // Ignore the specified operand, usually because this position was
175a76ee58fSStanislav Funiak       // visited in an upward traversal via an iterative choice.
176a76ee58fSStanislav Funiak       if (ignoreOperand && *ignoreOperand == operandIt.index())
177a76ee58fSStanislav Funiak         continue;
178a76ee58fSStanislav Funiak 
1793a833a0eSRiver Riddle       Position *pos =
1803a833a0eSRiver Riddle           foundVariableLength
1813a833a0eSRiver Riddle               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
1823a833a0eSRiver Riddle               : builder.getOperand(opPos, operandIt.index());
1833a833a0eSRiver Riddle       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
184242762c9SRiver Riddle     }
1853a833a0eSRiver Riddle   }
1863a833a0eSRiver Riddle   /// Results.
1875550c821STres Popp   if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
1883a833a0eSRiver Riddle     getTreePredicates(predList, types.front(), builder, inputs,
1893a833a0eSRiver Riddle                       builder.getType(builder.getAllResults(opPos)));
1908c258fdaSJakub Kuderski     return;
1918c258fdaSJakub Kuderski   }
1928c258fdaSJakub Kuderski 
1933a833a0eSRiver Riddle   bool foundVariableLength = false;
1948c258fdaSJakub Kuderski   for (auto [idx, typeValue] : llvm::enumerate(types)) {
1955550c821STres Popp     bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
1963a833a0eSRiver Riddle     foundVariableLength |= isVariadic;
1973a833a0eSRiver Riddle 
1988c258fdaSJakub Kuderski     auto *resultPos = foundVariableLength
1998c258fdaSJakub Kuderski                           ? builder.getResultGroup(pos, idx, isVariadic)
2008c258fdaSJakub Kuderski                           : builder.getResult(pos, idx);
201242762c9SRiver Riddle     predList.emplace_back(resultPos, builder.getIsNotNull());
2028c258fdaSJakub Kuderski     getTreePredicates(predList, typeValue, builder, inputs,
203242762c9SRiver Riddle                       builder.getType(resultPos));
2048a1ca2cdSRiver Riddle   }
2058a1ca2cdSRiver Riddle }
2068a1ca2cdSRiver Riddle 
207242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
208242762c9SRiver Riddle                               Value val, PredicateBuilder &builder,
209242762c9SRiver Riddle                               DenseMap<Value, Position *> &inputs,
210242762c9SRiver Riddle                               TypePosition *pos) {
2118a1ca2cdSRiver Riddle   // Check for a constraint on a constant type.
2123a833a0eSRiver Riddle   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
21372fddfb5SRiver Riddle     if (Attribute type = typeOp.getConstantTypeAttr())
2143a833a0eSRiver Riddle       predList.emplace_back(pos, builder.getTypeConstraint(type));
2153a833a0eSRiver Riddle   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
21672fddfb5SRiver Riddle     if (Attribute typeAttr = typeOp.getConstantTypesAttr())
2173a833a0eSRiver Riddle       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
2183a833a0eSRiver Riddle   }
2198a1ca2cdSRiver Riddle }
220242762c9SRiver Riddle 
221242762c9SRiver Riddle /// Collect the tree predicates anchored at the given value.
222242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
223242762c9SRiver Riddle                               Value val, PredicateBuilder &builder,
224242762c9SRiver Riddle                               DenseMap<Value, Position *> &inputs,
225242762c9SRiver Riddle                               Position *pos) {
226242762c9SRiver Riddle   // Make sure this input value is accessible to the rewrite.
227242762c9SRiver Riddle   auto it = inputs.try_emplace(val, pos);
228242762c9SRiver Riddle   if (!it.second) {
229242762c9SRiver Riddle     // If this is an input value that has been visited in the tree, add a
230242762c9SRiver Riddle     // constraint to ensure that both instances refer to the same value.
2313a833a0eSRiver Riddle     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
2323a833a0eSRiver Riddle             pdl::TypeOp>(val.getDefiningOp())) {
233242762c9SRiver Riddle       auto minMaxPositions =
234242762c9SRiver Riddle           std::minmax(pos, it.first->second, comparePosDepth);
235242762c9SRiver Riddle       predList.emplace_back(minMaxPositions.second,
236242762c9SRiver Riddle                             builder.getEqualTo(minMaxPositions.first));
2378a1ca2cdSRiver Riddle     }
238242762c9SRiver Riddle     return;
239242762c9SRiver Riddle   }
240242762c9SRiver Riddle 
241242762c9SRiver Riddle   TypeSwitch<Position *>(pos)
2423a833a0eSRiver Riddle       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
2433a833a0eSRiver Riddle         getTreePredicates(predList, val, builder, inputs, pos);
2443a833a0eSRiver Riddle       })
2453a833a0eSRiver Riddle       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
2463a833a0eSRiver Riddle         getOperandTreePredicates(predList, val, builder, inputs, pos);
247242762c9SRiver Riddle       })
248242762c9SRiver Riddle       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
2498a1ca2cdSRiver Riddle }
2508a1ca2cdSRiver Riddle 
251233e9476SRiver Riddle static void getAttributePredicates(pdl::AttributeOp op,
252233e9476SRiver Riddle                                    std::vector<PositionalPredicate> &predList,
253233e9476SRiver Riddle                                    PredicateBuilder &builder,
254233e9476SRiver Riddle                                    DenseMap<Value, Position *> &inputs) {
255233e9476SRiver Riddle   Position *&attrPos = inputs[op];
256233e9476SRiver Riddle   if (attrPos)
257233e9476SRiver Riddle     return;
258310c3ee4SRiver Riddle   Attribute value = op.getValueAttr();
259233e9476SRiver Riddle   assert(value && "expected non-tree `pdl.attribute` to contain a value");
260233e9476SRiver Riddle   attrPos = builder.getAttributeLiteral(value);
261233e9476SRiver Riddle }
262233e9476SRiver Riddle 
26302c4c0d5SRiver Riddle static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
264242762c9SRiver Riddle                                     std::vector<PositionalPredicate> &predList,
265242762c9SRiver Riddle                                     PredicateBuilder &builder,
266242762c9SRiver Riddle                                     DenseMap<Value, Position *> &inputs) {
267310c3ee4SRiver Riddle   OperandRange arguments = op.getArgs();
2688a1ca2cdSRiver Riddle 
2698a1ca2cdSRiver Riddle   std::vector<Position *> allPositions;
2708a1ca2cdSRiver Riddle   allPositions.reserve(arguments.size());
2718a1ca2cdSRiver Riddle   for (Value arg : arguments)
2728a1ca2cdSRiver Riddle     allPositions.push_back(inputs.lookup(arg));
2738a1ca2cdSRiver Riddle 
2748a1ca2cdSRiver Riddle   // Push the constraint to the furthest position.
275fab2bb8bSJustin Lebar   Position *pos = *llvm::max_element(allPositions, comparePosDepth);
2768ec28af8SMatthias Gehre   ResultRange results = op.getResults();
2778ec28af8SMatthias Gehre   PredicateBuilder::Predicate pred = builder.getConstraint(
2788ec28af8SMatthias Gehre       op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
2798ec28af8SMatthias Gehre       op.getIsNegated());
2808ec28af8SMatthias Gehre 
2818ec28af8SMatthias Gehre   // For each result register a position so it can be used later
2828ec28af8SMatthias Gehre   for (auto [i, result] : llvm::enumerate(results)) {
2838ec28af8SMatthias Gehre     ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
2848ec28af8SMatthias Gehre     ConstraintPosition *pos = builder.getConstraintPosition(q, i);
2858ec28af8SMatthias Gehre     auto [it, inserted] = inputs.try_emplace(result, pos);
2868ec28af8SMatthias Gehre     // If this is an input value that has been visited in the tree, add a
2878ec28af8SMatthias Gehre     // constraint to ensure that both instances refer to the same value.
2888ec28af8SMatthias Gehre     if (!inserted) {
2898ec28af8SMatthias Gehre       Position *first = pos;
2908ec28af8SMatthias Gehre       Position *second = it->second;
2918ec28af8SMatthias Gehre       if (comparePosDepth(second, first))
2928ec28af8SMatthias Gehre         std::tie(second, first) = std::make_pair(first, second);
2938ec28af8SMatthias Gehre 
2948ec28af8SMatthias Gehre       predList.emplace_back(second, builder.getEqualTo(first));
2958ec28af8SMatthias Gehre     }
2968ec28af8SMatthias Gehre   }
2978a1ca2cdSRiver Riddle   predList.emplace_back(pos, pred);
2988a1ca2cdSRiver Riddle }
299242762c9SRiver Riddle 
300242762c9SRiver Riddle static void getResultPredicates(pdl::ResultOp op,
301242762c9SRiver Riddle                                 std::vector<PositionalPredicate> &predList,
302242762c9SRiver Riddle                                 PredicateBuilder &builder,
303242762c9SRiver Riddle                                 DenseMap<Value, Position *> &inputs) {
304242762c9SRiver Riddle   Position *&resultPos = inputs[op];
305242762c9SRiver Riddle   if (resultPos)
306242762c9SRiver Riddle     return;
3073a833a0eSRiver Riddle 
3083a833a0eSRiver Riddle   // Ensure that the result isn't null.
309310c3ee4SRiver Riddle   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
310310c3ee4SRiver Riddle   resultPos = builder.getResult(parentPos, op.getIndex());
311242762c9SRiver Riddle   predList.emplace_back(resultPos, builder.getIsNotNull());
312242762c9SRiver Riddle }
313242762c9SRiver Riddle 
3143a833a0eSRiver Riddle static void getResultPredicates(pdl::ResultsOp op,
3153a833a0eSRiver Riddle                                 std::vector<PositionalPredicate> &predList,
3163a833a0eSRiver Riddle                                 PredicateBuilder &builder,
3173a833a0eSRiver Riddle                                 DenseMap<Value, Position *> &inputs) {
3183a833a0eSRiver Riddle   Position *&resultPos = inputs[op];
3193a833a0eSRiver Riddle   if (resultPos)
3203a833a0eSRiver Riddle     return;
3213a833a0eSRiver Riddle 
3223a833a0eSRiver Riddle   // Ensure that the result isn't null if the result has an index.
323310c3ee4SRiver Riddle   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
3245550c821STres Popp   bool isVariadic = isa<pdl::RangeType>(op.getType());
32522426110SRamkumar Ramachandra   std::optional<unsigned> index = op.getIndex();
3263a833a0eSRiver Riddle   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
3273a833a0eSRiver Riddle   if (index)
3283a833a0eSRiver Riddle     predList.emplace_back(resultPos, builder.getIsNotNull());
3293a833a0eSRiver Riddle }
3303a833a0eSRiver Riddle 
331233e9476SRiver Riddle static void getTypePredicates(Value typeValue,
332233e9476SRiver Riddle                               function_ref<Attribute()> typeAttrFn,
333233e9476SRiver Riddle                               PredicateBuilder &builder,
334233e9476SRiver Riddle                               DenseMap<Value, Position *> &inputs) {
335233e9476SRiver Riddle   Position *&typePos = inputs[typeValue];
336233e9476SRiver Riddle   if (typePos)
337233e9476SRiver Riddle     return;
338233e9476SRiver Riddle   Attribute typeAttr = typeAttrFn();
339233e9476SRiver Riddle   assert(typeAttr &&
340233e9476SRiver Riddle          "expected non-tree `pdl.type`/`pdl.types` to contain a value");
341233e9476SRiver Riddle   typePos = builder.getTypeLiteral(typeAttr);
342233e9476SRiver Riddle }
343233e9476SRiver Riddle 
344242762c9SRiver Riddle /// Collect all of the predicates that cannot be determined via walking the
345242762c9SRiver Riddle /// tree.
346242762c9SRiver Riddle static void getNonTreePredicates(pdl::PatternOp pattern,
347242762c9SRiver Riddle                                  std::vector<PositionalPredicate> &predList,
348242762c9SRiver Riddle                                  PredicateBuilder &builder,
349242762c9SRiver Riddle                                  DenseMap<Value, Position *> &inputs) {
35072fddfb5SRiver Riddle   for (Operation &op : pattern.getBodyRegion().getOps()) {
3513a833a0eSRiver Riddle     TypeSwitch<Operation *>(&op)
352233e9476SRiver Riddle         .Case([&](pdl::AttributeOp attrOp) {
353233e9476SRiver Riddle           getAttributePredicates(attrOp, predList, builder, inputs);
354233e9476SRiver Riddle         })
3553a833a0eSRiver Riddle         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
356242762c9SRiver Riddle           getConstraintPredicates(constraintOp, predList, builder, inputs);
3573a833a0eSRiver Riddle         })
3583a833a0eSRiver Riddle         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
359242762c9SRiver Riddle           getResultPredicates(resultOp, predList, builder, inputs);
360233e9476SRiver Riddle         })
361233e9476SRiver Riddle         .Case([&](pdl::TypeOp typeOp) {
362233e9476SRiver Riddle           getTypePredicates(
36372fddfb5SRiver Riddle               typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
36472fddfb5SRiver Riddle               inputs);
365233e9476SRiver Riddle         })
366233e9476SRiver Riddle         .Case([&](pdl::TypesOp typeOp) {
367233e9476SRiver Riddle           getTypePredicates(
36872fddfb5SRiver Riddle               typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
36972fddfb5SRiver Riddle               inputs);
3703a833a0eSRiver Riddle         });
371242762c9SRiver Riddle   }
3728a1ca2cdSRiver Riddle }
3738a1ca2cdSRiver Riddle 
374a76ee58fSStanislav Funiak namespace {
375a76ee58fSStanislav Funiak 
376a76ee58fSStanislav Funiak /// An op accepting a value at an optional index.
377a76ee58fSStanislav Funiak struct OpIndex {
378a76ee58fSStanislav Funiak   Value parent;
37922426110SRamkumar Ramachandra   std::optional<unsigned> index;
380a76ee58fSStanislav Funiak };
381a76ee58fSStanislav Funiak 
382a76ee58fSStanislav Funiak /// The parent and operand index of each operation for each root, stored
383a76ee58fSStanislav Funiak /// as a nested map [root][operation].
384a76ee58fSStanislav Funiak using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
385a76ee58fSStanislav Funiak 
386a76ee58fSStanislav Funiak } // namespace
387a76ee58fSStanislav Funiak 
388a76ee58fSStanislav Funiak /// Given a pattern, determines the set of roots present in this pattern.
389a76ee58fSStanislav Funiak /// These are the operations whose results are not consumed by other operations.
390a76ee58fSStanislav Funiak static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
391a76ee58fSStanislav Funiak   // First, collect all the operations that are used as operands
392a76ee58fSStanislav Funiak   // to other operations. These are not roots by default.
393a76ee58fSStanislav Funiak   DenseSet<Value> used;
39472fddfb5SRiver Riddle   for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
39572fddfb5SRiver Riddle     for (Value operand : operationOp.getOperandValues())
396a76ee58fSStanislav Funiak       TypeSwitch<Operation *>(operand.getDefiningOp())
397a76ee58fSStanislav Funiak           .Case<pdl::ResultOp, pdl::ResultsOp>(
398310c3ee4SRiver Riddle               [&used](auto resultOp) { used.insert(resultOp.getParent()); });
399a76ee58fSStanislav Funiak   }
400a76ee58fSStanislav Funiak 
401a76ee58fSStanislav Funiak   // Remove the specified root from the use set, so that we can
402a76ee58fSStanislav Funiak   // always select it as a root, even if it is used by other operations.
403310c3ee4SRiver Riddle   if (Value root = pattern.getRewriter().getRoot())
404a76ee58fSStanislav Funiak     used.erase(root);
405a76ee58fSStanislav Funiak 
406a76ee58fSStanislav Funiak   // Finally, collect all the unused operations.
407a76ee58fSStanislav Funiak   SmallVector<Value> roots;
40872fddfb5SRiver Riddle   for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
409a76ee58fSStanislav Funiak     if (!used.contains(operationOp))
410a76ee58fSStanislav Funiak       roots.push_back(operationOp);
411a76ee58fSStanislav Funiak 
412a76ee58fSStanislav Funiak   return roots;
413a76ee58fSStanislav Funiak }
414a76ee58fSStanislav Funiak 
415a76ee58fSStanislav Funiak /// Given a list of candidate roots, builds the cost graph for connecting them.
416a76ee58fSStanislav Funiak /// The graph is formed by traversing the DAG of operations starting from each
417a76ee58fSStanislav Funiak /// root and marking the depth of each connector value (operand). Then we join
418a76ee58fSStanislav Funiak /// the candidate roots based on the common connector values, taking the one
419a76ee58fSStanislav Funiak /// with the minimum depth. Along the way, we compute, for each candidate root,
420a76ee58fSStanislav Funiak /// a mapping from each operation (in the DAG underneath this root) to its
421a76ee58fSStanislav Funiak /// parent operation and the corresponding operand index.
422a76ee58fSStanislav Funiak static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
423a76ee58fSStanislav Funiak                            ParentMaps &parentMaps) {
424a76ee58fSStanislav Funiak 
425a76ee58fSStanislav Funiak   // The entry of a queue. The entry consists of the following items:
426a76ee58fSStanislav Funiak   // * the value in the DAG underneath the root;
427a76ee58fSStanislav Funiak   // * the parent of the value;
428a76ee58fSStanislav Funiak   // * the operand index of the value in its parent;
429a76ee58fSStanislav Funiak   // * the depth of the visited value.
430a76ee58fSStanislav Funiak   struct Entry {
43122426110SRamkumar Ramachandra     Entry(Value value, Value parent, std::optional<unsigned> index,
43222426110SRamkumar Ramachandra           unsigned depth)
433a76ee58fSStanislav Funiak         : value(value), parent(parent), index(index), depth(depth) {}
434a76ee58fSStanislav Funiak 
435a76ee58fSStanislav Funiak     Value value;
436a76ee58fSStanislav Funiak     Value parent;
43722426110SRamkumar Ramachandra     std::optional<unsigned> index;
438a76ee58fSStanislav Funiak     unsigned depth;
439a76ee58fSStanislav Funiak   };
440a76ee58fSStanislav Funiak 
441a76ee58fSStanislav Funiak   // A root of a value and its depth (distance from root to the value).
442a76ee58fSStanislav Funiak   struct RootDepth {
443a76ee58fSStanislav Funiak     Value root;
444a76ee58fSStanislav Funiak     unsigned depth = 0;
445a76ee58fSStanislav Funiak   };
446a76ee58fSStanislav Funiak 
447a76ee58fSStanislav Funiak   // Map from candidate connector values to their roots and depths. Using a
448a76ee58fSStanislav Funiak   // small vector with 1 entry because most values belong to a single root.
449a76ee58fSStanislav Funiak   llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
450a76ee58fSStanislav Funiak 
451a76ee58fSStanislav Funiak   // Perform a breadth-first traversal of the op DAG rooted at each root.
452a76ee58fSStanislav Funiak   for (Value root : roots) {
453a76ee58fSStanislav Funiak     // The queue of visited values. A value may be present multiple times in
454a76ee58fSStanislav Funiak     // the queue, for multiple parents. We only accept the first occurrence,
455a76ee58fSStanislav Funiak     // which is guaranteed to have the lowest depth.
456a76ee58fSStanislav Funiak     std::queue<Entry> toVisit;
457a76ee58fSStanislav Funiak     toVisit.emplace(root, Value(), 0, 0);
458a76ee58fSStanislav Funiak 
459a76ee58fSStanislav Funiak     // The map from value to its parent for the current root.
460a76ee58fSStanislav Funiak     DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
461a76ee58fSStanislav Funiak 
462a76ee58fSStanislav Funiak     while (!toVisit.empty()) {
463a76ee58fSStanislav Funiak       Entry entry = toVisit.front();
464a76ee58fSStanislav Funiak       toVisit.pop();
465a76ee58fSStanislav Funiak       // Skip if already visited.
466a76ee58fSStanislav Funiak       if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
467a76ee58fSStanislav Funiak         continue;
468a76ee58fSStanislav Funiak 
469a76ee58fSStanislav Funiak       // Mark the root and depth of the value.
470a76ee58fSStanislav Funiak       connectorsRootsDepths[entry.value].push_back({root, entry.depth});
471a76ee58fSStanislav Funiak 
472a76ee58fSStanislav Funiak       // Traverse the operands of an operation and result ops.
473a76ee58fSStanislav Funiak       // We intentionally do not traverse attributes and types, because those
474a76ee58fSStanislav Funiak       // are expensive to join on.
475a76ee58fSStanislav Funiak       TypeSwitch<Operation *>(entry.value.getDefiningOp())
476a76ee58fSStanislav Funiak           .Case<pdl::OperationOp>([&](auto operationOp) {
47772fddfb5SRiver Riddle             OperandRange operands = operationOp.getOperandValues();
478a76ee58fSStanislav Funiak             // Special case when we pass all the operands in one range.
479a76ee58fSStanislav Funiak             // For those, the index is empty.
480a76ee58fSStanislav Funiak             if (operands.size() == 1 &&
4815550c821STres Popp                 isa<pdl::RangeType>(operands[0].getType())) {
4821a36588eSKazu Hirata               toVisit.emplace(operands[0], entry.value, std::nullopt,
483a76ee58fSStanislav Funiak                               entry.depth + 1);
484a76ee58fSStanislav Funiak               return;
485a76ee58fSStanislav Funiak             }
486a76ee58fSStanislav Funiak 
487a76ee58fSStanislav Funiak             // Default case: visit all the operands.
48872fddfb5SRiver Riddle             for (const auto &p :
48972fddfb5SRiver Riddle                  llvm::enumerate(operationOp.getOperandValues()))
490a76ee58fSStanislav Funiak               toVisit.emplace(p.value(), entry.value, p.index(),
491a76ee58fSStanislav Funiak                               entry.depth + 1);
492a76ee58fSStanislav Funiak           })
493a76ee58fSStanislav Funiak           .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
494310c3ee4SRiver Riddle             toVisit.emplace(resultOp.getParent(), entry.value,
495310c3ee4SRiver Riddle                             resultOp.getIndex(), entry.depth);
496a76ee58fSStanislav Funiak           });
497a76ee58fSStanislav Funiak     }
498a76ee58fSStanislav Funiak   }
499a76ee58fSStanislav Funiak 
500a76ee58fSStanislav Funiak   // Now build the cost graph.
501a76ee58fSStanislav Funiak   // This is simply a minimum over all depths for the target root.
502a76ee58fSStanislav Funiak   unsigned nextID = 0;
503a76ee58fSStanislav Funiak   for (const auto &connectorRootsDepths : connectorsRootsDepths) {
504a76ee58fSStanislav Funiak     Value value = connectorRootsDepths.first;
505a76ee58fSStanislav Funiak     ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
506a76ee58fSStanislav Funiak     // If there is only one root for this value, this will not trigger
507a76ee58fSStanislav Funiak     // any edges in the cost graph (a perf optimization).
508a76ee58fSStanislav Funiak     if (rootsDepths.size() == 1)
509a76ee58fSStanislav Funiak       continue;
510a76ee58fSStanislav Funiak 
511a76ee58fSStanislav Funiak     for (const RootDepth &p : rootsDepths) {
512a76ee58fSStanislav Funiak       for (const RootDepth &q : rootsDepths) {
513a76ee58fSStanislav Funiak         if (&p == &q)
514a76ee58fSStanislav Funiak           continue;
515a76ee58fSStanislav Funiak         // Insert or retrieve the property of edge from p to q.
5169eb8e7b1SStanislav Funiak         RootOrderingEntry &entry = graph[q.root][p.root];
5179eb8e7b1SStanislav Funiak         if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
5189eb8e7b1SStanislav Funiak           if (!entry.connector)
5199eb8e7b1SStanislav Funiak             entry.cost.second = nextID++;
5209eb8e7b1SStanislav Funiak           entry.cost.first = q.depth;
5219eb8e7b1SStanislav Funiak           entry.connector = value;
522a76ee58fSStanislav Funiak         }
523a76ee58fSStanislav Funiak       }
524a76ee58fSStanislav Funiak     }
525a76ee58fSStanislav Funiak   }
526a76ee58fSStanislav Funiak 
527a76ee58fSStanislav Funiak   assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
528a76ee58fSStanislav Funiak          "the pattern contains a candidate root disconnected from the others");
529a76ee58fSStanislav Funiak }
530a76ee58fSStanislav Funiak 
5312692eae5SStanislav Funiak /// Returns true if the operand at the given index needs to be queried using an
5322692eae5SStanislav Funiak /// operand group, i.e., if it is variadic itself or follows a variadic operand.
5332692eae5SStanislav Funiak static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
53472fddfb5SRiver Riddle   OperandRange operands = op.getOperandValues();
5352692eae5SStanislav Funiak   assert(index < operands.size() && "operand index out of range");
5362692eae5SStanislav Funiak   for (unsigned i = 0; i <= index; ++i)
5375550c821STres Popp     if (isa<pdl::RangeType>(operands[i].getType()))
5382692eae5SStanislav Funiak       return true;
5392692eae5SStanislav Funiak   return false;
5402692eae5SStanislav Funiak }
5412692eae5SStanislav Funiak 
542a76ee58fSStanislav Funiak /// Visit a node during upward traversal.
5432692eae5SStanislav Funiak static void visitUpward(std::vector<PositionalPredicate> &predList,
5442692eae5SStanislav Funiak                         OpIndex opIndex, PredicateBuilder &builder,
5452692eae5SStanislav Funiak                         DenseMap<Value, Position *> &valueToPosition,
5462692eae5SStanislav Funiak                         Position *&pos, unsigned rootID) {
547a76ee58fSStanislav Funiak   Value value = opIndex.parent;
548a76ee58fSStanislav Funiak   TypeSwitch<Operation *>(value.getDefiningOp())
549a76ee58fSStanislav Funiak       .Case<pdl::OperationOp>([&](auto operationOp) {
550a76ee58fSStanislav Funiak         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
551a76ee58fSStanislav Funiak 
5522692eae5SStanislav Funiak         // Get users and iterate over them.
5532692eae5SStanislav Funiak         Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
5542692eae5SStanislav Funiak         Position *foreachPos = builder.getForEach(usersPos, rootID);
5552692eae5SStanislav Funiak         OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
5562692eae5SStanislav Funiak 
5572692eae5SStanislav Funiak         // Compare the operand(s) of the user against the input value(s).
5582692eae5SStanislav Funiak         Position *operandPos;
5592692eae5SStanislav Funiak         if (!opIndex.index) {
5602692eae5SStanislav Funiak           // We are querying all the operands of the operation.
5612692eae5SStanislav Funiak           operandPos = builder.getAllOperands(opPos);
5622692eae5SStanislav Funiak         } else if (useOperandGroup(operationOp, *opIndex.index)) {
5632692eae5SStanislav Funiak           // We are querying an operand group.
56472fddfb5SRiver Riddle           Type type = operationOp.getOperandValues()[*opIndex.index].getType();
5655550c821STres Popp           bool variadic = isa<pdl::RangeType>(type);
5662692eae5SStanislav Funiak           operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
5672692eae5SStanislav Funiak         } else {
5682692eae5SStanislav Funiak           // We are querying an individual operand.
5692692eae5SStanislav Funiak           operandPos = builder.getOperand(opPos, *opIndex.index);
570a76ee58fSStanislav Funiak         }
5712692eae5SStanislav Funiak         predList.emplace_back(operandPos, builder.getEqualTo(pos));
572a76ee58fSStanislav Funiak 
573a76ee58fSStanislav Funiak         // Guard against duplicate upward visits. These are not possible,
574a76ee58fSStanislav Funiak         // because if this value was already visited, it would have been
575a76ee58fSStanislav Funiak         // cheaper to start the traversal at this value rather than at the
576a76ee58fSStanislav Funiak         // `connector`, violating the optimality of our spanning tree.
577a76ee58fSStanislav Funiak         bool inserted = valueToPosition.try_emplace(value, opPos).second;
5781b0312d2SBenjamin Kramer         (void)inserted;
579a76ee58fSStanislav Funiak         assert(inserted && "duplicate upward visit");
580a76ee58fSStanislav Funiak 
581a76ee58fSStanislav Funiak         // Obtain the tree predicates at the current value.
582a76ee58fSStanislav Funiak         getTreePredicates(predList, value, builder, valueToPosition, opPos,
583a76ee58fSStanislav Funiak                           opIndex.index);
584a76ee58fSStanislav Funiak 
585a76ee58fSStanislav Funiak         // Update the position
586a76ee58fSStanislav Funiak         pos = opPos;
587a76ee58fSStanislav Funiak       })
588a76ee58fSStanislav Funiak       .Case<pdl::ResultOp>([&](auto resultOp) {
589a76ee58fSStanislav Funiak         // Traverse up an individual result.
590a76ee58fSStanislav Funiak         auto *opPos = dyn_cast<OperationPosition>(pos);
591a76ee58fSStanislav Funiak         assert(opPos && "operations and results must be interleaved");
592a76ee58fSStanislav Funiak         pos = builder.getResult(opPos, *opIndex.index);
5932692eae5SStanislav Funiak 
5942692eae5SStanislav Funiak         // Insert the result position in case we have not visited it yet.
5952692eae5SStanislav Funiak         valueToPosition.try_emplace(value, pos);
596a76ee58fSStanislav Funiak       })
597a76ee58fSStanislav Funiak       .Case<pdl::ResultsOp>([&](auto resultOp) {
598a76ee58fSStanislav Funiak         // Traverse up a group of results.
599a76ee58fSStanislav Funiak         auto *opPos = dyn_cast<OperationPosition>(pos);
600a76ee58fSStanislav Funiak         assert(opPos && "operations and results must be interleaved");
6015550c821STres Popp         bool isVariadic = isa<pdl::RangeType>(value.getType());
602a76ee58fSStanislav Funiak         if (opIndex.index)
603a76ee58fSStanislav Funiak           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
604a76ee58fSStanislav Funiak         else
605a76ee58fSStanislav Funiak           pos = builder.getAllResults(opPos);
6062692eae5SStanislav Funiak 
6072692eae5SStanislav Funiak         // Insert the result position in case we have not visited it yet.
6082692eae5SStanislav Funiak         valueToPosition.try_emplace(value, pos);
609a76ee58fSStanislav Funiak       });
610a76ee58fSStanislav Funiak }
611a76ee58fSStanislav Funiak 
6128a1ca2cdSRiver Riddle /// Given a pattern operation, build the set of matcher predicates necessary to
6138a1ca2cdSRiver Riddle /// match this pattern.
614a76ee58fSStanislav Funiak static Value buildPredicateList(pdl::PatternOp pattern,
6158a1ca2cdSRiver Riddle                                 PredicateBuilder &builder,
6168a1ca2cdSRiver Riddle                                 std::vector<PositionalPredicate> &predList,
6178a1ca2cdSRiver Riddle                                 DenseMap<Value, Position *> &valueToPosition) {
618a76ee58fSStanislav Funiak   SmallVector<Value> roots = detectRoots(pattern);
619a76ee58fSStanislav Funiak 
620a76ee58fSStanislav Funiak   // Build the root ordering graph and compute the parent maps.
621a76ee58fSStanislav Funiak   RootOrderingGraph graph;
622a76ee58fSStanislav Funiak   ParentMaps parentMaps;
623a76ee58fSStanislav Funiak   buildCostGraph(roots, graph, parentMaps);
624a76ee58fSStanislav Funiak   LLVM_DEBUG({
625a76ee58fSStanislav Funiak     llvm::dbgs() << "Graph:\n";
626a76ee58fSStanislav Funiak     for (auto &target : graph) {
6272692eae5SStanislav Funiak       llvm::dbgs() << "  * " << target.first.getLoc() << " " << target.first
6282692eae5SStanislav Funiak                    << "\n";
629a76ee58fSStanislav Funiak       for (auto &source : target.second) {
6309eb8e7b1SStanislav Funiak         RootOrderingEntry &entry = source.second;
6319eb8e7b1SStanislav Funiak         llvm::dbgs() << "      <- " << source.first << ": " << entry.cost.first
6329eb8e7b1SStanislav Funiak                      << ":" << entry.cost.second << " via "
6339eb8e7b1SStanislav Funiak                      << entry.connector.getLoc() << "\n";
634a76ee58fSStanislav Funiak       }
635a76ee58fSStanislav Funiak     }
636a76ee58fSStanislav Funiak   });
637a76ee58fSStanislav Funiak 
638a76ee58fSStanislav Funiak   // Solve the optimal branching problem for each candidate root, or use the
639a76ee58fSStanislav Funiak   // provided one.
640310c3ee4SRiver Riddle   Value bestRoot = pattern.getRewriter().getRoot();
641a76ee58fSStanislav Funiak   OptimalBranching::EdgeList bestEdges;
642a76ee58fSStanislav Funiak   if (!bestRoot) {
643a76ee58fSStanislav Funiak     unsigned bestCost = 0;
644a76ee58fSStanislav Funiak     LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
645a76ee58fSStanislav Funiak     for (Value root : roots) {
646a76ee58fSStanislav Funiak       OptimalBranching solver(graph, root);
647a76ee58fSStanislav Funiak       unsigned cost = solver.solve();
648a76ee58fSStanislav Funiak       LLVM_DEBUG(llvm::dbgs() << "  * " << root << ": " << cost << "\n");
649a76ee58fSStanislav Funiak       if (!bestRoot || bestCost > cost) {
650a76ee58fSStanislav Funiak         bestCost = cost;
651a76ee58fSStanislav Funiak         bestRoot = root;
652a76ee58fSStanislav Funiak         bestEdges = solver.preOrderTraversal(roots);
653a76ee58fSStanislav Funiak       }
654a76ee58fSStanislav Funiak     }
655a76ee58fSStanislav Funiak   } else {
656a76ee58fSStanislav Funiak     OptimalBranching solver(graph, bestRoot);
657a76ee58fSStanislav Funiak     solver.solve();
658a76ee58fSStanislav Funiak     bestEdges = solver.preOrderTraversal(roots);
659a76ee58fSStanislav Funiak   }
660a76ee58fSStanislav Funiak 
6612692eae5SStanislav Funiak   // Print the best solution.
6622692eae5SStanislav Funiak   LLVM_DEBUG({
6632692eae5SStanislav Funiak     llvm::dbgs() << "Best tree:\n";
6642692eae5SStanislav Funiak     for (const std::pair<Value, Value> &edge : bestEdges) {
6652692eae5SStanislav Funiak       llvm::dbgs() << "  * " << edge.first;
6662692eae5SStanislav Funiak       if (edge.second)
6672692eae5SStanislav Funiak         llvm::dbgs() << " <- " << edge.second;
6682692eae5SStanislav Funiak       llvm::dbgs() << "\n";
6692692eae5SStanislav Funiak     }
6702692eae5SStanislav Funiak   });
6712692eae5SStanislav Funiak 
672a76ee58fSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
673a76ee58fSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
674a76ee58fSStanislav Funiak 
675a76ee58fSStanislav Funiak   // The best root is the starting point for the traversal. Get the tree
676a76ee58fSStanislav Funiak   // predicates for the DAG rooted at bestRoot.
677a76ee58fSStanislav Funiak   getTreePredicates(predList, bestRoot, builder, valueToPosition,
678a76ee58fSStanislav Funiak                     builder.getRoot());
679a76ee58fSStanislav Funiak 
680a76ee58fSStanislav Funiak   // Traverse the selected optimal branching. For all edges in order, traverse
681a76ee58fSStanislav Funiak   // up starting from the connector, until the candidate root is reached, and
682a76ee58fSStanislav Funiak   // call getTreePredicates at every node along the way.
68350da0134SAdrian Kuegel   for (const auto &it : llvm::enumerate(bestEdges)) {
6842692eae5SStanislav Funiak     Value target = it.value().first;
6852692eae5SStanislav Funiak     Value source = it.value().second;
686a76ee58fSStanislav Funiak 
687a76ee58fSStanislav Funiak     // Check if we already visited the target root. This happens in two cases:
688a76ee58fSStanislav Funiak     // 1) the initial root (bestRoot);
689a76ee58fSStanislav Funiak     // 2) a root that is dominated by (contained in the subtree rooted at) an
690a76ee58fSStanislav Funiak     //    already visited root.
691a76ee58fSStanislav Funiak     if (valueToPosition.count(target))
692a76ee58fSStanislav Funiak       continue;
693a76ee58fSStanislav Funiak 
694a76ee58fSStanislav Funiak     // Determine the connector.
695a76ee58fSStanislav Funiak     Value connector = graph[target][source].connector;
696a76ee58fSStanislav Funiak     assert(connector && "invalid edge");
697a76ee58fSStanislav Funiak     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
698a76ee58fSStanislav Funiak     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
699a76ee58fSStanislav Funiak     Position *pos = valueToPosition.lookup(connector);
7002692eae5SStanislav Funiak     assert(pos && "connector has not been traversed yet");
701a76ee58fSStanislav Funiak 
702a76ee58fSStanislav Funiak     // Traverse from the connector upwards towards the target root.
703a76ee58fSStanislav Funiak     for (Value value = connector; value != target;) {
704a76ee58fSStanislav Funiak       OpIndex opIndex = parentMap.lookup(value);
705a76ee58fSStanislav Funiak       assert(opIndex.parent && "missing parent");
7062692eae5SStanislav Funiak       visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
707a76ee58fSStanislav Funiak       value = opIndex.parent;
708a76ee58fSStanislav Funiak     }
709a76ee58fSStanislav Funiak   }
710a76ee58fSStanislav Funiak 
711242762c9SRiver Riddle   getNonTreePredicates(pattern, predList, builder, valueToPosition);
712a76ee58fSStanislav Funiak 
713a76ee58fSStanislav Funiak   return bestRoot;
7148a1ca2cdSRiver Riddle }
7158a1ca2cdSRiver Riddle 
7168a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
7178a1ca2cdSRiver Riddle // Pattern Predicate Tree Merging
7188a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
7198a1ca2cdSRiver Riddle 
7208a1ca2cdSRiver Riddle namespace {
7218a1ca2cdSRiver Riddle 
7228a1ca2cdSRiver Riddle /// This class represents a specific predicate applied to a position, and
7238a1ca2cdSRiver Riddle /// provides hashing and ordering operators. This class allows for computing a
7248a1ca2cdSRiver Riddle /// frequence sum and ordering predicates based on a cost model.
7258a1ca2cdSRiver Riddle struct OrderedPredicate {
7268a1ca2cdSRiver Riddle   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
7278a1ca2cdSRiver Riddle       : position(ip.first), question(ip.second) {}
7288a1ca2cdSRiver Riddle   OrderedPredicate(const PositionalPredicate &ip)
7298a1ca2cdSRiver Riddle       : position(ip.position), question(ip.question) {}
7308a1ca2cdSRiver Riddle 
7318a1ca2cdSRiver Riddle   /// The position this predicate is applied to.
7328a1ca2cdSRiver Riddle   Position *position;
7338a1ca2cdSRiver Riddle 
7348a1ca2cdSRiver Riddle   /// The question that is applied by this predicate onto the position.
7358a1ca2cdSRiver Riddle   Qualifier *question;
7368a1ca2cdSRiver Riddle 
7378a1ca2cdSRiver Riddle   /// The first and second order benefit sums.
7388a1ca2cdSRiver Riddle   /// The primary sum is the number of occurrences of this predicate among all
7398a1ca2cdSRiver Riddle   /// of the patterns.
7408a1ca2cdSRiver Riddle   unsigned primary = 0;
7418a1ca2cdSRiver Riddle   /// The secondary sum is a squared summation of the primary sum of all of the
7428a1ca2cdSRiver Riddle   /// predicates within each pattern that contains this predicate. This allows
7438a1ca2cdSRiver Riddle   /// for favoring predicates that are more commonly shared within a pattern, as
7448a1ca2cdSRiver Riddle   /// opposed to those shared across patterns.
7458a1ca2cdSRiver Riddle   unsigned secondary = 0;
7468a1ca2cdSRiver Riddle 
747138803e0SStanislav Funiak   /// The tie breaking ID, used to preserve a deterministic (insertion) order
748138803e0SStanislav Funiak   /// among all the predicates with the same priority, depth, and position /
749138803e0SStanislav Funiak   /// predicate dependency.
750138803e0SStanislav Funiak   unsigned id = 0;
751138803e0SStanislav Funiak 
7528a1ca2cdSRiver Riddle   /// A map between a pattern operation and the answer to the predicate question
7538a1ca2cdSRiver Riddle   /// within that pattern.
7548a1ca2cdSRiver Riddle   DenseMap<Operation *, Qualifier *> patternToAnswer;
7558a1ca2cdSRiver Riddle 
756ddd556f1SRiver Riddle   /// Returns true if this predicate is ordered before `rhs`, based on the cost
757ddd556f1SRiver Riddle   /// model.
758ddd556f1SRiver Riddle   bool operator<(const OrderedPredicate &rhs) const {
7598a1ca2cdSRiver Riddle     // Sort by:
760ddd556f1SRiver Riddle     // * higher first and secondary order sums
7618a1ca2cdSRiver Riddle     // * lower depth
762ddd556f1SRiver Riddle     // * lower position dependency
763ddd556f1SRiver Riddle     // * lower predicate dependency
764138803e0SStanislav Funiak     // * lower tie breaking ID
765ddd556f1SRiver Riddle     auto *rhsPos = rhs.position;
7663a833a0eSRiver Riddle     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
767138803e0SStanislav Funiak                            rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
768ddd556f1SRiver Riddle            std::make_tuple(rhs.primary, rhs.secondary,
7693a833a0eSRiver Riddle                            position->getOperationDepth(), position->getKind(),
770138803e0SStanislav Funiak                            question->getKind(), id);
7718a1ca2cdSRiver Riddle   }
7728a1ca2cdSRiver Riddle };
7738a1ca2cdSRiver Riddle 
7748a1ca2cdSRiver Riddle /// A DenseMapInfo for OrderedPredicate based solely on the position and
7758a1ca2cdSRiver Riddle /// question.
7768a1ca2cdSRiver Riddle struct OrderedPredicateDenseInfo {
7778a1ca2cdSRiver Riddle   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
7788a1ca2cdSRiver Riddle 
7798a1ca2cdSRiver Riddle   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
7808a1ca2cdSRiver Riddle   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
7818a1ca2cdSRiver Riddle   static bool isEqual(const OrderedPredicate &lhs,
7828a1ca2cdSRiver Riddle                       const OrderedPredicate &rhs) {
7838a1ca2cdSRiver Riddle     return lhs.position == rhs.position && lhs.question == rhs.question;
7848a1ca2cdSRiver Riddle   }
7858a1ca2cdSRiver Riddle   static unsigned getHashValue(const OrderedPredicate &p) {
7868a1ca2cdSRiver Riddle     return llvm::hash_combine(p.position, p.question);
7878a1ca2cdSRiver Riddle   }
7888a1ca2cdSRiver Riddle };
7898a1ca2cdSRiver Riddle 
7908a1ca2cdSRiver Riddle /// This class wraps a set of ordered predicates that are used within a specific
7918a1ca2cdSRiver Riddle /// pattern operation.
7928a1ca2cdSRiver Riddle struct OrderedPredicateList {
793a76ee58fSStanislav Funiak   OrderedPredicateList(pdl::PatternOp pattern, Value root)
794a76ee58fSStanislav Funiak       : pattern(pattern), root(root) {}
7958a1ca2cdSRiver Riddle 
7968a1ca2cdSRiver Riddle   pdl::PatternOp pattern;
797a76ee58fSStanislav Funiak   Value root;
7988a1ca2cdSRiver Riddle   DenseSet<OrderedPredicate *> predicates;
7998a1ca2cdSRiver Riddle };
800be0a7e9fSMehdi Amini } // namespace
8018a1ca2cdSRiver Riddle 
8028a1ca2cdSRiver Riddle /// Returns true if the given matcher refers to the same predicate as the given
8038a1ca2cdSRiver Riddle /// ordered predicate. This means that the position and questions of the two
8048a1ca2cdSRiver Riddle /// match.
8058a1ca2cdSRiver Riddle static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
8068a1ca2cdSRiver Riddle   return node->getPosition() == predicate->position &&
8078a1ca2cdSRiver Riddle          node->getQuestion() == predicate->question;
8088a1ca2cdSRiver Riddle }
8098a1ca2cdSRiver Riddle 
8108a1ca2cdSRiver Riddle /// Get or insert a child matcher for the given parent switch node, given a
8118a1ca2cdSRiver Riddle /// predicate and parent pattern.
8128a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
8138a1ca2cdSRiver Riddle                                                OrderedPredicate *predicate,
8148a1ca2cdSRiver Riddle                                                pdl::PatternOp pattern) {
8158a1ca2cdSRiver Riddle   assert(isSamePredicate(node, predicate) &&
8168a1ca2cdSRiver Riddle          "expected matcher to equal the given predicate");
8178a1ca2cdSRiver Riddle 
8188a1ca2cdSRiver Riddle   auto it = predicate->patternToAnswer.find(pattern);
8198a1ca2cdSRiver Riddle   assert(it != predicate->patternToAnswer.end() &&
8208a1ca2cdSRiver Riddle          "expected pattern to exist in predicate");
821*5d6cb6f7SKazu Hirata   return node->getChildren()[it->second];
8228a1ca2cdSRiver Riddle }
8238a1ca2cdSRiver Riddle 
8248a1ca2cdSRiver Riddle /// Build the matcher CFG by "pushing" patterns through by sorted predicate
8258a1ca2cdSRiver Riddle /// order. A pattern will traverse as far as possible using common predicates
8268a1ca2cdSRiver Riddle /// and then either diverge from the CFG or reach the end of a branch and start
8278a1ca2cdSRiver Riddle /// creating new nodes.
8288a1ca2cdSRiver Riddle static void propagatePattern(std::unique_ptr<MatcherNode> &node,
8298a1ca2cdSRiver Riddle                              OrderedPredicateList &list,
8308a1ca2cdSRiver Riddle                              std::vector<OrderedPredicate *>::iterator current,
8318a1ca2cdSRiver Riddle                              std::vector<OrderedPredicate *>::iterator end) {
8328a1ca2cdSRiver Riddle   if (current == end) {
8338a1ca2cdSRiver Riddle     // We've hit the end of a pattern, so create a successful result node.
834a76ee58fSStanislav Funiak     node =
835a76ee58fSStanislav Funiak         std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
8368a1ca2cdSRiver Riddle 
8378a1ca2cdSRiver Riddle     // If the pattern doesn't contain this predicate, ignore it.
838ce14f7b1SKazu Hirata   } else if (!list.predicates.contains(*current)) {
8398a1ca2cdSRiver Riddle     propagatePattern(node, list, std::next(current), end);
8408a1ca2cdSRiver Riddle 
8418a1ca2cdSRiver Riddle     // If the current matcher node is invalid, create a new one for this
8428a1ca2cdSRiver Riddle     // position and continue propagation.
8438a1ca2cdSRiver Riddle   } else if (!node) {
8448a1ca2cdSRiver Riddle     // Create a new node at this position and continue
8458a1ca2cdSRiver Riddle     node = std::make_unique<SwitchNode>((*current)->position,
8468a1ca2cdSRiver Riddle                                         (*current)->question);
8478a1ca2cdSRiver Riddle     propagatePattern(
8488a1ca2cdSRiver Riddle         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
8498a1ca2cdSRiver Riddle         list, std::next(current), end);
8508a1ca2cdSRiver Riddle 
8518a1ca2cdSRiver Riddle     // If the matcher has already been created, and it is for this predicate we
8528a1ca2cdSRiver Riddle     // continue propagation to the child.
8538a1ca2cdSRiver Riddle   } else if (isSamePredicate(node.get(), *current)) {
8548a1ca2cdSRiver Riddle     propagatePattern(
8558a1ca2cdSRiver Riddle         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
8568a1ca2cdSRiver Riddle         list, std::next(current), end);
8578a1ca2cdSRiver Riddle 
8588a1ca2cdSRiver Riddle     // If the matcher doesn't match the current predicate, insert a branch as
8598a1ca2cdSRiver Riddle     // the common set of matchers has diverged.
8608a1ca2cdSRiver Riddle   } else {
8618a1ca2cdSRiver Riddle     propagatePattern(node->getFailureNode(), list, current, end);
8628a1ca2cdSRiver Riddle   }
8638a1ca2cdSRiver Riddle }
8648a1ca2cdSRiver Riddle 
8658a1ca2cdSRiver Riddle /// Fold any switch nodes nested under `node` to boolean nodes when possible.
8668a1ca2cdSRiver Riddle /// `node` is updated in-place if it is a switch.
8678a1ca2cdSRiver Riddle static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
8688a1ca2cdSRiver Riddle   if (!node)
8698a1ca2cdSRiver Riddle     return;
8708a1ca2cdSRiver Riddle 
8718a1ca2cdSRiver Riddle   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
8728a1ca2cdSRiver Riddle     SwitchNode::ChildMapT &children = switchNode->getChildren();
8738a1ca2cdSRiver Riddle     for (auto &it : children)
8748a1ca2cdSRiver Riddle       foldSwitchToBool(it.second);
8758a1ca2cdSRiver Riddle 
8768a1ca2cdSRiver Riddle     // If the node only contains one child, collapse it into a boolean predicate
8778a1ca2cdSRiver Riddle     // node.
8788a1ca2cdSRiver Riddle     if (children.size() == 1) {
8791cef577bSMehdi Amini       auto *childIt = children.begin();
8808a1ca2cdSRiver Riddle       node = std::make_unique<BoolNode>(
8818a1ca2cdSRiver Riddle           node->getPosition(), node->getQuestion(), childIt->first,
8828a1ca2cdSRiver Riddle           std::move(childIt->second), std::move(node->getFailureNode()));
8838a1ca2cdSRiver Riddle     }
8848a1ca2cdSRiver Riddle   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
8858a1ca2cdSRiver Riddle     foldSwitchToBool(boolNode->getSuccessNode());
8868a1ca2cdSRiver Riddle   }
8878a1ca2cdSRiver Riddle 
8888a1ca2cdSRiver Riddle   foldSwitchToBool(node->getFailureNode());
8898a1ca2cdSRiver Riddle }
8908a1ca2cdSRiver Riddle 
8918a1ca2cdSRiver Riddle /// Insert an exit node at the end of the failure path of the `root`.
8928a1ca2cdSRiver Riddle static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
8938a1ca2cdSRiver Riddle   while (*root)
8948a1ca2cdSRiver Riddle     root = &(*root)->getFailureNode();
8958a1ca2cdSRiver Riddle   *root = std::make_unique<ExitNode>();
8968a1ca2cdSRiver Riddle }
8978a1ca2cdSRiver Riddle 
8988ec28af8SMatthias Gehre /// Sorts the range begin/end with the partial order given by cmp.
8998ec28af8SMatthias Gehre template <typename Iterator, typename Compare>
9008ec28af8SMatthias Gehre static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
9018ec28af8SMatthias Gehre   while (begin != end) {
9028ec28af8SMatthias Gehre     // Cannot compute sortBeforeOthers in the predicate of stable_partition
9038ec28af8SMatthias Gehre     // because stable_partition will not keep the [begin, end) range intact
9048ec28af8SMatthias Gehre     // while it runs.
9058ec28af8SMatthias Gehre     llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
9068ec28af8SMatthias Gehre     for (auto i = begin; i != end; ++i) {
9078ec28af8SMatthias Gehre       if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
9088ec28af8SMatthias Gehre         sortBeforeOthers.insert(*i);
9098ec28af8SMatthias Gehre     }
9108ec28af8SMatthias Gehre 
9118ec28af8SMatthias Gehre     auto const next = std::stable_partition(begin, end, [&](auto const &a) {
9128ec28af8SMatthias Gehre       return sortBeforeOthers.contains(a);
9138ec28af8SMatthias Gehre     });
9148ec28af8SMatthias Gehre     assert(next != begin && "not a partial ordering");
9158ec28af8SMatthias Gehre     begin = next;
9168ec28af8SMatthias Gehre   }
9178ec28af8SMatthias Gehre }
9188ec28af8SMatthias Gehre 
9198ec28af8SMatthias Gehre /// Returns true if 'b' depends on a result of 'a'.
9208ec28af8SMatthias Gehre static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
9218ec28af8SMatthias Gehre   auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
9228ec28af8SMatthias Gehre   if (!cqa)
9238ec28af8SMatthias Gehre     return false;
9248ec28af8SMatthias Gehre 
9258ec28af8SMatthias Gehre   auto positionDependsOnA = [&](Position *p) {
9268ec28af8SMatthias Gehre     auto *cp = dyn_cast<ConstraintPosition>(p);
9278ec28af8SMatthias Gehre     return cp && cp->getQuestion() == cqa;
9288ec28af8SMatthias Gehre   };
9298ec28af8SMatthias Gehre 
9308ec28af8SMatthias Gehre   if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
9318ec28af8SMatthias Gehre     // Does any argument of b use a?
9328ec28af8SMatthias Gehre     return llvm::any_of(cqb->getArgs(), positionDependsOnA);
9338ec28af8SMatthias Gehre   }
9348ec28af8SMatthias Gehre   if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
9358ec28af8SMatthias Gehre     return positionDependsOnA(b->position) ||
9368ec28af8SMatthias Gehre            positionDependsOnA(equalTo->getValue());
9378ec28af8SMatthias Gehre   }
9388ec28af8SMatthias Gehre   return positionDependsOnA(b->position);
9398ec28af8SMatthias Gehre }
9408ec28af8SMatthias Gehre 
9418a1ca2cdSRiver Riddle /// Given a module containing PDL pattern operations, generate a matcher tree
9428a1ca2cdSRiver Riddle /// using the patterns within the given module and return the root matcher node.
9438a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode>
9448a1ca2cdSRiver Riddle MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
9458a1ca2cdSRiver Riddle                                  DenseMap<Value, Position *> &valueToPosition) {
946a76ee58fSStanislav Funiak   // The set of predicates contained within the pattern operations of the
947a76ee58fSStanislav Funiak   // module.
948a76ee58fSStanislav Funiak   struct PatternPredicates {
949a76ee58fSStanislav Funiak     PatternPredicates(pdl::PatternOp pattern, Value root,
950a76ee58fSStanislav Funiak                       std::vector<PositionalPredicate> predicates)
951a76ee58fSStanislav Funiak         : pattern(pattern), root(root), predicates(std::move(predicates)) {}
952a76ee58fSStanislav Funiak 
953a76ee58fSStanislav Funiak     /// A pattern.
954a76ee58fSStanislav Funiak     pdl::PatternOp pattern;
955a76ee58fSStanislav Funiak 
956a76ee58fSStanislav Funiak     /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
957a76ee58fSStanislav Funiak     Value root;
958a76ee58fSStanislav Funiak 
959a76ee58fSStanislav Funiak     /// The extracted predicates for this pattern and root.
960a76ee58fSStanislav Funiak     std::vector<PositionalPredicate> predicates;
961a76ee58fSStanislav Funiak   };
962a76ee58fSStanislav Funiak 
963a76ee58fSStanislav Funiak   SmallVector<PatternPredicates, 16> patternsAndPredicates;
9648a1ca2cdSRiver Riddle   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
9658a1ca2cdSRiver Riddle     std::vector<PositionalPredicate> predicateList;
966a76ee58fSStanislav Funiak     Value root =
9678a1ca2cdSRiver Riddle         buildPredicateList(pattern, builder, predicateList, valueToPosition);
968a76ee58fSStanislav Funiak     patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
9698a1ca2cdSRiver Riddle   }
9708a1ca2cdSRiver Riddle 
9718a1ca2cdSRiver Riddle   // Associate a pattern result with each unique predicate.
9728a1ca2cdSRiver Riddle   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
9738a1ca2cdSRiver Riddle   for (auto &patternAndPredList : patternsAndPredicates) {
974a76ee58fSStanislav Funiak     for (auto &predicate : patternAndPredList.predicates) {
9758a1ca2cdSRiver Riddle       auto it = uniqued.insert(predicate);
976a76ee58fSStanislav Funiak       it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
9778a1ca2cdSRiver Riddle                                             predicate.answer);
978138803e0SStanislav Funiak       // Mark the insertion order (0-based indexing).
979138803e0SStanislav Funiak       if (it.second)
980138803e0SStanislav Funiak         it.first->id = uniqued.size() - 1;
9818a1ca2cdSRiver Riddle     }
9828a1ca2cdSRiver Riddle   }
9838a1ca2cdSRiver Riddle 
9848a1ca2cdSRiver Riddle   // Associate each pattern to a set of its ordered predicates for later lookup.
9858a1ca2cdSRiver Riddle   std::vector<OrderedPredicateList> lists;
9868a1ca2cdSRiver Riddle   lists.reserve(patternsAndPredicates.size());
9878a1ca2cdSRiver Riddle   for (auto &patternAndPredList : patternsAndPredicates) {
988a76ee58fSStanislav Funiak     OrderedPredicateList list(patternAndPredList.pattern,
989a76ee58fSStanislav Funiak                               patternAndPredList.root);
990a76ee58fSStanislav Funiak     for (auto &predicate : patternAndPredList.predicates) {
9918a1ca2cdSRiver Riddle       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
9928a1ca2cdSRiver Riddle       list.predicates.insert(orderedPredicate);
9938a1ca2cdSRiver Riddle 
9948a1ca2cdSRiver Riddle       // Increment the primary sum for each reference to a particular predicate.
9958a1ca2cdSRiver Riddle       ++orderedPredicate->primary;
9968a1ca2cdSRiver Riddle     }
9978a1ca2cdSRiver Riddle     lists.push_back(std::move(list));
9988a1ca2cdSRiver Riddle   }
9998a1ca2cdSRiver Riddle 
10008a1ca2cdSRiver Riddle   // For a particular pattern, get the total primary sum and add it to the
10018a1ca2cdSRiver Riddle   // secondary sum of each predicate. Square the primary sums to emphasize
10028a1ca2cdSRiver Riddle   // shared predicates within rather than across patterns.
10038a1ca2cdSRiver Riddle   for (auto &list : lists) {
10048a1ca2cdSRiver Riddle     unsigned total = 0;
10058a1ca2cdSRiver Riddle     for (auto *predicate : list.predicates)
10068a1ca2cdSRiver Riddle       total += predicate->primary * predicate->primary;
10078a1ca2cdSRiver Riddle     for (auto *predicate : list.predicates)
10088a1ca2cdSRiver Riddle       predicate->secondary += total;
10098a1ca2cdSRiver Riddle   }
10108a1ca2cdSRiver Riddle 
10118a1ca2cdSRiver Riddle   // Sort the set of predicates now that the cost primary and secondary sums
10128a1ca2cdSRiver Riddle   // have been computed.
10138a1ca2cdSRiver Riddle   std::vector<OrderedPredicate *> ordered;
10148a1ca2cdSRiver Riddle   ordered.reserve(uniqued.size());
10158a1ca2cdSRiver Riddle   for (auto &ip : uniqued)
10168a1ca2cdSRiver Riddle     ordered.push_back(&ip);
1017138803e0SStanislav Funiak   llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1018138803e0SStanislav Funiak     return *lhs < *rhs;
1019138803e0SStanislav Funiak   });
10208a1ca2cdSRiver Riddle 
10218ec28af8SMatthias Gehre   // Mostly keep the now established order, but also ensure that
10228ec28af8SMatthias Gehre   // ConstraintQuestions come after the results they use.
10238ec28af8SMatthias Gehre   stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
10248ec28af8SMatthias Gehre 
10258a1ca2cdSRiver Riddle   // Build the matchers for each of the pattern predicate lists.
10268a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> root;
10278a1ca2cdSRiver Riddle   for (OrderedPredicateList &list : lists)
10288a1ca2cdSRiver Riddle     propagatePattern(root, list, ordered.begin(), ordered.end());
10298a1ca2cdSRiver Riddle 
10308a1ca2cdSRiver Riddle   // Collapse the graph and insert the exit node.
10318a1ca2cdSRiver Riddle   foldSwitchToBool(root);
10328a1ca2cdSRiver Riddle   insertExitNode(&root);
10338a1ca2cdSRiver Riddle   return root;
10348a1ca2cdSRiver Riddle }
10358a1ca2cdSRiver Riddle 
10368a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10378a1ca2cdSRiver Riddle // MatcherNode
10388a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10398a1ca2cdSRiver Riddle 
10408a1ca2cdSRiver Riddle MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
10418a1ca2cdSRiver Riddle                          std::unique_ptr<MatcherNode> failureNode)
10428a1ca2cdSRiver Riddle     : position(p), question(q), failureNode(std::move(failureNode)),
10438a1ca2cdSRiver Riddle       matcherTypeID(matcherTypeID) {}
10448a1ca2cdSRiver Riddle 
10458a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10468a1ca2cdSRiver Riddle // BoolNode
10478a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10488a1ca2cdSRiver Riddle 
10498a1ca2cdSRiver Riddle BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
10508a1ca2cdSRiver Riddle                    std::unique_ptr<MatcherNode> successNode,
10518a1ca2cdSRiver Riddle                    std::unique_ptr<MatcherNode> failureNode)
10528a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<BoolNode>(), position, question,
10538a1ca2cdSRiver Riddle                   std::move(failureNode)),
10548a1ca2cdSRiver Riddle       answer(answer), successNode(std::move(successNode)) {}
10558a1ca2cdSRiver Riddle 
10568a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10578a1ca2cdSRiver Riddle // SuccessNode
10588a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10598a1ca2cdSRiver Riddle 
1060a76ee58fSStanislav Funiak SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
10618a1ca2cdSRiver Riddle                          std::unique_ptr<MatcherNode> failureNode)
10628a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
10638a1ca2cdSRiver Riddle                   /*question=*/nullptr, std::move(failureNode)),
1064a76ee58fSStanislav Funiak       pattern(pattern), root(root) {}
10658a1ca2cdSRiver Riddle 
10668a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10678a1ca2cdSRiver Riddle // SwitchNode
10688a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
10698a1ca2cdSRiver Riddle 
10708a1ca2cdSRiver Riddle SwitchNode::SwitchNode(Position *position, Qualifier *question)
10718a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
1072