18a1ca2cdSRiver Riddle //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 28a1ca2cdSRiver Riddle // 38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68a1ca2cdSRiver Riddle // 78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 88a1ca2cdSRiver Riddle 98a1ca2cdSRiver Riddle #include "PredicateTree.h" 10a76ee58fSStanislav Funiak #include "RootOrdering.h" 11a76ee58fSStanislav Funiak 128a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h" 138a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 148a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 1565fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h" 168a1ca2cdSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h" 17a76ee58fSStanislav Funiak #include "llvm/ADT/MapVector.h" 188ec28af8SMatthias Gehre #include "llvm/ADT/SmallPtrSet.h" 19242762c9SRiver Riddle #include "llvm/ADT/TypeSwitch.h" 20a76ee58fSStanislav Funiak #include "llvm/Support/Debug.h" 21a76ee58fSStanislav Funiak #include <queue> 22a76ee58fSStanislav Funiak 23a76ee58fSStanislav Funiak #define DEBUG_TYPE "pdl-predicate-tree" 248a1ca2cdSRiver Riddle 258a1ca2cdSRiver Riddle using namespace mlir; 268a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp; 278a1ca2cdSRiver Riddle 288a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 298a1ca2cdSRiver Riddle // Predicate List Building 308a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 318a1ca2cdSRiver Riddle 32242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList, 33242762c9SRiver Riddle Value val, PredicateBuilder &builder, 34242762c9SRiver Riddle DenseMap<Value, Position *> &inputs, 35242762c9SRiver Riddle Position *pos); 36242762c9SRiver Riddle 378a1ca2cdSRiver Riddle /// Compares the depths of two positions. 388a1ca2cdSRiver Riddle static bool comparePosDepth(Position *lhs, Position *rhs) { 393a833a0eSRiver Riddle return lhs->getOperationDepth() < rhs->getOperationDepth(); 403a833a0eSRiver Riddle } 413a833a0eSRiver Riddle 423a833a0eSRiver Riddle /// Returns the number of non-range elements within `values`. 433a833a0eSRiver Riddle static unsigned getNumNonRangeValues(ValueRange values) { 443a833a0eSRiver Riddle return llvm::count_if(values.getTypes(), 455550c821STres Popp [](Type type) { return !isa<pdl::RangeType>(type); }); 468a1ca2cdSRiver Riddle } 478a1ca2cdSRiver Riddle 488a1ca2cdSRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList, 498a1ca2cdSRiver Riddle Value val, PredicateBuilder &builder, 508a1ca2cdSRiver Riddle DenseMap<Value, Position *> &inputs, 51242762c9SRiver Riddle AttributePosition *pos) { 525550c821STres Popp assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type"); 538a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 548a1ca2cdSRiver Riddle 558ec28af8SMatthias Gehre if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { 56242762c9SRiver Riddle // If the attribute has a type or value, add a constraint. 5772fddfb5SRiver Riddle if (Value type = attr.getValueType()) 588a1ca2cdSRiver Riddle getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 59310c3ee4SRiver Riddle else if (Attribute value = attr.getValueAttr()) 60242762c9SRiver Riddle predList.emplace_back(pos, builder.getAttributeConstraint(value)); 61242762c9SRiver Riddle } 628ec28af8SMatthias Gehre } 638a1ca2cdSRiver Riddle 643a833a0eSRiver Riddle /// Collect all of the predicates for the given operand position. 653a833a0eSRiver Riddle static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList, 66242762c9SRiver Riddle Value val, PredicateBuilder &builder, 67242762c9SRiver Riddle DenseMap<Value, Position *> &inputs, 683a833a0eSRiver Riddle Position *pos) { 693a833a0eSRiver Riddle Type valueType = val.getType(); 705550c821STres Popp bool isVariadic = isa<pdl::RangeType>(valueType); 718a1ca2cdSRiver Riddle 72e07c968aSRiver Riddle // If this is a typed operand, add a type constraint. 733a833a0eSRiver Riddle TypeSwitch<Operation *>(val.getDefiningOp()) 743a833a0eSRiver Riddle .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) { 753a833a0eSRiver Riddle // Prevent traversal into a null value if the operand has a proper 763a833a0eSRiver Riddle // index. 773a833a0eSRiver Riddle if (std::is_same<pdl::OperandOp, decltype(op)>::value || 783a833a0eSRiver Riddle cast<OperandGroupPosition>(pos)->getOperandGroupNumber()) 793a833a0eSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 80242762c9SRiver Riddle 8172fddfb5SRiver Riddle if (Value type = op.getValueType()) 823a833a0eSRiver Riddle getTreePredicates(predList, type, builder, inputs, 833a833a0eSRiver Riddle builder.getType(pos)); 843a833a0eSRiver Riddle }) 853a833a0eSRiver Riddle .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) { 8622426110SRamkumar Ramachandra std::optional<unsigned> index = op.getIndex(); 873a833a0eSRiver Riddle 883a833a0eSRiver Riddle // Prevent traversal into a null value if the result has a proper index. 893a833a0eSRiver Riddle if (index) 903a833a0eSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 913a833a0eSRiver Riddle 923a833a0eSRiver Riddle // Get the parent operation of this operand. 933a833a0eSRiver Riddle OperationPosition *parentPos = builder.getOperandDefiningOp(pos); 94242762c9SRiver Riddle predList.emplace_back(parentPos, builder.getIsNotNull()); 953a833a0eSRiver Riddle 963a833a0eSRiver Riddle // Ensure that the operands match the corresponding results of the 973a833a0eSRiver Riddle // parent operation. 983a833a0eSRiver Riddle Position *resultPos = nullptr; 993a833a0eSRiver Riddle if (std::is_same<pdl::ResultOp, decltype(op)>::value) 1003a833a0eSRiver Riddle resultPos = builder.getResult(parentPos, *index); 1013a833a0eSRiver Riddle else 1023a833a0eSRiver Riddle resultPos = builder.getResultGroup(parentPos, index, isVariadic); 103242762c9SRiver Riddle predList.emplace_back(resultPos, builder.getEqualTo(pos)); 1043a833a0eSRiver Riddle 1053a833a0eSRiver Riddle // Collect the predicates of the parent operation. 106310c3ee4SRiver Riddle getTreePredicates(predList, op.getParent(), builder, inputs, 1071f13963eSRiver Riddle (Position *)parentPos); 1083a833a0eSRiver Riddle }); 1098a1ca2cdSRiver Riddle } 1108a1ca2cdSRiver Riddle 11122426110SRamkumar Ramachandra static void 11222426110SRamkumar Ramachandra getTreePredicates(std::vector<PositionalPredicate> &predList, Value val, 11322426110SRamkumar Ramachandra PredicateBuilder &builder, 11422426110SRamkumar Ramachandra DenseMap<Value, Position *> &inputs, OperationPosition *pos, 11522426110SRamkumar Ramachandra std::optional<unsigned> ignoreOperand = std::nullopt) { 1165550c821STres Popp assert(isa<pdl::OperationType>(val.getType()) && "expected operation"); 1178a1ca2cdSRiver Riddle pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 1188a1ca2cdSRiver Riddle OperationPosition *opPos = cast<OperationPosition>(pos); 1198a1ca2cdSRiver Riddle 1208a1ca2cdSRiver Riddle // Ensure getDefiningOp returns a non-null operation. 1218a1ca2cdSRiver Riddle if (!opPos->isRoot()) 1228a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 1238a1ca2cdSRiver Riddle 1248a1ca2cdSRiver Riddle // Check that this is the correct root operation. 12522426110SRamkumar Ramachandra if (std::optional<StringRef> opName = op.getOpName()) 1268a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getOperationName(*opName)); 1278a1ca2cdSRiver Riddle 1283a833a0eSRiver Riddle // Check that the operation has the proper number of operands. If there are 1293a833a0eSRiver Riddle // any variable length operands, we check a minimum instead of an exact count. 13072fddfb5SRiver Riddle OperandRange operands = op.getOperandValues(); 1313a833a0eSRiver Riddle unsigned minOperands = getNumNonRangeValues(operands); 1323a833a0eSRiver Riddle if (minOperands != operands.size()) { 1333a833a0eSRiver Riddle if (minOperands) 1343a833a0eSRiver Riddle predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands)); 1353a833a0eSRiver Riddle } else { 1363a833a0eSRiver Riddle predList.emplace_back(pos, builder.getOperandCount(minOperands)); 1373a833a0eSRiver Riddle } 1383a833a0eSRiver Riddle 1393a833a0eSRiver Riddle // Check that the operation has the proper number of results. If there are 1403a833a0eSRiver Riddle // any variable length results, we check a minimum instead of an exact count. 14172fddfb5SRiver Riddle OperandRange types = op.getTypeValues(); 1423a833a0eSRiver Riddle unsigned minResults = getNumNonRangeValues(types); 1433a833a0eSRiver Riddle if (minResults == types.size()) 144242762c9SRiver Riddle predList.emplace_back(pos, builder.getResultCount(types.size())); 1453a833a0eSRiver Riddle else if (minResults) 1463a833a0eSRiver Riddle predList.emplace_back(pos, builder.getResultCountAtLeast(minResults)); 1478a1ca2cdSRiver Riddle 1488a1ca2cdSRiver Riddle // Recurse into any attributes, operands, or results. 14972fddfb5SRiver Riddle for (auto [attrName, attr] : 15072fddfb5SRiver Riddle llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) { 1518a1ca2cdSRiver Riddle getTreePredicates( 15272fddfb5SRiver Riddle predList, attr, builder, inputs, 1535550c821STres Popp builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue())); 1548a1ca2cdSRiver Riddle } 1553a833a0eSRiver Riddle 1563a833a0eSRiver Riddle // Process the operands and results of the operation. For all values up to 1573a833a0eSRiver Riddle // the first variable length value, we use the concrete operand/result 1583a833a0eSRiver Riddle // number. After that, we use the "group" given that we can't know the 1593a833a0eSRiver Riddle // concrete indices until runtime. If there is only one variadic operand 1603a833a0eSRiver Riddle // group, we treat it as all of the operands/results of the operation. 1613a833a0eSRiver Riddle /// Operands. 1625550c821STres Popp if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) { 1632692eae5SStanislav Funiak // Ignore the operands if we are performing an upward traversal (in that 1642692eae5SStanislav Funiak // case, they have already been visited). 1652692eae5SStanislav Funiak if (opPos->isRoot() || opPos->isOperandDefiningOp()) 1663a833a0eSRiver Riddle getTreePredicates(predList, operands.front(), builder, inputs, 1673a833a0eSRiver Riddle builder.getAllOperands(opPos)); 1683a833a0eSRiver Riddle } else { 1693a833a0eSRiver Riddle bool foundVariableLength = false; 170e4853be2SMehdi Amini for (const auto &operandIt : llvm::enumerate(operands)) { 1715550c821STres Popp bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType()); 1723a833a0eSRiver Riddle foundVariableLength |= isVariadic; 1733a833a0eSRiver Riddle 174a76ee58fSStanislav Funiak // Ignore the specified operand, usually because this position was 175a76ee58fSStanislav Funiak // visited in an upward traversal via an iterative choice. 176a76ee58fSStanislav Funiak if (ignoreOperand && *ignoreOperand == operandIt.index()) 177a76ee58fSStanislav Funiak continue; 178a76ee58fSStanislav Funiak 1793a833a0eSRiver Riddle Position *pos = 1803a833a0eSRiver Riddle foundVariableLength 1813a833a0eSRiver Riddle ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) 1823a833a0eSRiver Riddle : builder.getOperand(opPos, operandIt.index()); 1833a833a0eSRiver Riddle getTreePredicates(predList, operandIt.value(), builder, inputs, pos); 184242762c9SRiver Riddle } 1853a833a0eSRiver Riddle } 1863a833a0eSRiver Riddle /// Results. 1875550c821STres Popp if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) { 1883a833a0eSRiver Riddle getTreePredicates(predList, types.front(), builder, inputs, 1893a833a0eSRiver Riddle builder.getType(builder.getAllResults(opPos))); 1908c258fdaSJakub Kuderski return; 1918c258fdaSJakub Kuderski } 1928c258fdaSJakub Kuderski 1933a833a0eSRiver Riddle bool foundVariableLength = false; 1948c258fdaSJakub Kuderski for (auto [idx, typeValue] : llvm::enumerate(types)) { 1955550c821STres Popp bool isVariadic = isa<pdl::RangeType>(typeValue.getType()); 1963a833a0eSRiver Riddle foundVariableLength |= isVariadic; 1973a833a0eSRiver Riddle 1988c258fdaSJakub Kuderski auto *resultPos = foundVariableLength 1998c258fdaSJakub Kuderski ? builder.getResultGroup(pos, idx, isVariadic) 2008c258fdaSJakub Kuderski : builder.getResult(pos, idx); 201242762c9SRiver Riddle predList.emplace_back(resultPos, builder.getIsNotNull()); 2028c258fdaSJakub Kuderski getTreePredicates(predList, typeValue, builder, inputs, 203242762c9SRiver Riddle builder.getType(resultPos)); 2048a1ca2cdSRiver Riddle } 2058a1ca2cdSRiver Riddle } 2068a1ca2cdSRiver Riddle 207242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList, 208242762c9SRiver Riddle Value val, PredicateBuilder &builder, 209242762c9SRiver Riddle DenseMap<Value, Position *> &inputs, 210242762c9SRiver Riddle TypePosition *pos) { 2118a1ca2cdSRiver Riddle // Check for a constraint on a constant type. 2123a833a0eSRiver Riddle if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) { 21372fddfb5SRiver Riddle if (Attribute type = typeOp.getConstantTypeAttr()) 2143a833a0eSRiver Riddle predList.emplace_back(pos, builder.getTypeConstraint(type)); 2153a833a0eSRiver Riddle } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) { 21672fddfb5SRiver Riddle if (Attribute typeAttr = typeOp.getConstantTypesAttr()) 2173a833a0eSRiver Riddle predList.emplace_back(pos, builder.getTypeConstraint(typeAttr)); 2183a833a0eSRiver Riddle } 2198a1ca2cdSRiver Riddle } 220242762c9SRiver Riddle 221242762c9SRiver Riddle /// Collect the tree predicates anchored at the given value. 222242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList, 223242762c9SRiver Riddle Value val, PredicateBuilder &builder, 224242762c9SRiver Riddle DenseMap<Value, Position *> &inputs, 225242762c9SRiver Riddle Position *pos) { 226242762c9SRiver Riddle // Make sure this input value is accessible to the rewrite. 227242762c9SRiver Riddle auto it = inputs.try_emplace(val, pos); 228242762c9SRiver Riddle if (!it.second) { 229242762c9SRiver Riddle // If this is an input value that has been visited in the tree, add a 230242762c9SRiver Riddle // constraint to ensure that both instances refer to the same value. 2313a833a0eSRiver Riddle if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp, 2323a833a0eSRiver Riddle pdl::TypeOp>(val.getDefiningOp())) { 233242762c9SRiver Riddle auto minMaxPositions = 234242762c9SRiver Riddle std::minmax(pos, it.first->second, comparePosDepth); 235242762c9SRiver Riddle predList.emplace_back(minMaxPositions.second, 236242762c9SRiver Riddle builder.getEqualTo(minMaxPositions.first)); 2378a1ca2cdSRiver Riddle } 238242762c9SRiver Riddle return; 239242762c9SRiver Riddle } 240242762c9SRiver Riddle 241242762c9SRiver Riddle TypeSwitch<Position *>(pos) 2423a833a0eSRiver Riddle .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) { 2433a833a0eSRiver Riddle getTreePredicates(predList, val, builder, inputs, pos); 2443a833a0eSRiver Riddle }) 2453a833a0eSRiver Riddle .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) { 2463a833a0eSRiver Riddle getOperandTreePredicates(predList, val, builder, inputs, pos); 247242762c9SRiver Riddle }) 248242762c9SRiver Riddle .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); 2498a1ca2cdSRiver Riddle } 2508a1ca2cdSRiver Riddle 251233e9476SRiver Riddle static void getAttributePredicates(pdl::AttributeOp op, 252233e9476SRiver Riddle std::vector<PositionalPredicate> &predList, 253233e9476SRiver Riddle PredicateBuilder &builder, 254233e9476SRiver Riddle DenseMap<Value, Position *> &inputs) { 255233e9476SRiver Riddle Position *&attrPos = inputs[op]; 256233e9476SRiver Riddle if (attrPos) 257233e9476SRiver Riddle return; 258310c3ee4SRiver Riddle Attribute value = op.getValueAttr(); 259233e9476SRiver Riddle assert(value && "expected non-tree `pdl.attribute` to contain a value"); 260233e9476SRiver Riddle attrPos = builder.getAttributeLiteral(value); 261233e9476SRiver Riddle } 262233e9476SRiver Riddle 26302c4c0d5SRiver Riddle static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, 264242762c9SRiver Riddle std::vector<PositionalPredicate> &predList, 265242762c9SRiver Riddle PredicateBuilder &builder, 266242762c9SRiver Riddle DenseMap<Value, Position *> &inputs) { 267310c3ee4SRiver Riddle OperandRange arguments = op.getArgs(); 2688a1ca2cdSRiver Riddle 2698a1ca2cdSRiver Riddle std::vector<Position *> allPositions; 2708a1ca2cdSRiver Riddle allPositions.reserve(arguments.size()); 2718a1ca2cdSRiver Riddle for (Value arg : arguments) 2728a1ca2cdSRiver Riddle allPositions.push_back(inputs.lookup(arg)); 2738a1ca2cdSRiver Riddle 2748a1ca2cdSRiver Riddle // Push the constraint to the furthest position. 275fab2bb8bSJustin Lebar Position *pos = *llvm::max_element(allPositions, comparePosDepth); 2768ec28af8SMatthias Gehre ResultRange results = op.getResults(); 2778ec28af8SMatthias Gehre PredicateBuilder::Predicate pred = builder.getConstraint( 2788ec28af8SMatthias Gehre op.getName(), allPositions, SmallVector<Type>(results.getTypes()), 2798ec28af8SMatthias Gehre op.getIsNegated()); 2808ec28af8SMatthias Gehre 2818ec28af8SMatthias Gehre // For each result register a position so it can be used later 2828ec28af8SMatthias Gehre for (auto [i, result] : llvm::enumerate(results)) { 2838ec28af8SMatthias Gehre ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first); 2848ec28af8SMatthias Gehre ConstraintPosition *pos = builder.getConstraintPosition(q, i); 2858ec28af8SMatthias Gehre auto [it, inserted] = inputs.try_emplace(result, pos); 2868ec28af8SMatthias Gehre // If this is an input value that has been visited in the tree, add a 2878ec28af8SMatthias Gehre // constraint to ensure that both instances refer to the same value. 2888ec28af8SMatthias Gehre if (!inserted) { 2898ec28af8SMatthias Gehre Position *first = pos; 2908ec28af8SMatthias Gehre Position *second = it->second; 2918ec28af8SMatthias Gehre if (comparePosDepth(second, first)) 2928ec28af8SMatthias Gehre std::tie(second, first) = std::make_pair(first, second); 2938ec28af8SMatthias Gehre 2948ec28af8SMatthias Gehre predList.emplace_back(second, builder.getEqualTo(first)); 2958ec28af8SMatthias Gehre } 2968ec28af8SMatthias Gehre } 2978a1ca2cdSRiver Riddle predList.emplace_back(pos, pred); 2988a1ca2cdSRiver Riddle } 299242762c9SRiver Riddle 300242762c9SRiver Riddle static void getResultPredicates(pdl::ResultOp op, 301242762c9SRiver Riddle std::vector<PositionalPredicate> &predList, 302242762c9SRiver Riddle PredicateBuilder &builder, 303242762c9SRiver Riddle DenseMap<Value, Position *> &inputs) { 304242762c9SRiver Riddle Position *&resultPos = inputs[op]; 305242762c9SRiver Riddle if (resultPos) 306242762c9SRiver Riddle return; 3073a833a0eSRiver Riddle 3083a833a0eSRiver Riddle // Ensure that the result isn't null. 309310c3ee4SRiver Riddle auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent())); 310310c3ee4SRiver Riddle resultPos = builder.getResult(parentPos, op.getIndex()); 311242762c9SRiver Riddle predList.emplace_back(resultPos, builder.getIsNotNull()); 312242762c9SRiver Riddle } 313242762c9SRiver Riddle 3143a833a0eSRiver Riddle static void getResultPredicates(pdl::ResultsOp op, 3153a833a0eSRiver Riddle std::vector<PositionalPredicate> &predList, 3163a833a0eSRiver Riddle PredicateBuilder &builder, 3173a833a0eSRiver Riddle DenseMap<Value, Position *> &inputs) { 3183a833a0eSRiver Riddle Position *&resultPos = inputs[op]; 3193a833a0eSRiver Riddle if (resultPos) 3203a833a0eSRiver Riddle return; 3213a833a0eSRiver Riddle 3223a833a0eSRiver Riddle // Ensure that the result isn't null if the result has an index. 323310c3ee4SRiver Riddle auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent())); 3245550c821STres Popp bool isVariadic = isa<pdl::RangeType>(op.getType()); 32522426110SRamkumar Ramachandra std::optional<unsigned> index = op.getIndex(); 3263a833a0eSRiver Riddle resultPos = builder.getResultGroup(parentPos, index, isVariadic); 3273a833a0eSRiver Riddle if (index) 3283a833a0eSRiver Riddle predList.emplace_back(resultPos, builder.getIsNotNull()); 3293a833a0eSRiver Riddle } 3303a833a0eSRiver Riddle 331233e9476SRiver Riddle static void getTypePredicates(Value typeValue, 332233e9476SRiver Riddle function_ref<Attribute()> typeAttrFn, 333233e9476SRiver Riddle PredicateBuilder &builder, 334233e9476SRiver Riddle DenseMap<Value, Position *> &inputs) { 335233e9476SRiver Riddle Position *&typePos = inputs[typeValue]; 336233e9476SRiver Riddle if (typePos) 337233e9476SRiver Riddle return; 338233e9476SRiver Riddle Attribute typeAttr = typeAttrFn(); 339233e9476SRiver Riddle assert(typeAttr && 340233e9476SRiver Riddle "expected non-tree `pdl.type`/`pdl.types` to contain a value"); 341233e9476SRiver Riddle typePos = builder.getTypeLiteral(typeAttr); 342233e9476SRiver Riddle } 343233e9476SRiver Riddle 344242762c9SRiver Riddle /// Collect all of the predicates that cannot be determined via walking the 345242762c9SRiver Riddle /// tree. 346242762c9SRiver Riddle static void getNonTreePredicates(pdl::PatternOp pattern, 347242762c9SRiver Riddle std::vector<PositionalPredicate> &predList, 348242762c9SRiver Riddle PredicateBuilder &builder, 349242762c9SRiver Riddle DenseMap<Value, Position *> &inputs) { 35072fddfb5SRiver Riddle for (Operation &op : pattern.getBodyRegion().getOps()) { 3513a833a0eSRiver Riddle TypeSwitch<Operation *>(&op) 352233e9476SRiver Riddle .Case([&](pdl::AttributeOp attrOp) { 353233e9476SRiver Riddle getAttributePredicates(attrOp, predList, builder, inputs); 354233e9476SRiver Riddle }) 3553a833a0eSRiver Riddle .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) { 356242762c9SRiver Riddle getConstraintPredicates(constraintOp, predList, builder, inputs); 3573a833a0eSRiver Riddle }) 3583a833a0eSRiver Riddle .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 359242762c9SRiver Riddle getResultPredicates(resultOp, predList, builder, inputs); 360233e9476SRiver Riddle }) 361233e9476SRiver Riddle .Case([&](pdl::TypeOp typeOp) { 362233e9476SRiver Riddle getTypePredicates( 36372fddfb5SRiver Riddle typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder, 36472fddfb5SRiver Riddle inputs); 365233e9476SRiver Riddle }) 366233e9476SRiver Riddle .Case([&](pdl::TypesOp typeOp) { 367233e9476SRiver Riddle getTypePredicates( 36872fddfb5SRiver Riddle typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder, 36972fddfb5SRiver Riddle inputs); 3703a833a0eSRiver Riddle }); 371242762c9SRiver Riddle } 3728a1ca2cdSRiver Riddle } 3738a1ca2cdSRiver Riddle 374a76ee58fSStanislav Funiak namespace { 375a76ee58fSStanislav Funiak 376a76ee58fSStanislav Funiak /// An op accepting a value at an optional index. 377a76ee58fSStanislav Funiak struct OpIndex { 378a76ee58fSStanislav Funiak Value parent; 37922426110SRamkumar Ramachandra std::optional<unsigned> index; 380a76ee58fSStanislav Funiak }; 381a76ee58fSStanislav Funiak 382a76ee58fSStanislav Funiak /// The parent and operand index of each operation for each root, stored 383a76ee58fSStanislav Funiak /// as a nested map [root][operation]. 384a76ee58fSStanislav Funiak using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>; 385a76ee58fSStanislav Funiak 386a76ee58fSStanislav Funiak } // namespace 387a76ee58fSStanislav Funiak 388a76ee58fSStanislav Funiak /// Given a pattern, determines the set of roots present in this pattern. 389a76ee58fSStanislav Funiak /// These are the operations whose results are not consumed by other operations. 390a76ee58fSStanislav Funiak static SmallVector<Value> detectRoots(pdl::PatternOp pattern) { 391a76ee58fSStanislav Funiak // First, collect all the operations that are used as operands 392a76ee58fSStanislav Funiak // to other operations. These are not roots by default. 393a76ee58fSStanislav Funiak DenseSet<Value> used; 39472fddfb5SRiver Riddle for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) { 39572fddfb5SRiver Riddle for (Value operand : operationOp.getOperandValues()) 396a76ee58fSStanislav Funiak TypeSwitch<Operation *>(operand.getDefiningOp()) 397a76ee58fSStanislav Funiak .Case<pdl::ResultOp, pdl::ResultsOp>( 398310c3ee4SRiver Riddle [&used](auto resultOp) { used.insert(resultOp.getParent()); }); 399a76ee58fSStanislav Funiak } 400a76ee58fSStanislav Funiak 401a76ee58fSStanislav Funiak // Remove the specified root from the use set, so that we can 402a76ee58fSStanislav Funiak // always select it as a root, even if it is used by other operations. 403310c3ee4SRiver Riddle if (Value root = pattern.getRewriter().getRoot()) 404a76ee58fSStanislav Funiak used.erase(root); 405a76ee58fSStanislav Funiak 406a76ee58fSStanislav Funiak // Finally, collect all the unused operations. 407a76ee58fSStanislav Funiak SmallVector<Value> roots; 40872fddfb5SRiver Riddle for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) 409a76ee58fSStanislav Funiak if (!used.contains(operationOp)) 410a76ee58fSStanislav Funiak roots.push_back(operationOp); 411a76ee58fSStanislav Funiak 412a76ee58fSStanislav Funiak return roots; 413a76ee58fSStanislav Funiak } 414a76ee58fSStanislav Funiak 415a76ee58fSStanislav Funiak /// Given a list of candidate roots, builds the cost graph for connecting them. 416a76ee58fSStanislav Funiak /// The graph is formed by traversing the DAG of operations starting from each 417a76ee58fSStanislav Funiak /// root and marking the depth of each connector value (operand). Then we join 418a76ee58fSStanislav Funiak /// the candidate roots based on the common connector values, taking the one 419a76ee58fSStanislav Funiak /// with the minimum depth. Along the way, we compute, for each candidate root, 420a76ee58fSStanislav Funiak /// a mapping from each operation (in the DAG underneath this root) to its 421a76ee58fSStanislav Funiak /// parent operation and the corresponding operand index. 422a76ee58fSStanislav Funiak static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph, 423a76ee58fSStanislav Funiak ParentMaps &parentMaps) { 424a76ee58fSStanislav Funiak 425a76ee58fSStanislav Funiak // The entry of a queue. The entry consists of the following items: 426a76ee58fSStanislav Funiak // * the value in the DAG underneath the root; 427a76ee58fSStanislav Funiak // * the parent of the value; 428a76ee58fSStanislav Funiak // * the operand index of the value in its parent; 429a76ee58fSStanislav Funiak // * the depth of the visited value. 430a76ee58fSStanislav Funiak struct Entry { 43122426110SRamkumar Ramachandra Entry(Value value, Value parent, std::optional<unsigned> index, 43222426110SRamkumar Ramachandra unsigned depth) 433a76ee58fSStanislav Funiak : value(value), parent(parent), index(index), depth(depth) {} 434a76ee58fSStanislav Funiak 435a76ee58fSStanislav Funiak Value value; 436a76ee58fSStanislav Funiak Value parent; 43722426110SRamkumar Ramachandra std::optional<unsigned> index; 438a76ee58fSStanislav Funiak unsigned depth; 439a76ee58fSStanislav Funiak }; 440a76ee58fSStanislav Funiak 441a76ee58fSStanislav Funiak // A root of a value and its depth (distance from root to the value). 442a76ee58fSStanislav Funiak struct RootDepth { 443a76ee58fSStanislav Funiak Value root; 444a76ee58fSStanislav Funiak unsigned depth = 0; 445a76ee58fSStanislav Funiak }; 446a76ee58fSStanislav Funiak 447a76ee58fSStanislav Funiak // Map from candidate connector values to their roots and depths. Using a 448a76ee58fSStanislav Funiak // small vector with 1 entry because most values belong to a single root. 449a76ee58fSStanislav Funiak llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths; 450a76ee58fSStanislav Funiak 451a76ee58fSStanislav Funiak // Perform a breadth-first traversal of the op DAG rooted at each root. 452a76ee58fSStanislav Funiak for (Value root : roots) { 453a76ee58fSStanislav Funiak // The queue of visited values. A value may be present multiple times in 454a76ee58fSStanislav Funiak // the queue, for multiple parents. We only accept the first occurrence, 455a76ee58fSStanislav Funiak // which is guaranteed to have the lowest depth. 456a76ee58fSStanislav Funiak std::queue<Entry> toVisit; 457a76ee58fSStanislav Funiak toVisit.emplace(root, Value(), 0, 0); 458a76ee58fSStanislav Funiak 459a76ee58fSStanislav Funiak // The map from value to its parent for the current root. 460a76ee58fSStanislav Funiak DenseMap<Value, OpIndex> &parentMap = parentMaps[root]; 461a76ee58fSStanislav Funiak 462a76ee58fSStanislav Funiak while (!toVisit.empty()) { 463a76ee58fSStanislav Funiak Entry entry = toVisit.front(); 464a76ee58fSStanislav Funiak toVisit.pop(); 465a76ee58fSStanislav Funiak // Skip if already visited. 466a76ee58fSStanislav Funiak if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) 467a76ee58fSStanislav Funiak continue; 468a76ee58fSStanislav Funiak 469a76ee58fSStanislav Funiak // Mark the root and depth of the value. 470a76ee58fSStanislav Funiak connectorsRootsDepths[entry.value].push_back({root, entry.depth}); 471a76ee58fSStanislav Funiak 472a76ee58fSStanislav Funiak // Traverse the operands of an operation and result ops. 473a76ee58fSStanislav Funiak // We intentionally do not traverse attributes and types, because those 474a76ee58fSStanislav Funiak // are expensive to join on. 475a76ee58fSStanislav Funiak TypeSwitch<Operation *>(entry.value.getDefiningOp()) 476a76ee58fSStanislav Funiak .Case<pdl::OperationOp>([&](auto operationOp) { 47772fddfb5SRiver Riddle OperandRange operands = operationOp.getOperandValues(); 478a76ee58fSStanislav Funiak // Special case when we pass all the operands in one range. 479a76ee58fSStanislav Funiak // For those, the index is empty. 480a76ee58fSStanislav Funiak if (operands.size() == 1 && 4815550c821STres Popp isa<pdl::RangeType>(operands[0].getType())) { 4821a36588eSKazu Hirata toVisit.emplace(operands[0], entry.value, std::nullopt, 483a76ee58fSStanislav Funiak entry.depth + 1); 484a76ee58fSStanislav Funiak return; 485a76ee58fSStanislav Funiak } 486a76ee58fSStanislav Funiak 487a76ee58fSStanislav Funiak // Default case: visit all the operands. 48872fddfb5SRiver Riddle for (const auto &p : 48972fddfb5SRiver Riddle llvm::enumerate(operationOp.getOperandValues())) 490a76ee58fSStanislav Funiak toVisit.emplace(p.value(), entry.value, p.index(), 491a76ee58fSStanislav Funiak entry.depth + 1); 492a76ee58fSStanislav Funiak }) 493a76ee58fSStanislav Funiak .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 494310c3ee4SRiver Riddle toVisit.emplace(resultOp.getParent(), entry.value, 495310c3ee4SRiver Riddle resultOp.getIndex(), entry.depth); 496a76ee58fSStanislav Funiak }); 497a76ee58fSStanislav Funiak } 498a76ee58fSStanislav Funiak } 499a76ee58fSStanislav Funiak 500a76ee58fSStanislav Funiak // Now build the cost graph. 501a76ee58fSStanislav Funiak // This is simply a minimum over all depths for the target root. 502a76ee58fSStanislav Funiak unsigned nextID = 0; 503a76ee58fSStanislav Funiak for (const auto &connectorRootsDepths : connectorsRootsDepths) { 504a76ee58fSStanislav Funiak Value value = connectorRootsDepths.first; 505a76ee58fSStanislav Funiak ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second; 506a76ee58fSStanislav Funiak // If there is only one root for this value, this will not trigger 507a76ee58fSStanislav Funiak // any edges in the cost graph (a perf optimization). 508a76ee58fSStanislav Funiak if (rootsDepths.size() == 1) 509a76ee58fSStanislav Funiak continue; 510a76ee58fSStanislav Funiak 511a76ee58fSStanislav Funiak for (const RootDepth &p : rootsDepths) { 512a76ee58fSStanislav Funiak for (const RootDepth &q : rootsDepths) { 513a76ee58fSStanislav Funiak if (&p == &q) 514a76ee58fSStanislav Funiak continue; 515a76ee58fSStanislav Funiak // Insert or retrieve the property of edge from p to q. 5169eb8e7b1SStanislav Funiak RootOrderingEntry &entry = graph[q.root][p.root]; 5179eb8e7b1SStanislav Funiak if (!entry.connector /* new edge */ || entry.cost.first > q.depth) { 5189eb8e7b1SStanislav Funiak if (!entry.connector) 5199eb8e7b1SStanislav Funiak entry.cost.second = nextID++; 5209eb8e7b1SStanislav Funiak entry.cost.first = q.depth; 5219eb8e7b1SStanislav Funiak entry.connector = value; 522a76ee58fSStanislav Funiak } 523a76ee58fSStanislav Funiak } 524a76ee58fSStanislav Funiak } 525a76ee58fSStanislav Funiak } 526a76ee58fSStanislav Funiak 527a76ee58fSStanislav Funiak assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && 528a76ee58fSStanislav Funiak "the pattern contains a candidate root disconnected from the others"); 529a76ee58fSStanislav Funiak } 530a76ee58fSStanislav Funiak 5312692eae5SStanislav Funiak /// Returns true if the operand at the given index needs to be queried using an 5322692eae5SStanislav Funiak /// operand group, i.e., if it is variadic itself or follows a variadic operand. 5332692eae5SStanislav Funiak static bool useOperandGroup(pdl::OperationOp op, unsigned index) { 53472fddfb5SRiver Riddle OperandRange operands = op.getOperandValues(); 5352692eae5SStanislav Funiak assert(index < operands.size() && "operand index out of range"); 5362692eae5SStanislav Funiak for (unsigned i = 0; i <= index; ++i) 5375550c821STres Popp if (isa<pdl::RangeType>(operands[i].getType())) 5382692eae5SStanislav Funiak return true; 5392692eae5SStanislav Funiak return false; 5402692eae5SStanislav Funiak } 5412692eae5SStanislav Funiak 542a76ee58fSStanislav Funiak /// Visit a node during upward traversal. 5432692eae5SStanislav Funiak static void visitUpward(std::vector<PositionalPredicate> &predList, 5442692eae5SStanislav Funiak OpIndex opIndex, PredicateBuilder &builder, 5452692eae5SStanislav Funiak DenseMap<Value, Position *> &valueToPosition, 5462692eae5SStanislav Funiak Position *&pos, unsigned rootID) { 547a76ee58fSStanislav Funiak Value value = opIndex.parent; 548a76ee58fSStanislav Funiak TypeSwitch<Operation *>(value.getDefiningOp()) 549a76ee58fSStanislav Funiak .Case<pdl::OperationOp>([&](auto operationOp) { 550a76ee58fSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 551a76ee58fSStanislav Funiak 5522692eae5SStanislav Funiak // Get users and iterate over them. 5532692eae5SStanislav Funiak Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true); 5542692eae5SStanislav Funiak Position *foreachPos = builder.getForEach(usersPos, rootID); 5552692eae5SStanislav Funiak OperationPosition *opPos = builder.getPassthroughOp(foreachPos); 5562692eae5SStanislav Funiak 5572692eae5SStanislav Funiak // Compare the operand(s) of the user against the input value(s). 5582692eae5SStanislav Funiak Position *operandPos; 5592692eae5SStanislav Funiak if (!opIndex.index) { 5602692eae5SStanislav Funiak // We are querying all the operands of the operation. 5612692eae5SStanislav Funiak operandPos = builder.getAllOperands(opPos); 5622692eae5SStanislav Funiak } else if (useOperandGroup(operationOp, *opIndex.index)) { 5632692eae5SStanislav Funiak // We are querying an operand group. 56472fddfb5SRiver Riddle Type type = operationOp.getOperandValues()[*opIndex.index].getType(); 5655550c821STres Popp bool variadic = isa<pdl::RangeType>(type); 5662692eae5SStanislav Funiak operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic); 5672692eae5SStanislav Funiak } else { 5682692eae5SStanislav Funiak // We are querying an individual operand. 5692692eae5SStanislav Funiak operandPos = builder.getOperand(opPos, *opIndex.index); 570a76ee58fSStanislav Funiak } 5712692eae5SStanislav Funiak predList.emplace_back(operandPos, builder.getEqualTo(pos)); 572a76ee58fSStanislav Funiak 573a76ee58fSStanislav Funiak // Guard against duplicate upward visits. These are not possible, 574a76ee58fSStanislav Funiak // because if this value was already visited, it would have been 575a76ee58fSStanislav Funiak // cheaper to start the traversal at this value rather than at the 576a76ee58fSStanislav Funiak // `connector`, violating the optimality of our spanning tree. 577a76ee58fSStanislav Funiak bool inserted = valueToPosition.try_emplace(value, opPos).second; 5781b0312d2SBenjamin Kramer (void)inserted; 579a76ee58fSStanislav Funiak assert(inserted && "duplicate upward visit"); 580a76ee58fSStanislav Funiak 581a76ee58fSStanislav Funiak // Obtain the tree predicates at the current value. 582a76ee58fSStanislav Funiak getTreePredicates(predList, value, builder, valueToPosition, opPos, 583a76ee58fSStanislav Funiak opIndex.index); 584a76ee58fSStanislav Funiak 585a76ee58fSStanislav Funiak // Update the position 586a76ee58fSStanislav Funiak pos = opPos; 587a76ee58fSStanislav Funiak }) 588a76ee58fSStanislav Funiak .Case<pdl::ResultOp>([&](auto resultOp) { 589a76ee58fSStanislav Funiak // Traverse up an individual result. 590a76ee58fSStanislav Funiak auto *opPos = dyn_cast<OperationPosition>(pos); 591a76ee58fSStanislav Funiak assert(opPos && "operations and results must be interleaved"); 592a76ee58fSStanislav Funiak pos = builder.getResult(opPos, *opIndex.index); 5932692eae5SStanislav Funiak 5942692eae5SStanislav Funiak // Insert the result position in case we have not visited it yet. 5952692eae5SStanislav Funiak valueToPosition.try_emplace(value, pos); 596a76ee58fSStanislav Funiak }) 597a76ee58fSStanislav Funiak .Case<pdl::ResultsOp>([&](auto resultOp) { 598a76ee58fSStanislav Funiak // Traverse up a group of results. 599a76ee58fSStanislav Funiak auto *opPos = dyn_cast<OperationPosition>(pos); 600a76ee58fSStanislav Funiak assert(opPos && "operations and results must be interleaved"); 6015550c821STres Popp bool isVariadic = isa<pdl::RangeType>(value.getType()); 602a76ee58fSStanislav Funiak if (opIndex.index) 603a76ee58fSStanislav Funiak pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); 604a76ee58fSStanislav Funiak else 605a76ee58fSStanislav Funiak pos = builder.getAllResults(opPos); 6062692eae5SStanislav Funiak 6072692eae5SStanislav Funiak // Insert the result position in case we have not visited it yet. 6082692eae5SStanislav Funiak valueToPosition.try_emplace(value, pos); 609a76ee58fSStanislav Funiak }); 610a76ee58fSStanislav Funiak } 611a76ee58fSStanislav Funiak 6128a1ca2cdSRiver Riddle /// Given a pattern operation, build the set of matcher predicates necessary to 6138a1ca2cdSRiver Riddle /// match this pattern. 614a76ee58fSStanislav Funiak static Value buildPredicateList(pdl::PatternOp pattern, 6158a1ca2cdSRiver Riddle PredicateBuilder &builder, 6168a1ca2cdSRiver Riddle std::vector<PositionalPredicate> &predList, 6178a1ca2cdSRiver Riddle DenseMap<Value, Position *> &valueToPosition) { 618a76ee58fSStanislav Funiak SmallVector<Value> roots = detectRoots(pattern); 619a76ee58fSStanislav Funiak 620a76ee58fSStanislav Funiak // Build the root ordering graph and compute the parent maps. 621a76ee58fSStanislav Funiak RootOrderingGraph graph; 622a76ee58fSStanislav Funiak ParentMaps parentMaps; 623a76ee58fSStanislav Funiak buildCostGraph(roots, graph, parentMaps); 624a76ee58fSStanislav Funiak LLVM_DEBUG({ 625a76ee58fSStanislav Funiak llvm::dbgs() << "Graph:\n"; 626a76ee58fSStanislav Funiak for (auto &target : graph) { 6272692eae5SStanislav Funiak llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first 6282692eae5SStanislav Funiak << "\n"; 629a76ee58fSStanislav Funiak for (auto &source : target.second) { 6309eb8e7b1SStanislav Funiak RootOrderingEntry &entry = source.second; 6319eb8e7b1SStanislav Funiak llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first 6329eb8e7b1SStanislav Funiak << ":" << entry.cost.second << " via " 6339eb8e7b1SStanislav Funiak << entry.connector.getLoc() << "\n"; 634a76ee58fSStanislav Funiak } 635a76ee58fSStanislav Funiak } 636a76ee58fSStanislav Funiak }); 637a76ee58fSStanislav Funiak 638a76ee58fSStanislav Funiak // Solve the optimal branching problem for each candidate root, or use the 639a76ee58fSStanislav Funiak // provided one. 640310c3ee4SRiver Riddle Value bestRoot = pattern.getRewriter().getRoot(); 641a76ee58fSStanislav Funiak OptimalBranching::EdgeList bestEdges; 642a76ee58fSStanislav Funiak if (!bestRoot) { 643a76ee58fSStanislav Funiak unsigned bestCost = 0; 644a76ee58fSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); 645a76ee58fSStanislav Funiak for (Value root : roots) { 646a76ee58fSStanislav Funiak OptimalBranching solver(graph, root); 647a76ee58fSStanislav Funiak unsigned cost = solver.solve(); 648a76ee58fSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); 649a76ee58fSStanislav Funiak if (!bestRoot || bestCost > cost) { 650a76ee58fSStanislav Funiak bestCost = cost; 651a76ee58fSStanislav Funiak bestRoot = root; 652a76ee58fSStanislav Funiak bestEdges = solver.preOrderTraversal(roots); 653a76ee58fSStanislav Funiak } 654a76ee58fSStanislav Funiak } 655a76ee58fSStanislav Funiak } else { 656a76ee58fSStanislav Funiak OptimalBranching solver(graph, bestRoot); 657a76ee58fSStanislav Funiak solver.solve(); 658a76ee58fSStanislav Funiak bestEdges = solver.preOrderTraversal(roots); 659a76ee58fSStanislav Funiak } 660a76ee58fSStanislav Funiak 6612692eae5SStanislav Funiak // Print the best solution. 6622692eae5SStanislav Funiak LLVM_DEBUG({ 6632692eae5SStanislav Funiak llvm::dbgs() << "Best tree:\n"; 6642692eae5SStanislav Funiak for (const std::pair<Value, Value> &edge : bestEdges) { 6652692eae5SStanislav Funiak llvm::dbgs() << " * " << edge.first; 6662692eae5SStanislav Funiak if (edge.second) 6672692eae5SStanislav Funiak llvm::dbgs() << " <- " << edge.second; 6682692eae5SStanislav Funiak llvm::dbgs() << "\n"; 6692692eae5SStanislav Funiak } 6702692eae5SStanislav Funiak }); 6712692eae5SStanislav Funiak 672a76ee58fSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); 673a76ee58fSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); 674a76ee58fSStanislav Funiak 675a76ee58fSStanislav Funiak // The best root is the starting point for the traversal. Get the tree 676a76ee58fSStanislav Funiak // predicates for the DAG rooted at bestRoot. 677a76ee58fSStanislav Funiak getTreePredicates(predList, bestRoot, builder, valueToPosition, 678a76ee58fSStanislav Funiak builder.getRoot()); 679a76ee58fSStanislav Funiak 680a76ee58fSStanislav Funiak // Traverse the selected optimal branching. For all edges in order, traverse 681a76ee58fSStanislav Funiak // up starting from the connector, until the candidate root is reached, and 682a76ee58fSStanislav Funiak // call getTreePredicates at every node along the way. 68350da0134SAdrian Kuegel for (const auto &it : llvm::enumerate(bestEdges)) { 6842692eae5SStanislav Funiak Value target = it.value().first; 6852692eae5SStanislav Funiak Value source = it.value().second; 686a76ee58fSStanislav Funiak 687a76ee58fSStanislav Funiak // Check if we already visited the target root. This happens in two cases: 688a76ee58fSStanislav Funiak // 1) the initial root (bestRoot); 689a76ee58fSStanislav Funiak // 2) a root that is dominated by (contained in the subtree rooted at) an 690a76ee58fSStanislav Funiak // already visited root. 691a76ee58fSStanislav Funiak if (valueToPosition.count(target)) 692a76ee58fSStanislav Funiak continue; 693a76ee58fSStanislav Funiak 694a76ee58fSStanislav Funiak // Determine the connector. 695a76ee58fSStanislav Funiak Value connector = graph[target][source].connector; 696a76ee58fSStanislav Funiak assert(connector && "invalid edge"); 697a76ee58fSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); 698a76ee58fSStanislav Funiak DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target); 699a76ee58fSStanislav Funiak Position *pos = valueToPosition.lookup(connector); 7002692eae5SStanislav Funiak assert(pos && "connector has not been traversed yet"); 701a76ee58fSStanislav Funiak 702a76ee58fSStanislav Funiak // Traverse from the connector upwards towards the target root. 703a76ee58fSStanislav Funiak for (Value value = connector; value != target;) { 704a76ee58fSStanislav Funiak OpIndex opIndex = parentMap.lookup(value); 705a76ee58fSStanislav Funiak assert(opIndex.parent && "missing parent"); 7062692eae5SStanislav Funiak visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index()); 707a76ee58fSStanislav Funiak value = opIndex.parent; 708a76ee58fSStanislav Funiak } 709a76ee58fSStanislav Funiak } 710a76ee58fSStanislav Funiak 711242762c9SRiver Riddle getNonTreePredicates(pattern, predList, builder, valueToPosition); 712a76ee58fSStanislav Funiak 713a76ee58fSStanislav Funiak return bestRoot; 7148a1ca2cdSRiver Riddle } 7158a1ca2cdSRiver Riddle 7168a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 7178a1ca2cdSRiver Riddle // Pattern Predicate Tree Merging 7188a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 7198a1ca2cdSRiver Riddle 7208a1ca2cdSRiver Riddle namespace { 7218a1ca2cdSRiver Riddle 7228a1ca2cdSRiver Riddle /// This class represents a specific predicate applied to a position, and 7238a1ca2cdSRiver Riddle /// provides hashing and ordering operators. This class allows for computing a 7248a1ca2cdSRiver Riddle /// frequence sum and ordering predicates based on a cost model. 7258a1ca2cdSRiver Riddle struct OrderedPredicate { 7268a1ca2cdSRiver Riddle OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 7278a1ca2cdSRiver Riddle : position(ip.first), question(ip.second) {} 7288a1ca2cdSRiver Riddle OrderedPredicate(const PositionalPredicate &ip) 7298a1ca2cdSRiver Riddle : position(ip.position), question(ip.question) {} 7308a1ca2cdSRiver Riddle 7318a1ca2cdSRiver Riddle /// The position this predicate is applied to. 7328a1ca2cdSRiver Riddle Position *position; 7338a1ca2cdSRiver Riddle 7348a1ca2cdSRiver Riddle /// The question that is applied by this predicate onto the position. 7358a1ca2cdSRiver Riddle Qualifier *question; 7368a1ca2cdSRiver Riddle 7378a1ca2cdSRiver Riddle /// The first and second order benefit sums. 7388a1ca2cdSRiver Riddle /// The primary sum is the number of occurrences of this predicate among all 7398a1ca2cdSRiver Riddle /// of the patterns. 7408a1ca2cdSRiver Riddle unsigned primary = 0; 7418a1ca2cdSRiver Riddle /// The secondary sum is a squared summation of the primary sum of all of the 7428a1ca2cdSRiver Riddle /// predicates within each pattern that contains this predicate. This allows 7438a1ca2cdSRiver Riddle /// for favoring predicates that are more commonly shared within a pattern, as 7448a1ca2cdSRiver Riddle /// opposed to those shared across patterns. 7458a1ca2cdSRiver Riddle unsigned secondary = 0; 7468a1ca2cdSRiver Riddle 747138803e0SStanislav Funiak /// The tie breaking ID, used to preserve a deterministic (insertion) order 748138803e0SStanislav Funiak /// among all the predicates with the same priority, depth, and position / 749138803e0SStanislav Funiak /// predicate dependency. 750138803e0SStanislav Funiak unsigned id = 0; 751138803e0SStanislav Funiak 7528a1ca2cdSRiver Riddle /// A map between a pattern operation and the answer to the predicate question 7538a1ca2cdSRiver Riddle /// within that pattern. 7548a1ca2cdSRiver Riddle DenseMap<Operation *, Qualifier *> patternToAnswer; 7558a1ca2cdSRiver Riddle 756ddd556f1SRiver Riddle /// Returns true if this predicate is ordered before `rhs`, based on the cost 757ddd556f1SRiver Riddle /// model. 758ddd556f1SRiver Riddle bool operator<(const OrderedPredicate &rhs) const { 7598a1ca2cdSRiver Riddle // Sort by: 760ddd556f1SRiver Riddle // * higher first and secondary order sums 7618a1ca2cdSRiver Riddle // * lower depth 762ddd556f1SRiver Riddle // * lower position dependency 763ddd556f1SRiver Riddle // * lower predicate dependency 764138803e0SStanislav Funiak // * lower tie breaking ID 765ddd556f1SRiver Riddle auto *rhsPos = rhs.position; 7663a833a0eSRiver Riddle return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), 767138803e0SStanislav Funiak rhsPos->getKind(), rhs.question->getKind(), rhs.id) > 768ddd556f1SRiver Riddle std::make_tuple(rhs.primary, rhs.secondary, 7693a833a0eSRiver Riddle position->getOperationDepth(), position->getKind(), 770138803e0SStanislav Funiak question->getKind(), id); 7718a1ca2cdSRiver Riddle } 7728a1ca2cdSRiver Riddle }; 7738a1ca2cdSRiver Riddle 7748a1ca2cdSRiver Riddle /// A DenseMapInfo for OrderedPredicate based solely on the position and 7758a1ca2cdSRiver Riddle /// question. 7768a1ca2cdSRiver Riddle struct OrderedPredicateDenseInfo { 7778a1ca2cdSRiver Riddle using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 7788a1ca2cdSRiver Riddle 7798a1ca2cdSRiver Riddle static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 7808a1ca2cdSRiver Riddle static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 7818a1ca2cdSRiver Riddle static bool isEqual(const OrderedPredicate &lhs, 7828a1ca2cdSRiver Riddle const OrderedPredicate &rhs) { 7838a1ca2cdSRiver Riddle return lhs.position == rhs.position && lhs.question == rhs.question; 7848a1ca2cdSRiver Riddle } 7858a1ca2cdSRiver Riddle static unsigned getHashValue(const OrderedPredicate &p) { 7868a1ca2cdSRiver Riddle return llvm::hash_combine(p.position, p.question); 7878a1ca2cdSRiver Riddle } 7888a1ca2cdSRiver Riddle }; 7898a1ca2cdSRiver Riddle 7908a1ca2cdSRiver Riddle /// This class wraps a set of ordered predicates that are used within a specific 7918a1ca2cdSRiver Riddle /// pattern operation. 7928a1ca2cdSRiver Riddle struct OrderedPredicateList { 793a76ee58fSStanislav Funiak OrderedPredicateList(pdl::PatternOp pattern, Value root) 794a76ee58fSStanislav Funiak : pattern(pattern), root(root) {} 7958a1ca2cdSRiver Riddle 7968a1ca2cdSRiver Riddle pdl::PatternOp pattern; 797a76ee58fSStanislav Funiak Value root; 7988a1ca2cdSRiver Riddle DenseSet<OrderedPredicate *> predicates; 7998a1ca2cdSRiver Riddle }; 800be0a7e9fSMehdi Amini } // namespace 8018a1ca2cdSRiver Riddle 8028a1ca2cdSRiver Riddle /// Returns true if the given matcher refers to the same predicate as the given 8038a1ca2cdSRiver Riddle /// ordered predicate. This means that the position and questions of the two 8048a1ca2cdSRiver Riddle /// match. 8058a1ca2cdSRiver Riddle static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 8068a1ca2cdSRiver Riddle return node->getPosition() == predicate->position && 8078a1ca2cdSRiver Riddle node->getQuestion() == predicate->question; 8088a1ca2cdSRiver Riddle } 8098a1ca2cdSRiver Riddle 8108a1ca2cdSRiver Riddle /// Get or insert a child matcher for the given parent switch node, given a 8118a1ca2cdSRiver Riddle /// predicate and parent pattern. 8128a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 8138a1ca2cdSRiver Riddle OrderedPredicate *predicate, 8148a1ca2cdSRiver Riddle pdl::PatternOp pattern) { 8158a1ca2cdSRiver Riddle assert(isSamePredicate(node, predicate) && 8168a1ca2cdSRiver Riddle "expected matcher to equal the given predicate"); 8178a1ca2cdSRiver Riddle 8188a1ca2cdSRiver Riddle auto it = predicate->patternToAnswer.find(pattern); 8198a1ca2cdSRiver Riddle assert(it != predicate->patternToAnswer.end() && 8208a1ca2cdSRiver Riddle "expected pattern to exist in predicate"); 821*5d6cb6f7SKazu Hirata return node->getChildren()[it->second]; 8228a1ca2cdSRiver Riddle } 8238a1ca2cdSRiver Riddle 8248a1ca2cdSRiver Riddle /// Build the matcher CFG by "pushing" patterns through by sorted predicate 8258a1ca2cdSRiver Riddle /// order. A pattern will traverse as far as possible using common predicates 8268a1ca2cdSRiver Riddle /// and then either diverge from the CFG or reach the end of a branch and start 8278a1ca2cdSRiver Riddle /// creating new nodes. 8288a1ca2cdSRiver Riddle static void propagatePattern(std::unique_ptr<MatcherNode> &node, 8298a1ca2cdSRiver Riddle OrderedPredicateList &list, 8308a1ca2cdSRiver Riddle std::vector<OrderedPredicate *>::iterator current, 8318a1ca2cdSRiver Riddle std::vector<OrderedPredicate *>::iterator end) { 8328a1ca2cdSRiver Riddle if (current == end) { 8338a1ca2cdSRiver Riddle // We've hit the end of a pattern, so create a successful result node. 834a76ee58fSStanislav Funiak node = 835a76ee58fSStanislav Funiak std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node)); 8368a1ca2cdSRiver Riddle 8378a1ca2cdSRiver Riddle // If the pattern doesn't contain this predicate, ignore it. 838ce14f7b1SKazu Hirata } else if (!list.predicates.contains(*current)) { 8398a1ca2cdSRiver Riddle propagatePattern(node, list, std::next(current), end); 8408a1ca2cdSRiver Riddle 8418a1ca2cdSRiver Riddle // If the current matcher node is invalid, create a new one for this 8428a1ca2cdSRiver Riddle // position and continue propagation. 8438a1ca2cdSRiver Riddle } else if (!node) { 8448a1ca2cdSRiver Riddle // Create a new node at this position and continue 8458a1ca2cdSRiver Riddle node = std::make_unique<SwitchNode>((*current)->position, 8468a1ca2cdSRiver Riddle (*current)->question); 8478a1ca2cdSRiver Riddle propagatePattern( 8488a1ca2cdSRiver Riddle getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 8498a1ca2cdSRiver Riddle list, std::next(current), end); 8508a1ca2cdSRiver Riddle 8518a1ca2cdSRiver Riddle // If the matcher has already been created, and it is for this predicate we 8528a1ca2cdSRiver Riddle // continue propagation to the child. 8538a1ca2cdSRiver Riddle } else if (isSamePredicate(node.get(), *current)) { 8548a1ca2cdSRiver Riddle propagatePattern( 8558a1ca2cdSRiver Riddle getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 8568a1ca2cdSRiver Riddle list, std::next(current), end); 8578a1ca2cdSRiver Riddle 8588a1ca2cdSRiver Riddle // If the matcher doesn't match the current predicate, insert a branch as 8598a1ca2cdSRiver Riddle // the common set of matchers has diverged. 8608a1ca2cdSRiver Riddle } else { 8618a1ca2cdSRiver Riddle propagatePattern(node->getFailureNode(), list, current, end); 8628a1ca2cdSRiver Riddle } 8638a1ca2cdSRiver Riddle } 8648a1ca2cdSRiver Riddle 8658a1ca2cdSRiver Riddle /// Fold any switch nodes nested under `node` to boolean nodes when possible. 8668a1ca2cdSRiver Riddle /// `node` is updated in-place if it is a switch. 8678a1ca2cdSRiver Riddle static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 8688a1ca2cdSRiver Riddle if (!node) 8698a1ca2cdSRiver Riddle return; 8708a1ca2cdSRiver Riddle 8718a1ca2cdSRiver Riddle if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 8728a1ca2cdSRiver Riddle SwitchNode::ChildMapT &children = switchNode->getChildren(); 8738a1ca2cdSRiver Riddle for (auto &it : children) 8748a1ca2cdSRiver Riddle foldSwitchToBool(it.second); 8758a1ca2cdSRiver Riddle 8768a1ca2cdSRiver Riddle // If the node only contains one child, collapse it into a boolean predicate 8778a1ca2cdSRiver Riddle // node. 8788a1ca2cdSRiver Riddle if (children.size() == 1) { 8791cef577bSMehdi Amini auto *childIt = children.begin(); 8808a1ca2cdSRiver Riddle node = std::make_unique<BoolNode>( 8818a1ca2cdSRiver Riddle node->getPosition(), node->getQuestion(), childIt->first, 8828a1ca2cdSRiver Riddle std::move(childIt->second), std::move(node->getFailureNode())); 8838a1ca2cdSRiver Riddle } 8848a1ca2cdSRiver Riddle } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 8858a1ca2cdSRiver Riddle foldSwitchToBool(boolNode->getSuccessNode()); 8868a1ca2cdSRiver Riddle } 8878a1ca2cdSRiver Riddle 8888a1ca2cdSRiver Riddle foldSwitchToBool(node->getFailureNode()); 8898a1ca2cdSRiver Riddle } 8908a1ca2cdSRiver Riddle 8918a1ca2cdSRiver Riddle /// Insert an exit node at the end of the failure path of the `root`. 8928a1ca2cdSRiver Riddle static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 8938a1ca2cdSRiver Riddle while (*root) 8948a1ca2cdSRiver Riddle root = &(*root)->getFailureNode(); 8958a1ca2cdSRiver Riddle *root = std::make_unique<ExitNode>(); 8968a1ca2cdSRiver Riddle } 8978a1ca2cdSRiver Riddle 8988ec28af8SMatthias Gehre /// Sorts the range begin/end with the partial order given by cmp. 8998ec28af8SMatthias Gehre template <typename Iterator, typename Compare> 9008ec28af8SMatthias Gehre static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) { 9018ec28af8SMatthias Gehre while (begin != end) { 9028ec28af8SMatthias Gehre // Cannot compute sortBeforeOthers in the predicate of stable_partition 9038ec28af8SMatthias Gehre // because stable_partition will not keep the [begin, end) range intact 9048ec28af8SMatthias Gehre // while it runs. 9058ec28af8SMatthias Gehre llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers; 9068ec28af8SMatthias Gehre for (auto i = begin; i != end; ++i) { 9078ec28af8SMatthias Gehre if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); })) 9088ec28af8SMatthias Gehre sortBeforeOthers.insert(*i); 9098ec28af8SMatthias Gehre } 9108ec28af8SMatthias Gehre 9118ec28af8SMatthias Gehre auto const next = std::stable_partition(begin, end, [&](auto const &a) { 9128ec28af8SMatthias Gehre return sortBeforeOthers.contains(a); 9138ec28af8SMatthias Gehre }); 9148ec28af8SMatthias Gehre assert(next != begin && "not a partial ordering"); 9158ec28af8SMatthias Gehre begin = next; 9168ec28af8SMatthias Gehre } 9178ec28af8SMatthias Gehre } 9188ec28af8SMatthias Gehre 9198ec28af8SMatthias Gehre /// Returns true if 'b' depends on a result of 'a'. 9208ec28af8SMatthias Gehre static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) { 9218ec28af8SMatthias Gehre auto *cqa = dyn_cast<ConstraintQuestion>(a->question); 9228ec28af8SMatthias Gehre if (!cqa) 9238ec28af8SMatthias Gehre return false; 9248ec28af8SMatthias Gehre 9258ec28af8SMatthias Gehre auto positionDependsOnA = [&](Position *p) { 9268ec28af8SMatthias Gehre auto *cp = dyn_cast<ConstraintPosition>(p); 9278ec28af8SMatthias Gehre return cp && cp->getQuestion() == cqa; 9288ec28af8SMatthias Gehre }; 9298ec28af8SMatthias Gehre 9308ec28af8SMatthias Gehre if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) { 9318ec28af8SMatthias Gehre // Does any argument of b use a? 9328ec28af8SMatthias Gehre return llvm::any_of(cqb->getArgs(), positionDependsOnA); 9338ec28af8SMatthias Gehre } 9348ec28af8SMatthias Gehre if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) { 9358ec28af8SMatthias Gehre return positionDependsOnA(b->position) || 9368ec28af8SMatthias Gehre positionDependsOnA(equalTo->getValue()); 9378ec28af8SMatthias Gehre } 9388ec28af8SMatthias Gehre return positionDependsOnA(b->position); 9398ec28af8SMatthias Gehre } 9408ec28af8SMatthias Gehre 9418a1ca2cdSRiver Riddle /// Given a module containing PDL pattern operations, generate a matcher tree 9428a1ca2cdSRiver Riddle /// using the patterns within the given module and return the root matcher node. 9438a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> 9448a1ca2cdSRiver Riddle MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 9458a1ca2cdSRiver Riddle DenseMap<Value, Position *> &valueToPosition) { 946a76ee58fSStanislav Funiak // The set of predicates contained within the pattern operations of the 947a76ee58fSStanislav Funiak // module. 948a76ee58fSStanislav Funiak struct PatternPredicates { 949a76ee58fSStanislav Funiak PatternPredicates(pdl::PatternOp pattern, Value root, 950a76ee58fSStanislav Funiak std::vector<PositionalPredicate> predicates) 951a76ee58fSStanislav Funiak : pattern(pattern), root(root), predicates(std::move(predicates)) {} 952a76ee58fSStanislav Funiak 953a76ee58fSStanislav Funiak /// A pattern. 954a76ee58fSStanislav Funiak pdl::PatternOp pattern; 955a76ee58fSStanislav Funiak 956a76ee58fSStanislav Funiak /// A root of the pattern chosen among the candidate roots in pdl.rewrite. 957a76ee58fSStanislav Funiak Value root; 958a76ee58fSStanislav Funiak 959a76ee58fSStanislav Funiak /// The extracted predicates for this pattern and root. 960a76ee58fSStanislav Funiak std::vector<PositionalPredicate> predicates; 961a76ee58fSStanislav Funiak }; 962a76ee58fSStanislav Funiak 963a76ee58fSStanislav Funiak SmallVector<PatternPredicates, 16> patternsAndPredicates; 9648a1ca2cdSRiver Riddle for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 9658a1ca2cdSRiver Riddle std::vector<PositionalPredicate> predicateList; 966a76ee58fSStanislav Funiak Value root = 9678a1ca2cdSRiver Riddle buildPredicateList(pattern, builder, predicateList, valueToPosition); 968a76ee58fSStanislav Funiak patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); 9698a1ca2cdSRiver Riddle } 9708a1ca2cdSRiver Riddle 9718a1ca2cdSRiver Riddle // Associate a pattern result with each unique predicate. 9728a1ca2cdSRiver Riddle DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 9738a1ca2cdSRiver Riddle for (auto &patternAndPredList : patternsAndPredicates) { 974a76ee58fSStanislav Funiak for (auto &predicate : patternAndPredList.predicates) { 9758a1ca2cdSRiver Riddle auto it = uniqued.insert(predicate); 976a76ee58fSStanislav Funiak it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, 9778a1ca2cdSRiver Riddle predicate.answer); 978138803e0SStanislav Funiak // Mark the insertion order (0-based indexing). 979138803e0SStanislav Funiak if (it.second) 980138803e0SStanislav Funiak it.first->id = uniqued.size() - 1; 9818a1ca2cdSRiver Riddle } 9828a1ca2cdSRiver Riddle } 9838a1ca2cdSRiver Riddle 9848a1ca2cdSRiver Riddle // Associate each pattern to a set of its ordered predicates for later lookup. 9858a1ca2cdSRiver Riddle std::vector<OrderedPredicateList> lists; 9868a1ca2cdSRiver Riddle lists.reserve(patternsAndPredicates.size()); 9878a1ca2cdSRiver Riddle for (auto &patternAndPredList : patternsAndPredicates) { 988a76ee58fSStanislav Funiak OrderedPredicateList list(patternAndPredList.pattern, 989a76ee58fSStanislav Funiak patternAndPredList.root); 990a76ee58fSStanislav Funiak for (auto &predicate : patternAndPredList.predicates) { 9918a1ca2cdSRiver Riddle OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 9928a1ca2cdSRiver Riddle list.predicates.insert(orderedPredicate); 9938a1ca2cdSRiver Riddle 9948a1ca2cdSRiver Riddle // Increment the primary sum for each reference to a particular predicate. 9958a1ca2cdSRiver Riddle ++orderedPredicate->primary; 9968a1ca2cdSRiver Riddle } 9978a1ca2cdSRiver Riddle lists.push_back(std::move(list)); 9988a1ca2cdSRiver Riddle } 9998a1ca2cdSRiver Riddle 10008a1ca2cdSRiver Riddle // For a particular pattern, get the total primary sum and add it to the 10018a1ca2cdSRiver Riddle // secondary sum of each predicate. Square the primary sums to emphasize 10028a1ca2cdSRiver Riddle // shared predicates within rather than across patterns. 10038a1ca2cdSRiver Riddle for (auto &list : lists) { 10048a1ca2cdSRiver Riddle unsigned total = 0; 10058a1ca2cdSRiver Riddle for (auto *predicate : list.predicates) 10068a1ca2cdSRiver Riddle total += predicate->primary * predicate->primary; 10078a1ca2cdSRiver Riddle for (auto *predicate : list.predicates) 10088a1ca2cdSRiver Riddle predicate->secondary += total; 10098a1ca2cdSRiver Riddle } 10108a1ca2cdSRiver Riddle 10118a1ca2cdSRiver Riddle // Sort the set of predicates now that the cost primary and secondary sums 10128a1ca2cdSRiver Riddle // have been computed. 10138a1ca2cdSRiver Riddle std::vector<OrderedPredicate *> ordered; 10148a1ca2cdSRiver Riddle ordered.reserve(uniqued.size()); 10158a1ca2cdSRiver Riddle for (auto &ip : uniqued) 10168a1ca2cdSRiver Riddle ordered.push_back(&ip); 1017138803e0SStanislav Funiak llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) { 1018138803e0SStanislav Funiak return *lhs < *rhs; 1019138803e0SStanislav Funiak }); 10208a1ca2cdSRiver Riddle 10218ec28af8SMatthias Gehre // Mostly keep the now established order, but also ensure that 10228ec28af8SMatthias Gehre // ConstraintQuestions come after the results they use. 10238ec28af8SMatthias Gehre stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn); 10248ec28af8SMatthias Gehre 10258a1ca2cdSRiver Riddle // Build the matchers for each of the pattern predicate lists. 10268a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> root; 10278a1ca2cdSRiver Riddle for (OrderedPredicateList &list : lists) 10288a1ca2cdSRiver Riddle propagatePattern(root, list, ordered.begin(), ordered.end()); 10298a1ca2cdSRiver Riddle 10308a1ca2cdSRiver Riddle // Collapse the graph and insert the exit node. 10318a1ca2cdSRiver Riddle foldSwitchToBool(root); 10328a1ca2cdSRiver Riddle insertExitNode(&root); 10338a1ca2cdSRiver Riddle return root; 10348a1ca2cdSRiver Riddle } 10358a1ca2cdSRiver Riddle 10368a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10378a1ca2cdSRiver Riddle // MatcherNode 10388a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10398a1ca2cdSRiver Riddle 10408a1ca2cdSRiver Riddle MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 10418a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 10428a1ca2cdSRiver Riddle : position(p), question(q), failureNode(std::move(failureNode)), 10438a1ca2cdSRiver Riddle matcherTypeID(matcherTypeID) {} 10448a1ca2cdSRiver Riddle 10458a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10468a1ca2cdSRiver Riddle // BoolNode 10478a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10488a1ca2cdSRiver Riddle 10498a1ca2cdSRiver Riddle BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 10508a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> successNode, 10518a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 10528a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<BoolNode>(), position, question, 10538a1ca2cdSRiver Riddle std::move(failureNode)), 10548a1ca2cdSRiver Riddle answer(answer), successNode(std::move(successNode)) {} 10558a1ca2cdSRiver Riddle 10568a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10578a1ca2cdSRiver Riddle // SuccessNode 10588a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10598a1ca2cdSRiver Riddle 1060a76ee58fSStanislav Funiak SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, 10618a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 10628a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 10638a1ca2cdSRiver Riddle /*question=*/nullptr, std::move(failureNode)), 1064a76ee58fSStanislav Funiak pattern(pattern), root(root) {} 10658a1ca2cdSRiver Riddle 10668a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10678a1ca2cdSRiver Riddle // SwitchNode 10688a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 10698a1ca2cdSRiver Riddle 10708a1ca2cdSRiver Riddle SwitchNode::SwitchNode(Position *position, Qualifier *question) 10718a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 1072