1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PredicateTree.h" 10 #include "mlir/Dialect/PDL/IR/PDL.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::pdl_to_pdl_interp; 19 20 //===----------------------------------------------------------------------===// 21 // Predicate List Building 22 //===----------------------------------------------------------------------===// 23 24 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 25 Value val, PredicateBuilder &builder, 26 DenseMap<Value, Position *> &inputs, 27 Position *pos); 28 29 /// Compares the depths of two positions. 30 static bool comparePosDepth(Position *lhs, Position *rhs) { 31 return lhs->getIndex().size() < rhs->getIndex().size(); 32 } 33 34 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 35 Value val, PredicateBuilder &builder, 36 DenseMap<Value, Position *> &inputs, 37 AttributePosition *pos) { 38 assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type"); 39 pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp()); 40 predList.emplace_back(pos, builder.getIsNotNull()); 41 42 // If the attribute has a type or value, add a constraint. 43 if (Value type = attr.type()) 44 getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 45 else if (Attribute value = attr.valueAttr()) 46 predList.emplace_back(pos, builder.getAttributeConstraint(value)); 47 } 48 49 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 50 Value val, PredicateBuilder &builder, 51 DenseMap<Value, Position *> &inputs, 52 OperandPosition *pos) { 53 assert(val.getType().isa<pdl::ValueType>() && "expected value type"); 54 55 // Prevent traversal into a null value. 56 predList.emplace_back(pos, builder.getIsNotNull()); 57 58 // If this is a typed operand, add a type constraint. 59 if (auto in = val.getDefiningOp<pdl::OperandOp>()) { 60 if (Value type = in.type()) 61 getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 62 63 // Otherwise, recurse into a result node. 64 } else if (auto resultOp = val.getDefiningOp<pdl::ResultOp>()) { 65 OperationPosition *parentPos = builder.getParent(pos); 66 Position *resultPos = builder.getResult(parentPos, resultOp.index()); 67 predList.emplace_back(parentPos, builder.getIsNotNull()); 68 predList.emplace_back(resultPos, builder.getEqualTo(pos)); 69 getTreePredicates(predList, resultOp.parent(), builder, inputs, parentPos); 70 } 71 } 72 73 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 74 Value val, PredicateBuilder &builder, 75 DenseMap<Value, Position *> &inputs, 76 OperationPosition *pos) { 77 assert(val.getType().isa<pdl::OperationType>() && "expected operation"); 78 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 79 OperationPosition *opPos = cast<OperationPosition>(pos); 80 81 // Ensure getDefiningOp returns a non-null operation. 82 if (!opPos->isRoot()) 83 predList.emplace_back(pos, builder.getIsNotNull()); 84 85 // Check that this is the correct root operation. 86 if (Optional<StringRef> opName = op.name()) 87 predList.emplace_back(pos, builder.getOperationName(*opName)); 88 89 // Check that the operation has the proper number of operands and results. 90 OperandRange operands = op.operands(); 91 OperandRange types = op.types(); 92 predList.emplace_back(pos, builder.getOperandCount(operands.size())); 93 predList.emplace_back(pos, builder.getResultCount(types.size())); 94 95 // Recurse into any attributes, operands, or results. 96 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 97 getTreePredicates( 98 predList, std::get<1>(it), builder, inputs, 99 builder.getAttribute(opPos, 100 std::get<0>(it).cast<StringAttr>().getValue())); 101 } 102 for (auto operandIt : llvm::enumerate(operands)) { 103 getTreePredicates(predList, operandIt.value(), builder, inputs, 104 builder.getOperand(opPos, operandIt.index())); 105 } 106 for (auto &resultIt : llvm::enumerate(types)) { 107 auto *resultPos = builder.getResult(pos, resultIt.index()); 108 predList.emplace_back(resultPos, builder.getIsNotNull()); 109 getTreePredicates(predList, resultIt.value(), builder, inputs, 110 builder.getType(resultPos)); 111 } 112 } 113 114 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 115 Value val, PredicateBuilder &builder, 116 DenseMap<Value, Position *> &inputs, 117 TypePosition *pos) { 118 assert(val.getType().isa<pdl::TypeType>() && "expected value type"); 119 pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp()); 120 121 // Check for a constraint on a constant type. 122 if (Optional<Type> type = typeOp.type()) 123 predList.emplace_back(pos, builder.getTypeConstraint(*type)); 124 } 125 126 /// Collect the tree predicates anchored at the given value. 127 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 128 Value val, PredicateBuilder &builder, 129 DenseMap<Value, Position *> &inputs, 130 Position *pos) { 131 // Make sure this input value is accessible to the rewrite. 132 auto it = inputs.try_emplace(val, pos); 133 if (!it.second) { 134 // If this is an input value that has been visited in the tree, add a 135 // constraint to ensure that both instances refer to the same value. 136 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperationOp, pdl::TypeOp>( 137 val.getDefiningOp())) { 138 auto minMaxPositions = 139 std::minmax(pos, it.first->second, comparePosDepth); 140 predList.emplace_back(minMaxPositions.second, 141 builder.getEqualTo(minMaxPositions.first)); 142 } 143 return; 144 } 145 146 TypeSwitch<Position *>(pos) 147 .Case<AttributePosition, OperandPosition, OperationPosition, 148 TypePosition>([&](auto *derivedPos) { 149 getTreePredicates(predList, val, builder, inputs, derivedPos); 150 }) 151 .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); 152 } 153 154 /// Collect all of the predicates related to constraints within the given 155 /// pattern operation. 156 static void getConstraintPredicates(pdl::ApplyConstraintOp op, 157 std::vector<PositionalPredicate> &predList, 158 PredicateBuilder &builder, 159 DenseMap<Value, Position *> &inputs) { 160 OperandRange arguments = op.args(); 161 ArrayAttr parameters = op.constParamsAttr(); 162 163 std::vector<Position *> allPositions; 164 allPositions.reserve(arguments.size()); 165 for (Value arg : arguments) 166 allPositions.push_back(inputs.lookup(arg)); 167 168 // Push the constraint to the furthest position. 169 Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), 170 comparePosDepth); 171 PredicateBuilder::Predicate pred = 172 builder.getConstraint(op.name(), std::move(allPositions), parameters); 173 predList.emplace_back(pos, pred); 174 } 175 176 static void getResultPredicates(pdl::ResultOp op, 177 std::vector<PositionalPredicate> &predList, 178 PredicateBuilder &builder, 179 DenseMap<Value, Position *> &inputs) { 180 Position *&resultPos = inputs[op]; 181 if (resultPos) 182 return; 183 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent())); 184 resultPos = builder.getResult(parentPos, op.index()); 185 predList.emplace_back(resultPos, builder.getIsNotNull()); 186 } 187 188 /// Collect all of the predicates that cannot be determined via walking the 189 /// tree. 190 static void getNonTreePredicates(pdl::PatternOp pattern, 191 std::vector<PositionalPredicate> &predList, 192 PredicateBuilder &builder, 193 DenseMap<Value, Position *> &inputs) { 194 for (Operation &op : pattern.body().getOps()) { 195 if (auto constraintOp = dyn_cast<pdl::ApplyConstraintOp>(&op)) 196 getConstraintPredicates(constraintOp, predList, builder, inputs); 197 else if (auto resultOp = dyn_cast<pdl::ResultOp>(&op)) 198 getResultPredicates(resultOp, predList, builder, inputs); 199 } 200 } 201 202 /// Given a pattern operation, build the set of matcher predicates necessary to 203 /// match this pattern. 204 static void buildPredicateList(pdl::PatternOp pattern, 205 PredicateBuilder &builder, 206 std::vector<PositionalPredicate> &predList, 207 DenseMap<Value, Position *> &valueToPosition) { 208 getTreePredicates(predList, pattern.getRewriter().root(), builder, 209 valueToPosition, builder.getRoot()); 210 getNonTreePredicates(pattern, predList, builder, valueToPosition); 211 } 212 213 //===----------------------------------------------------------------------===// 214 // Pattern Predicate Tree Merging 215 //===----------------------------------------------------------------------===// 216 217 namespace { 218 219 /// This class represents a specific predicate applied to a position, and 220 /// provides hashing and ordering operators. This class allows for computing a 221 /// frequence sum and ordering predicates based on a cost model. 222 struct OrderedPredicate { 223 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 224 : position(ip.first), question(ip.second) {} 225 OrderedPredicate(const PositionalPredicate &ip) 226 : position(ip.position), question(ip.question) {} 227 228 /// The position this predicate is applied to. 229 Position *position; 230 231 /// The question that is applied by this predicate onto the position. 232 Qualifier *question; 233 234 /// The first and second order benefit sums. 235 /// The primary sum is the number of occurrences of this predicate among all 236 /// of the patterns. 237 unsigned primary = 0; 238 /// The secondary sum is a squared summation of the primary sum of all of the 239 /// predicates within each pattern that contains this predicate. This allows 240 /// for favoring predicates that are more commonly shared within a pattern, as 241 /// opposed to those shared across patterns. 242 unsigned secondary = 0; 243 244 /// A map between a pattern operation and the answer to the predicate question 245 /// within that pattern. 246 DenseMap<Operation *, Qualifier *> patternToAnswer; 247 248 /// Returns true if this predicate is ordered before `rhs`, based on the cost 249 /// model. 250 bool operator<(const OrderedPredicate &rhs) const { 251 // Sort by: 252 // * higher first and secondary order sums 253 // * lower depth 254 // * lower position dependency 255 // * lower predicate dependency 256 auto *rhsPos = rhs.position; 257 return std::make_tuple(primary, secondary, rhsPos->getIndex().size(), 258 rhsPos->getKind(), rhs.question->getKind()) > 259 std::make_tuple(rhs.primary, rhs.secondary, 260 position->getIndex().size(), position->getKind(), 261 question->getKind()); 262 } 263 }; 264 265 /// A DenseMapInfo for OrderedPredicate based solely on the position and 266 /// question. 267 struct OrderedPredicateDenseInfo { 268 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 269 270 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 271 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 272 static bool isEqual(const OrderedPredicate &lhs, 273 const OrderedPredicate &rhs) { 274 return lhs.position == rhs.position && lhs.question == rhs.question; 275 } 276 static unsigned getHashValue(const OrderedPredicate &p) { 277 return llvm::hash_combine(p.position, p.question); 278 } 279 }; 280 281 /// This class wraps a set of ordered predicates that are used within a specific 282 /// pattern operation. 283 struct OrderedPredicateList { 284 OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} 285 286 pdl::PatternOp pattern; 287 DenseSet<OrderedPredicate *> predicates; 288 }; 289 } // end anonymous namespace 290 291 /// Returns true if the given matcher refers to the same predicate as the given 292 /// ordered predicate. This means that the position and questions of the two 293 /// match. 294 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 295 return node->getPosition() == predicate->position && 296 node->getQuestion() == predicate->question; 297 } 298 299 /// Get or insert a child matcher for the given parent switch node, given a 300 /// predicate and parent pattern. 301 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 302 OrderedPredicate *predicate, 303 pdl::PatternOp pattern) { 304 assert(isSamePredicate(node, predicate) && 305 "expected matcher to equal the given predicate"); 306 307 auto it = predicate->patternToAnswer.find(pattern); 308 assert(it != predicate->patternToAnswer.end() && 309 "expected pattern to exist in predicate"); 310 return node->getChildren().insert({it->second, nullptr}).first->second; 311 } 312 313 /// Build the matcher CFG by "pushing" patterns through by sorted predicate 314 /// order. A pattern will traverse as far as possible using common predicates 315 /// and then either diverge from the CFG or reach the end of a branch and start 316 /// creating new nodes. 317 static void propagatePattern(std::unique_ptr<MatcherNode> &node, 318 OrderedPredicateList &list, 319 std::vector<OrderedPredicate *>::iterator current, 320 std::vector<OrderedPredicate *>::iterator end) { 321 if (current == end) { 322 // We've hit the end of a pattern, so create a successful result node. 323 node = std::make_unique<SuccessNode>(list.pattern, std::move(node)); 324 325 // If the pattern doesn't contain this predicate, ignore it. 326 } else if (list.predicates.find(*current) == list.predicates.end()) { 327 propagatePattern(node, list, std::next(current), end); 328 329 // If the current matcher node is invalid, create a new one for this 330 // position and continue propagation. 331 } else if (!node) { 332 // Create a new node at this position and continue 333 node = std::make_unique<SwitchNode>((*current)->position, 334 (*current)->question); 335 propagatePattern( 336 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 337 list, std::next(current), end); 338 339 // If the matcher has already been created, and it is for this predicate we 340 // continue propagation to the child. 341 } else if (isSamePredicate(node.get(), *current)) { 342 propagatePattern( 343 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 344 list, std::next(current), end); 345 346 // If the matcher doesn't match the current predicate, insert a branch as 347 // the common set of matchers has diverged. 348 } else { 349 propagatePattern(node->getFailureNode(), list, current, end); 350 } 351 } 352 353 /// Fold any switch nodes nested under `node` to boolean nodes when possible. 354 /// `node` is updated in-place if it is a switch. 355 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 356 if (!node) 357 return; 358 359 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 360 SwitchNode::ChildMapT &children = switchNode->getChildren(); 361 for (auto &it : children) 362 foldSwitchToBool(it.second); 363 364 // If the node only contains one child, collapse it into a boolean predicate 365 // node. 366 if (children.size() == 1) { 367 auto childIt = children.begin(); 368 node = std::make_unique<BoolNode>( 369 node->getPosition(), node->getQuestion(), childIt->first, 370 std::move(childIt->second), std::move(node->getFailureNode())); 371 } 372 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 373 foldSwitchToBool(boolNode->getSuccessNode()); 374 } 375 376 foldSwitchToBool(node->getFailureNode()); 377 } 378 379 /// Insert an exit node at the end of the failure path of the `root`. 380 static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 381 while (*root) 382 root = &(*root)->getFailureNode(); 383 *root = std::make_unique<ExitNode>(); 384 } 385 386 /// Given a module containing PDL pattern operations, generate a matcher tree 387 /// using the patterns within the given module and return the root matcher node. 388 std::unique_ptr<MatcherNode> 389 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 390 DenseMap<Value, Position *> &valueToPosition) { 391 // Collect the set of predicates contained within the pattern operations of 392 // the module. 393 SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16> 394 patternsAndPredicates; 395 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 396 std::vector<PositionalPredicate> predicateList; 397 buildPredicateList(pattern, builder, predicateList, valueToPosition); 398 patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); 399 } 400 401 // Associate a pattern result with each unique predicate. 402 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 403 for (auto &patternAndPredList : patternsAndPredicates) { 404 for (auto &predicate : patternAndPredList.second) { 405 auto it = uniqued.insert(predicate); 406 it.first->patternToAnswer.try_emplace(patternAndPredList.first, 407 predicate.answer); 408 } 409 } 410 411 // Associate each pattern to a set of its ordered predicates for later lookup. 412 std::vector<OrderedPredicateList> lists; 413 lists.reserve(patternsAndPredicates.size()); 414 for (auto &patternAndPredList : patternsAndPredicates) { 415 OrderedPredicateList list(patternAndPredList.first); 416 for (auto &predicate : patternAndPredList.second) { 417 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 418 list.predicates.insert(orderedPredicate); 419 420 // Increment the primary sum for each reference to a particular predicate. 421 ++orderedPredicate->primary; 422 } 423 lists.push_back(std::move(list)); 424 } 425 426 // For a particular pattern, get the total primary sum and add it to the 427 // secondary sum of each predicate. Square the primary sums to emphasize 428 // shared predicates within rather than across patterns. 429 for (auto &list : lists) { 430 unsigned total = 0; 431 for (auto *predicate : list.predicates) 432 total += predicate->primary * predicate->primary; 433 for (auto *predicate : list.predicates) 434 predicate->secondary += total; 435 } 436 437 // Sort the set of predicates now that the cost primary and secondary sums 438 // have been computed. 439 std::vector<OrderedPredicate *> ordered; 440 ordered.reserve(uniqued.size()); 441 for (auto &ip : uniqued) 442 ordered.push_back(&ip); 443 std::stable_sort( 444 ordered.begin(), ordered.end(), 445 [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; }); 446 447 // Build the matchers for each of the pattern predicate lists. 448 std::unique_ptr<MatcherNode> root; 449 for (OrderedPredicateList &list : lists) 450 propagatePattern(root, list, ordered.begin(), ordered.end()); 451 452 // Collapse the graph and insert the exit node. 453 foldSwitchToBool(root); 454 insertExitNode(&root); 455 return root; 456 } 457 458 //===----------------------------------------------------------------------===// 459 // MatcherNode 460 //===----------------------------------------------------------------------===// 461 462 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 463 std::unique_ptr<MatcherNode> failureNode) 464 : position(p), question(q), failureNode(std::move(failureNode)), 465 matcherTypeID(matcherTypeID) {} 466 467 //===----------------------------------------------------------------------===// 468 // BoolNode 469 //===----------------------------------------------------------------------===// 470 471 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 472 std::unique_ptr<MatcherNode> successNode, 473 std::unique_ptr<MatcherNode> failureNode) 474 : MatcherNode(TypeID::get<BoolNode>(), position, question, 475 std::move(failureNode)), 476 answer(answer), successNode(std::move(successNode)) {} 477 478 //===----------------------------------------------------------------------===// 479 // SuccessNode 480 //===----------------------------------------------------------------------===// 481 482 SuccessNode::SuccessNode(pdl::PatternOp pattern, 483 std::unique_ptr<MatcherNode> failureNode) 484 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 485 /*question=*/nullptr, std::move(failureNode)), 486 pattern(pattern) {} 487 488 //===----------------------------------------------------------------------===// 489 // SwitchNode 490 //===----------------------------------------------------------------------===// 491 492 SwitchNode::SwitchNode(Position *position, Qualifier *question) 493 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 494