1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PredicateTree.h" 10 #include "mlir/Dialect/PDL/IR/PDL.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 16 using namespace mlir; 17 using namespace mlir::pdl_to_pdl_interp; 18 19 //===----------------------------------------------------------------------===// 20 // Predicate List Building 21 //===----------------------------------------------------------------------===// 22 23 /// Compares the depths of two positions. 24 static bool comparePosDepth(Position *lhs, Position *rhs) { 25 return lhs->getIndex().size() < rhs->getIndex().size(); 26 } 27 28 /// Collect the tree predicates anchored at the given value. 29 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 30 Value val, PredicateBuilder &builder, 31 DenseMap<Value, Position *> &inputs, 32 Position *pos) { 33 // Make sure this input value is accessible to the rewrite. 34 auto it = inputs.try_emplace(val, pos); 35 36 // If this is an input value that has been visited in the tree, add a 37 // constraint to ensure that both instances refer to the same value. 38 if (!it.second && 39 isa<pdl::AttributeOp, pdl::InputOp, pdl::TypeOp>(val.getDefiningOp())) { 40 auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth); 41 predList.emplace_back(minMaxPositions.second, 42 builder.getEqualTo(minMaxPositions.first)); 43 return; 44 } 45 46 // Check for a per-position predicate to apply. 47 switch (pos->getKind()) { 48 case Predicates::AttributePos: { 49 assert(val.getType().isa<pdl::AttributeType>() && 50 "expected attribute type"); 51 pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp()); 52 predList.emplace_back(pos, builder.getIsNotNull()); 53 54 // If the attribute has a type, add a type constraint. 55 if (Value type = attr.type()) { 56 getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 57 58 // Check for a constant value of the attribute. 59 } else if (Optional<Attribute> value = attr.value()) { 60 predList.emplace_back(pos, builder.getAttributeConstraint(*value)); 61 } 62 break; 63 } 64 case Predicates::OperandPos: { 65 assert(val.getType().isa<pdl::ValueType>() && "expected value type"); 66 67 // Prevent traversal into a null value. 68 predList.emplace_back(pos, builder.getIsNotNull()); 69 70 // If this is a typed input, add a type constraint. 71 if (auto in = val.getDefiningOp<pdl::InputOp>()) { 72 if (Value type = in.type()) { 73 getTreePredicates(predList, type, builder, inputs, 74 builder.getType(pos)); 75 } 76 77 // Otherwise, recurse into the parent node. 78 } else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) { 79 getTreePredicates(predList, parentOp.op(), builder, inputs, 80 builder.getParent(cast<OperandPosition>(pos))); 81 } 82 break; 83 } 84 case Predicates::OperationPos: { 85 assert(val.getType().isa<pdl::OperationType>() && "expected operation"); 86 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 87 OperationPosition *opPos = cast<OperationPosition>(pos); 88 89 // Ensure getDefiningOp returns a non-null operation. 90 if (!opPos->isRoot()) 91 predList.emplace_back(pos, builder.getIsNotNull()); 92 93 // Check that this is the correct root operation. 94 if (Optional<StringRef> opName = op.name()) 95 predList.emplace_back(pos, builder.getOperationName(*opName)); 96 97 // Check that the operation has the proper number of operands and results. 98 OperandRange operands = op.operands(); 99 ResultRange results = op.results(); 100 predList.emplace_back(pos, builder.getOperandCount(operands.size())); 101 predList.emplace_back(pos, builder.getResultCount(results.size())); 102 103 // Recurse into any attributes, operands, or results. 104 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 105 getTreePredicates( 106 predList, std::get<1>(it), builder, inputs, 107 builder.getAttribute(opPos, 108 std::get<0>(it).cast<StringAttr>().getValue())); 109 } 110 for (auto operandIt : llvm::enumerate(operands)) 111 getTreePredicates(predList, operandIt.value(), builder, inputs, 112 builder.getOperand(opPos, operandIt.index())); 113 114 // Only recurse into results that are not referenced in the source tree. 115 for (auto resultIt : llvm::enumerate(results)) { 116 getTreePredicates(predList, resultIt.value(), builder, inputs, 117 builder.getResult(opPos, resultIt.index())); 118 } 119 break; 120 } 121 case Predicates::ResultPos: { 122 assert(val.getType().isa<pdl::ValueType>() && "expected value type"); 123 pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp()); 124 125 // Prevent traversing a null value. 126 predList.emplace_back(pos, builder.getIsNotNull()); 127 128 // Traverse the type constraint. 129 unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber(); 130 getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs, 131 builder.getType(pos)); 132 break; 133 } 134 case Predicates::TypePos: { 135 assert(val.getType().isa<pdl::TypeType>() && "expected value type"); 136 pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp()); 137 138 // Check for a constraint on a constant type. 139 if (Optional<Type> type = typeOp.type()) 140 predList.emplace_back(pos, builder.getTypeConstraint(*type)); 141 break; 142 } 143 default: 144 llvm_unreachable("unknown position kind"); 145 } 146 } 147 148 /// Collect all of the predicates related to constraints within the given 149 /// pattern operation. 150 static void collectConstraintPredicates( 151 pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList, 152 PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) { 153 for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) { 154 OperandRange arguments = op.args(); 155 ArrayAttr parameters = op.constParamsAttr(); 156 157 std::vector<Position *> allPositions; 158 allPositions.reserve(arguments.size()); 159 for (Value arg : arguments) 160 allPositions.push_back(inputs.lookup(arg)); 161 162 // Push the constraint to the furthest position. 163 Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), 164 comparePosDepth); 165 PredicateBuilder::Predicate pred = 166 builder.getConstraint(op.name(), std::move(allPositions), parameters); 167 predList.emplace_back(pos, pred); 168 } 169 } 170 171 /// Given a pattern operation, build the set of matcher predicates necessary to 172 /// match this pattern. 173 static void buildPredicateList(pdl::PatternOp pattern, 174 PredicateBuilder &builder, 175 std::vector<PositionalPredicate> &predList, 176 DenseMap<Value, Position *> &valueToPosition) { 177 getTreePredicates(predList, pattern.getRewriter().root(), builder, 178 valueToPosition, builder.getRoot()); 179 collectConstraintPredicates(pattern, predList, builder, valueToPosition); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // Pattern Predicate Tree Merging 184 //===----------------------------------------------------------------------===// 185 186 namespace { 187 188 /// This class represents a specific predicate applied to a position, and 189 /// provides hashing and ordering operators. This class allows for computing a 190 /// frequence sum and ordering predicates based on a cost model. 191 struct OrderedPredicate { 192 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 193 : position(ip.first), question(ip.second) {} 194 OrderedPredicate(const PositionalPredicate &ip) 195 : position(ip.position), question(ip.question) {} 196 197 /// The position this predicate is applied to. 198 Position *position; 199 200 /// The question that is applied by this predicate onto the position. 201 Qualifier *question; 202 203 /// The first and second order benefit sums. 204 /// The primary sum is the number of occurrences of this predicate among all 205 /// of the patterns. 206 unsigned primary = 0; 207 /// The secondary sum is a squared summation of the primary sum of all of the 208 /// predicates within each pattern that contains this predicate. This allows 209 /// for favoring predicates that are more commonly shared within a pattern, as 210 /// opposed to those shared across patterns. 211 unsigned secondary = 0; 212 213 /// A map between a pattern operation and the answer to the predicate question 214 /// within that pattern. 215 DenseMap<Operation *, Qualifier *> patternToAnswer; 216 217 /// Returns true if this predicate is ordered before `other`, based on the 218 /// cost model. 219 bool operator<(const OrderedPredicate &other) const { 220 // Sort by: 221 // * first and secondary order sums 222 // * lower depth 223 // * position dependency 224 // * predicate dependency. 225 auto *otherPos = other.position; 226 return std::make_tuple(other.primary, other.secondary, 227 otherPos->getIndex().size(), otherPos->getKind(), 228 other.question->getKind()) > 229 std::make_tuple(primary, secondary, position->getIndex().size(), 230 position->getKind(), question->getKind()); 231 } 232 }; 233 234 /// A DenseMapInfo for OrderedPredicate based solely on the position and 235 /// question. 236 struct OrderedPredicateDenseInfo { 237 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 238 239 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 240 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 241 static bool isEqual(const OrderedPredicate &lhs, 242 const OrderedPredicate &rhs) { 243 return lhs.position == rhs.position && lhs.question == rhs.question; 244 } 245 static unsigned getHashValue(const OrderedPredicate &p) { 246 return llvm::hash_combine(p.position, p.question); 247 } 248 }; 249 250 /// This class wraps a set of ordered predicates that are used within a specific 251 /// pattern operation. 252 struct OrderedPredicateList { 253 OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} 254 255 pdl::PatternOp pattern; 256 DenseSet<OrderedPredicate *> predicates; 257 }; 258 } // end anonymous namespace 259 260 /// Returns true if the given matcher refers to the same predicate as the given 261 /// ordered predicate. This means that the position and questions of the two 262 /// match. 263 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 264 return node->getPosition() == predicate->position && 265 node->getQuestion() == predicate->question; 266 } 267 268 /// Get or insert a child matcher for the given parent switch node, given a 269 /// predicate and parent pattern. 270 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 271 OrderedPredicate *predicate, 272 pdl::PatternOp pattern) { 273 assert(isSamePredicate(node, predicate) && 274 "expected matcher to equal the given predicate"); 275 276 auto it = predicate->patternToAnswer.find(pattern); 277 assert(it != predicate->patternToAnswer.end() && 278 "expected pattern to exist in predicate"); 279 return node->getChildren().insert({it->second, nullptr}).first->second; 280 } 281 282 /// Build the matcher CFG by "pushing" patterns through by sorted predicate 283 /// order. A pattern will traverse as far as possible using common predicates 284 /// and then either diverge from the CFG or reach the end of a branch and start 285 /// creating new nodes. 286 static void propagatePattern(std::unique_ptr<MatcherNode> &node, 287 OrderedPredicateList &list, 288 std::vector<OrderedPredicate *>::iterator current, 289 std::vector<OrderedPredicate *>::iterator end) { 290 if (current == end) { 291 // We've hit the end of a pattern, so create a successful result node. 292 node = std::make_unique<SuccessNode>(list.pattern, std::move(node)); 293 294 // If the pattern doesn't contain this predicate, ignore it. 295 } else if (list.predicates.find(*current) == list.predicates.end()) { 296 propagatePattern(node, list, std::next(current), end); 297 298 // If the current matcher node is invalid, create a new one for this 299 // position and continue propagation. 300 } else if (!node) { 301 // Create a new node at this position and continue 302 node = std::make_unique<SwitchNode>((*current)->position, 303 (*current)->question); 304 propagatePattern( 305 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 306 list, std::next(current), end); 307 308 // If the matcher has already been created, and it is for this predicate we 309 // continue propagation to the child. 310 } else if (isSamePredicate(node.get(), *current)) { 311 propagatePattern( 312 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 313 list, std::next(current), end); 314 315 // If the matcher doesn't match the current predicate, insert a branch as 316 // the common set of matchers has diverged. 317 } else { 318 propagatePattern(node->getFailureNode(), list, current, end); 319 } 320 } 321 322 /// Fold any switch nodes nested under `node` to boolean nodes when possible. 323 /// `node` is updated in-place if it is a switch. 324 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 325 if (!node) 326 return; 327 328 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 329 SwitchNode::ChildMapT &children = switchNode->getChildren(); 330 for (auto &it : children) 331 foldSwitchToBool(it.second); 332 333 // If the node only contains one child, collapse it into a boolean predicate 334 // node. 335 if (children.size() == 1) { 336 auto childIt = children.begin(); 337 node = std::make_unique<BoolNode>( 338 node->getPosition(), node->getQuestion(), childIt->first, 339 std::move(childIt->second), std::move(node->getFailureNode())); 340 } 341 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 342 foldSwitchToBool(boolNode->getSuccessNode()); 343 } 344 345 foldSwitchToBool(node->getFailureNode()); 346 } 347 348 /// Insert an exit node at the end of the failure path of the `root`. 349 static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 350 while (*root) 351 root = &(*root)->getFailureNode(); 352 *root = std::make_unique<ExitNode>(); 353 } 354 355 /// Given a module containing PDL pattern operations, generate a matcher tree 356 /// using the patterns within the given module and return the root matcher node. 357 std::unique_ptr<MatcherNode> 358 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 359 DenseMap<Value, Position *> &valueToPosition) { 360 // Collect the set of predicates contained within the pattern operations of 361 // the module. 362 SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16> 363 patternsAndPredicates; 364 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 365 std::vector<PositionalPredicate> predicateList; 366 buildPredicateList(pattern, builder, predicateList, valueToPosition); 367 patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); 368 } 369 370 // Associate a pattern result with each unique predicate. 371 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 372 for (auto &patternAndPredList : patternsAndPredicates) { 373 for (auto &predicate : patternAndPredList.second) { 374 auto it = uniqued.insert(predicate); 375 it.first->patternToAnswer.try_emplace(patternAndPredList.first, 376 predicate.answer); 377 } 378 } 379 380 // Associate each pattern to a set of its ordered predicates for later lookup. 381 std::vector<OrderedPredicateList> lists; 382 lists.reserve(patternsAndPredicates.size()); 383 for (auto &patternAndPredList : patternsAndPredicates) { 384 OrderedPredicateList list(patternAndPredList.first); 385 for (auto &predicate : patternAndPredList.second) { 386 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 387 list.predicates.insert(orderedPredicate); 388 389 // Increment the primary sum for each reference to a particular predicate. 390 ++orderedPredicate->primary; 391 } 392 lists.push_back(std::move(list)); 393 } 394 395 // For a particular pattern, get the total primary sum and add it to the 396 // secondary sum of each predicate. Square the primary sums to emphasize 397 // shared predicates within rather than across patterns. 398 for (auto &list : lists) { 399 unsigned total = 0; 400 for (auto *predicate : list.predicates) 401 total += predicate->primary * predicate->primary; 402 for (auto *predicate : list.predicates) 403 predicate->secondary += total; 404 } 405 406 // Sort the set of predicates now that the cost primary and secondary sums 407 // have been computed. 408 std::vector<OrderedPredicate *> ordered; 409 ordered.reserve(uniqued.size()); 410 for (auto &ip : uniqued) 411 ordered.push_back(&ip); 412 std::stable_sort( 413 ordered.begin(), ordered.end(), 414 [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; }); 415 416 // Build the matchers for each of the pattern predicate lists. 417 std::unique_ptr<MatcherNode> root; 418 for (OrderedPredicateList &list : lists) 419 propagatePattern(root, list, ordered.begin(), ordered.end()); 420 421 // Collapse the graph and insert the exit node. 422 foldSwitchToBool(root); 423 insertExitNode(&root); 424 return root; 425 } 426 427 //===----------------------------------------------------------------------===// 428 // MatcherNode 429 //===----------------------------------------------------------------------===// 430 431 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 432 std::unique_ptr<MatcherNode> failureNode) 433 : position(p), question(q), failureNode(std::move(failureNode)), 434 matcherTypeID(matcherTypeID) {} 435 436 //===----------------------------------------------------------------------===// 437 // BoolNode 438 //===----------------------------------------------------------------------===// 439 440 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 441 std::unique_ptr<MatcherNode> successNode, 442 std::unique_ptr<MatcherNode> failureNode) 443 : MatcherNode(TypeID::get<BoolNode>(), position, question, 444 std::move(failureNode)), 445 answer(answer), successNode(std::move(successNode)) {} 446 447 //===----------------------------------------------------------------------===// 448 // SuccessNode 449 //===----------------------------------------------------------------------===// 450 451 SuccessNode::SuccessNode(pdl::PatternOp pattern, 452 std::unique_ptr<MatcherNode> failureNode) 453 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 454 /*question=*/nullptr, std::move(failureNode)), 455 pattern(pattern) {} 456 457 //===----------------------------------------------------------------------===// 458 // SwitchNode 459 //===----------------------------------------------------------------------===// 460 461 SwitchNode::SwitchNode(Position *position, Qualifier *question) 462 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 463