1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PredicateTree.h" 10 #include "RootOrdering.h" 11 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/Interfaces/InferTypeOpInterface.h" 17 #include "llvm/ADT/MapVector.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 #include "llvm/Support/Debug.h" 20 #include <queue> 21 22 #define DEBUG_TYPE "pdl-predicate-tree" 23 24 using namespace mlir; 25 using namespace mlir::pdl_to_pdl_interp; 26 27 //===----------------------------------------------------------------------===// 28 // Predicate List Building 29 //===----------------------------------------------------------------------===// 30 31 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 32 Value val, PredicateBuilder &builder, 33 DenseMap<Value, Position *> &inputs, 34 Position *pos); 35 36 /// Compares the depths of two positions. 37 static bool comparePosDepth(Position *lhs, Position *rhs) { 38 return lhs->getOperationDepth() < rhs->getOperationDepth(); 39 } 40 41 /// Returns the number of non-range elements within `values`. 42 static unsigned getNumNonRangeValues(ValueRange values) { 43 return llvm::count_if(values.getTypes(), 44 [](Type type) { return !type.isa<pdl::RangeType>(); }); 45 } 46 47 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 48 Value val, PredicateBuilder &builder, 49 DenseMap<Value, Position *> &inputs, 50 AttributePosition *pos) { 51 assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type"); 52 pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp()); 53 predList.emplace_back(pos, builder.getIsNotNull()); 54 55 // If the attribute has a type or value, add a constraint. 56 if (Value type = attr.type()) 57 getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 58 else if (Attribute value = attr.valueAttr()) 59 predList.emplace_back(pos, builder.getAttributeConstraint(value)); 60 } 61 62 /// Collect all of the predicates for the given operand position. 63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList, 64 Value val, PredicateBuilder &builder, 65 DenseMap<Value, Position *> &inputs, 66 Position *pos) { 67 Type valueType = val.getType(); 68 bool isVariadic = valueType.isa<pdl::RangeType>(); 69 70 // If this is a typed operand, add a type constraint. 71 TypeSwitch<Operation *>(val.getDefiningOp()) 72 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) { 73 // Prevent traversal into a null value if the operand has a proper 74 // index. 75 if (std::is_same<pdl::OperandOp, decltype(op)>::value || 76 cast<OperandGroupPosition>(pos)->getOperandGroupNumber()) 77 predList.emplace_back(pos, builder.getIsNotNull()); 78 79 if (Value type = op.type()) 80 getTreePredicates(predList, type, builder, inputs, 81 builder.getType(pos)); 82 }) 83 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) { 84 Optional<unsigned> index = op.index(); 85 86 // Prevent traversal into a null value if the result has a proper index. 87 if (index) 88 predList.emplace_back(pos, builder.getIsNotNull()); 89 90 // Get the parent operation of this operand. 91 OperationPosition *parentPos = builder.getOperandDefiningOp(pos); 92 predList.emplace_back(parentPos, builder.getIsNotNull()); 93 94 // Ensure that the operands match the corresponding results of the 95 // parent operation. 96 Position *resultPos = nullptr; 97 if (std::is_same<pdl::ResultOp, decltype(op)>::value) 98 resultPos = builder.getResult(parentPos, *index); 99 else 100 resultPos = builder.getResultGroup(parentPos, index, isVariadic); 101 predList.emplace_back(resultPos, builder.getEqualTo(pos)); 102 103 // Collect the predicates of the parent operation. 104 getTreePredicates(predList, op.parent(), builder, inputs, 105 (Position *)parentPos); 106 }); 107 } 108 109 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 110 Value val, PredicateBuilder &builder, 111 DenseMap<Value, Position *> &inputs, 112 OperationPosition *pos, 113 Optional<unsigned> ignoreOperand = llvm::None) { 114 assert(val.getType().isa<pdl::OperationType>() && "expected operation"); 115 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 116 OperationPosition *opPos = cast<OperationPosition>(pos); 117 118 // Ensure getDefiningOp returns a non-null operation. 119 if (!opPos->isRoot()) 120 predList.emplace_back(pos, builder.getIsNotNull()); 121 122 // Check that this is the correct root operation. 123 if (Optional<StringRef> opName = op.name()) 124 predList.emplace_back(pos, builder.getOperationName(*opName)); 125 126 // Check that the operation has the proper number of operands. If there are 127 // any variable length operands, we check a minimum instead of an exact count. 128 OperandRange operands = op.operands(); 129 unsigned minOperands = getNumNonRangeValues(operands); 130 if (minOperands != operands.size()) { 131 if (minOperands) 132 predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands)); 133 } else { 134 predList.emplace_back(pos, builder.getOperandCount(minOperands)); 135 } 136 137 // Check that the operation has the proper number of results. If there are 138 // any variable length results, we check a minimum instead of an exact count. 139 OperandRange types = op.types(); 140 unsigned minResults = getNumNonRangeValues(types); 141 if (minResults == types.size()) 142 predList.emplace_back(pos, builder.getResultCount(types.size())); 143 else if (minResults) 144 predList.emplace_back(pos, builder.getResultCountAtLeast(minResults)); 145 146 // Recurse into any attributes, operands, or results. 147 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 148 getTreePredicates( 149 predList, std::get<1>(it), builder, inputs, 150 builder.getAttribute(opPos, 151 std::get<0>(it).cast<StringAttr>().getValue())); 152 } 153 154 // Process the operands and results of the operation. For all values up to 155 // the first variable length value, we use the concrete operand/result 156 // number. After that, we use the "group" given that we can't know the 157 // concrete indices until runtime. If there is only one variadic operand 158 // group, we treat it as all of the operands/results of the operation. 159 /// Operands. 160 if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) { 161 getTreePredicates(predList, operands.front(), builder, inputs, 162 builder.getAllOperands(opPos)); 163 } else { 164 bool foundVariableLength = false; 165 for (auto operandIt : llvm::enumerate(operands)) { 166 bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>(); 167 foundVariableLength |= isVariadic; 168 169 // Ignore the specified operand, usually because this position was 170 // visited in an upward traversal via an iterative choice. 171 if (ignoreOperand && *ignoreOperand == operandIt.index()) 172 continue; 173 174 Position *pos = 175 foundVariableLength 176 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) 177 : builder.getOperand(opPos, operandIt.index()); 178 getTreePredicates(predList, operandIt.value(), builder, inputs, pos); 179 } 180 } 181 /// Results. 182 if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) { 183 getTreePredicates(predList, types.front(), builder, inputs, 184 builder.getType(builder.getAllResults(opPos))); 185 } else { 186 bool foundVariableLength = false; 187 for (auto &resultIt : llvm::enumerate(types)) { 188 bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>(); 189 foundVariableLength |= isVariadic; 190 191 auto *resultPos = 192 foundVariableLength 193 ? builder.getResultGroup(pos, resultIt.index(), isVariadic) 194 : builder.getResult(pos, resultIt.index()); 195 predList.emplace_back(resultPos, builder.getIsNotNull()); 196 getTreePredicates(predList, resultIt.value(), builder, inputs, 197 builder.getType(resultPos)); 198 } 199 } 200 } 201 202 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 203 Value val, PredicateBuilder &builder, 204 DenseMap<Value, Position *> &inputs, 205 TypePosition *pos) { 206 // Check for a constraint on a constant type. 207 if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) { 208 if (Attribute type = typeOp.typeAttr()) 209 predList.emplace_back(pos, builder.getTypeConstraint(type)); 210 } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) { 211 if (Attribute typeAttr = typeOp.typesAttr()) 212 predList.emplace_back(pos, builder.getTypeConstraint(typeAttr)); 213 } 214 } 215 216 /// Collect the tree predicates anchored at the given value. 217 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 218 Value val, PredicateBuilder &builder, 219 DenseMap<Value, Position *> &inputs, 220 Position *pos) { 221 // Make sure this input value is accessible to the rewrite. 222 auto it = inputs.try_emplace(val, pos); 223 if (!it.second) { 224 // If this is an input value that has been visited in the tree, add a 225 // constraint to ensure that both instances refer to the same value. 226 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp, 227 pdl::TypeOp>(val.getDefiningOp())) { 228 auto minMaxPositions = 229 std::minmax(pos, it.first->second, comparePosDepth); 230 predList.emplace_back(minMaxPositions.second, 231 builder.getEqualTo(minMaxPositions.first)); 232 } 233 return; 234 } 235 236 TypeSwitch<Position *>(pos) 237 .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) { 238 getTreePredicates(predList, val, builder, inputs, pos); 239 }) 240 .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) { 241 getOperandTreePredicates(predList, val, builder, inputs, pos); 242 }) 243 .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); 244 } 245 246 /// Collect all of the predicates related to constraints within the given 247 /// pattern operation. 248 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, 249 std::vector<PositionalPredicate> &predList, 250 PredicateBuilder &builder, 251 DenseMap<Value, Position *> &inputs) { 252 OperandRange arguments = op.args(); 253 ArrayAttr parameters = op.constParamsAttr(); 254 255 std::vector<Position *> allPositions; 256 allPositions.reserve(arguments.size()); 257 for (Value arg : arguments) 258 allPositions.push_back(inputs.lookup(arg)); 259 260 // Push the constraint to the furthest position. 261 Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), 262 comparePosDepth); 263 PredicateBuilder::Predicate pred = 264 builder.getConstraint(op.name(), std::move(allPositions), parameters); 265 predList.emplace_back(pos, pred); 266 } 267 268 static void getResultPredicates(pdl::ResultOp op, 269 std::vector<PositionalPredicate> &predList, 270 PredicateBuilder &builder, 271 DenseMap<Value, Position *> &inputs) { 272 Position *&resultPos = inputs[op]; 273 if (resultPos) 274 return; 275 276 // Ensure that the result isn't null. 277 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent())); 278 resultPos = builder.getResult(parentPos, op.index()); 279 predList.emplace_back(resultPos, builder.getIsNotNull()); 280 } 281 282 static void getResultPredicates(pdl::ResultsOp op, 283 std::vector<PositionalPredicate> &predList, 284 PredicateBuilder &builder, 285 DenseMap<Value, Position *> &inputs) { 286 Position *&resultPos = inputs[op]; 287 if (resultPos) 288 return; 289 290 // Ensure that the result isn't null if the result has an index. 291 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent())); 292 bool isVariadic = op.getType().isa<pdl::RangeType>(); 293 Optional<unsigned> index = op.index(); 294 resultPos = builder.getResultGroup(parentPos, index, isVariadic); 295 if (index) 296 predList.emplace_back(resultPos, builder.getIsNotNull()); 297 } 298 299 /// Collect all of the predicates that cannot be determined via walking the 300 /// tree. 301 static void getNonTreePredicates(pdl::PatternOp pattern, 302 std::vector<PositionalPredicate> &predList, 303 PredicateBuilder &builder, 304 DenseMap<Value, Position *> &inputs) { 305 for (Operation &op : pattern.body().getOps()) { 306 TypeSwitch<Operation *>(&op) 307 .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) { 308 getConstraintPredicates(constraintOp, predList, builder, inputs); 309 }) 310 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 311 getResultPredicates(resultOp, predList, builder, inputs); 312 }); 313 } 314 } 315 316 namespace { 317 318 /// An op accepting a value at an optional index. 319 struct OpIndex { 320 Value parent; 321 Optional<unsigned> index; 322 }; 323 324 /// The parent and operand index of each operation for each root, stored 325 /// as a nested map [root][operation]. 326 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>; 327 328 } // namespace 329 330 /// Given a pattern, determines the set of roots present in this pattern. 331 /// These are the operations whose results are not consumed by other operations. 332 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) { 333 // First, collect all the operations that are used as operands 334 // to other operations. These are not roots by default. 335 DenseSet<Value> used; 336 for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) { 337 for (Value operand : operationOp.operands()) 338 TypeSwitch<Operation *>(operand.getDefiningOp()) 339 .Case<pdl::ResultOp, pdl::ResultsOp>( 340 [&used](auto resultOp) { used.insert(resultOp.parent()); }); 341 } 342 343 // Remove the specified root from the use set, so that we can 344 // always select it as a root, even if it is used by other operations. 345 if (Value root = pattern.getRewriter().root()) 346 used.erase(root); 347 348 // Finally, collect all the unused operations. 349 SmallVector<Value> roots; 350 for (Value operationOp : pattern.body().getOps<pdl::OperationOp>()) 351 if (!used.contains(operationOp)) 352 roots.push_back(operationOp); 353 354 return roots; 355 } 356 357 /// Given a list of candidate roots, builds the cost graph for connecting them. 358 /// The graph is formed by traversing the DAG of operations starting from each 359 /// root and marking the depth of each connector value (operand). Then we join 360 /// the candidate roots based on the common connector values, taking the one 361 /// with the minimum depth. Along the way, we compute, for each candidate root, 362 /// a mapping from each operation (in the DAG underneath this root) to its 363 /// parent operation and the corresponding operand index. 364 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph, 365 ParentMaps &parentMaps) { 366 367 // The entry of a queue. The entry consists of the following items: 368 // * the value in the DAG underneath the root; 369 // * the parent of the value; 370 // * the operand index of the value in its parent; 371 // * the depth of the visited value. 372 struct Entry { 373 Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth) 374 : value(value), parent(parent), index(index), depth(depth) {} 375 376 Value value; 377 Value parent; 378 Optional<unsigned> index; 379 unsigned depth; 380 }; 381 382 // A root of a value and its depth (distance from root to the value). 383 struct RootDepth { 384 Value root; 385 unsigned depth = 0; 386 }; 387 388 // Map from candidate connector values to their roots and depths. Using a 389 // small vector with 1 entry because most values belong to a single root. 390 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths; 391 392 // Perform a breadth-first traversal of the op DAG rooted at each root. 393 for (Value root : roots) { 394 // The queue of visited values. A value may be present multiple times in 395 // the queue, for multiple parents. We only accept the first occurrence, 396 // which is guaranteed to have the lowest depth. 397 std::queue<Entry> toVisit; 398 toVisit.emplace(root, Value(), 0, 0); 399 400 // The map from value to its parent for the current root. 401 DenseMap<Value, OpIndex> &parentMap = parentMaps[root]; 402 403 while (!toVisit.empty()) { 404 Entry entry = toVisit.front(); 405 toVisit.pop(); 406 // Skip if already visited. 407 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) 408 continue; 409 410 // Mark the root and depth of the value. 411 connectorsRootsDepths[entry.value].push_back({root, entry.depth}); 412 413 // Traverse the operands of an operation and result ops. 414 // We intentionally do not traverse attributes and types, because those 415 // are expensive to join on. 416 TypeSwitch<Operation *>(entry.value.getDefiningOp()) 417 .Case<pdl::OperationOp>([&](auto operationOp) { 418 OperandRange operands = operationOp.operands(); 419 // Special case when we pass all the operands in one range. 420 // For those, the index is empty. 421 if (operands.size() == 1 && 422 operands[0].getType().isa<pdl::RangeType>()) { 423 toVisit.emplace(operands[0], entry.value, llvm::None, 424 entry.depth + 1); 425 return; 426 } 427 428 // Default case: visit all the operands. 429 for (auto p : llvm::enumerate(operationOp.operands())) 430 toVisit.emplace(p.value(), entry.value, p.index(), 431 entry.depth + 1); 432 }) 433 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 434 toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(), 435 entry.depth); 436 }); 437 } 438 } 439 440 // Now build the cost graph. 441 // This is simply a minimum over all depths for the target root. 442 unsigned nextID = 0; 443 for (const auto &connectorRootsDepths : connectorsRootsDepths) { 444 Value value = connectorRootsDepths.first; 445 ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second; 446 // If there is only one root for this value, this will not trigger 447 // any edges in the cost graph (a perf optimization). 448 if (rootsDepths.size() == 1) 449 continue; 450 451 for (const RootDepth &p : rootsDepths) { 452 for (const RootDepth &q : rootsDepths) { 453 if (&p == &q) 454 continue; 455 // Insert or retrieve the property of edge from p to q. 456 RootOrderingCost &cost = graph[q.root][p.root]; 457 if (!cost.connector /* new edge */ || cost.cost.first > q.depth) { 458 if (!cost.connector) 459 cost.cost.second = nextID++; 460 cost.cost.first = q.depth; 461 cost.connector = value; 462 } 463 } 464 } 465 } 466 467 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && 468 "the pattern contains a candidate root disconnected from the others"); 469 } 470 471 /// Visit a node during upward traversal. 472 void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex, 473 PredicateBuilder &builder, 474 DenseMap<Value, Position *> &valueToPosition, Position *&pos, 475 bool &first) { 476 Value value = opIndex.parent; 477 TypeSwitch<Operation *>(value.getDefiningOp()) 478 .Case<pdl::OperationOp>([&](auto operationOp) { 479 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 480 OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index); 481 482 // Guard against traversing back to where we came from. 483 if (first) { 484 Position *parent = pos->getParent(); 485 predList.emplace_back(opPos, builder.getNotEqualTo(parent)); 486 first = false; 487 } 488 489 // Guard against duplicate upward visits. These are not possible, 490 // because if this value was already visited, it would have been 491 // cheaper to start the traversal at this value rather than at the 492 // `connector`, violating the optimality of our spanning tree. 493 bool inserted = valueToPosition.try_emplace(value, opPos).second; 494 (void)inserted; 495 assert(inserted && "duplicate upward visit"); 496 497 // Obtain the tree predicates at the current value. 498 getTreePredicates(predList, value, builder, valueToPosition, opPos, 499 opIndex.index); 500 501 // Update the position 502 pos = opPos; 503 }) 504 .Case<pdl::ResultOp>([&](auto resultOp) { 505 // Traverse up an individual result. 506 auto *opPos = dyn_cast<OperationPosition>(pos); 507 assert(opPos && "operations and results must be interleaved"); 508 pos = builder.getResult(opPos, *opIndex.index); 509 }) 510 .Case<pdl::ResultsOp>([&](auto resultOp) { 511 // Traverse up a group of results. 512 auto *opPos = dyn_cast<OperationPosition>(pos); 513 assert(opPos && "operations and results must be interleaved"); 514 bool isVariadic = value.getType().isa<pdl::RangeType>(); 515 if (opIndex.index) 516 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); 517 else 518 pos = builder.getAllResults(opPos); 519 }); 520 } 521 522 /// Given a pattern operation, build the set of matcher predicates necessary to 523 /// match this pattern. 524 static Value buildPredicateList(pdl::PatternOp pattern, 525 PredicateBuilder &builder, 526 std::vector<PositionalPredicate> &predList, 527 DenseMap<Value, Position *> &valueToPosition) { 528 SmallVector<Value> roots = detectRoots(pattern); 529 530 // Build the root ordering graph and compute the parent maps. 531 RootOrderingGraph graph; 532 ParentMaps parentMaps; 533 buildCostGraph(roots, graph, parentMaps); 534 LLVM_DEBUG({ 535 llvm::dbgs() << "Graph:\n"; 536 for (auto &target : graph) { 537 llvm::dbgs() << " * " << target.first << "\n"; 538 for (auto &source : target.second) { 539 RootOrderingCost c = source.second; 540 llvm::dbgs() << " <- " << source.first << ": " << c.cost.first 541 << ":" << c.cost.second << " via " << c.connector.getLoc() 542 << "\n"; 543 } 544 } 545 }); 546 547 // Solve the optimal branching problem for each candidate root, or use the 548 // provided one. 549 Value bestRoot = pattern.getRewriter().root(); 550 OptimalBranching::EdgeList bestEdges; 551 if (!bestRoot) { 552 unsigned bestCost = 0; 553 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); 554 for (Value root : roots) { 555 OptimalBranching solver(graph, root); 556 unsigned cost = solver.solve(); 557 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); 558 if (!bestRoot || bestCost > cost) { 559 bestCost = cost; 560 bestRoot = root; 561 bestEdges = solver.preOrderTraversal(roots); 562 } 563 } 564 } else { 565 OptimalBranching solver(graph, bestRoot); 566 solver.solve(); 567 bestEdges = solver.preOrderTraversal(roots); 568 } 569 570 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); 571 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); 572 573 // The best root is the starting point for the traversal. Get the tree 574 // predicates for the DAG rooted at bestRoot. 575 getTreePredicates(predList, bestRoot, builder, valueToPosition, 576 builder.getRoot()); 577 578 // Traverse the selected optimal branching. For all edges in order, traverse 579 // up starting from the connector, until the candidate root is reached, and 580 // call getTreePredicates at every node along the way. 581 for (const std::pair<Value, Value> &edge : bestEdges) { 582 Value target = edge.first; 583 Value source = edge.second; 584 585 // Check if we already visited the target root. This happens in two cases: 586 // 1) the initial root (bestRoot); 587 // 2) a root that is dominated by (contained in the subtree rooted at) an 588 // already visited root. 589 if (valueToPosition.count(target)) 590 continue; 591 592 // Determine the connector. 593 Value connector = graph[target][source].connector; 594 assert(connector && "invalid edge"); 595 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); 596 DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target); 597 Position *pos = valueToPosition.lookup(connector); 598 assert(pos && "The value has not been traversed yet"); 599 bool first = true; 600 601 // Traverse from the connector upwards towards the target root. 602 for (Value value = connector; value != target;) { 603 OpIndex opIndex = parentMap.lookup(value); 604 assert(opIndex.parent && "missing parent"); 605 visitUpward(predList, opIndex, builder, valueToPosition, pos, first); 606 value = opIndex.parent; 607 } 608 } 609 610 getNonTreePredicates(pattern, predList, builder, valueToPosition); 611 612 return bestRoot; 613 } 614 615 //===----------------------------------------------------------------------===// 616 // Pattern Predicate Tree Merging 617 //===----------------------------------------------------------------------===// 618 619 namespace { 620 621 /// This class represents a specific predicate applied to a position, and 622 /// provides hashing and ordering operators. This class allows for computing a 623 /// frequence sum and ordering predicates based on a cost model. 624 struct OrderedPredicate { 625 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 626 : position(ip.first), question(ip.second) {} 627 OrderedPredicate(const PositionalPredicate &ip) 628 : position(ip.position), question(ip.question) {} 629 630 /// The position this predicate is applied to. 631 Position *position; 632 633 /// The question that is applied by this predicate onto the position. 634 Qualifier *question; 635 636 /// The first and second order benefit sums. 637 /// The primary sum is the number of occurrences of this predicate among all 638 /// of the patterns. 639 unsigned primary = 0; 640 /// The secondary sum is a squared summation of the primary sum of all of the 641 /// predicates within each pattern that contains this predicate. This allows 642 /// for favoring predicates that are more commonly shared within a pattern, as 643 /// opposed to those shared across patterns. 644 unsigned secondary = 0; 645 646 /// A map between a pattern operation and the answer to the predicate question 647 /// within that pattern. 648 DenseMap<Operation *, Qualifier *> patternToAnswer; 649 650 /// Returns true if this predicate is ordered before `rhs`, based on the cost 651 /// model. 652 bool operator<(const OrderedPredicate &rhs) const { 653 // Sort by: 654 // * higher first and secondary order sums 655 // * lower depth 656 // * lower position dependency 657 // * lower predicate dependency 658 auto *rhsPos = rhs.position; 659 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), 660 rhsPos->getKind(), rhs.question->getKind()) > 661 std::make_tuple(rhs.primary, rhs.secondary, 662 position->getOperationDepth(), position->getKind(), 663 question->getKind()); 664 } 665 }; 666 667 /// A DenseMapInfo for OrderedPredicate based solely on the position and 668 /// question. 669 struct OrderedPredicateDenseInfo { 670 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 671 672 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 673 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 674 static bool isEqual(const OrderedPredicate &lhs, 675 const OrderedPredicate &rhs) { 676 return lhs.position == rhs.position && lhs.question == rhs.question; 677 } 678 static unsigned getHashValue(const OrderedPredicate &p) { 679 return llvm::hash_combine(p.position, p.question); 680 } 681 }; 682 683 /// This class wraps a set of ordered predicates that are used within a specific 684 /// pattern operation. 685 struct OrderedPredicateList { 686 OrderedPredicateList(pdl::PatternOp pattern, Value root) 687 : pattern(pattern), root(root) {} 688 689 pdl::PatternOp pattern; 690 Value root; 691 DenseSet<OrderedPredicate *> predicates; 692 }; 693 } // namespace 694 695 /// Returns true if the given matcher refers to the same predicate as the given 696 /// ordered predicate. This means that the position and questions of the two 697 /// match. 698 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 699 return node->getPosition() == predicate->position && 700 node->getQuestion() == predicate->question; 701 } 702 703 /// Get or insert a child matcher for the given parent switch node, given a 704 /// predicate and parent pattern. 705 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 706 OrderedPredicate *predicate, 707 pdl::PatternOp pattern) { 708 assert(isSamePredicate(node, predicate) && 709 "expected matcher to equal the given predicate"); 710 711 auto it = predicate->patternToAnswer.find(pattern); 712 assert(it != predicate->patternToAnswer.end() && 713 "expected pattern to exist in predicate"); 714 return node->getChildren().insert({it->second, nullptr}).first->second; 715 } 716 717 /// Build the matcher CFG by "pushing" patterns through by sorted predicate 718 /// order. A pattern will traverse as far as possible using common predicates 719 /// and then either diverge from the CFG or reach the end of a branch and start 720 /// creating new nodes. 721 static void propagatePattern(std::unique_ptr<MatcherNode> &node, 722 OrderedPredicateList &list, 723 std::vector<OrderedPredicate *>::iterator current, 724 std::vector<OrderedPredicate *>::iterator end) { 725 if (current == end) { 726 // We've hit the end of a pattern, so create a successful result node. 727 node = 728 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node)); 729 730 // If the pattern doesn't contain this predicate, ignore it. 731 } else if (list.predicates.find(*current) == list.predicates.end()) { 732 propagatePattern(node, list, std::next(current), end); 733 734 // If the current matcher node is invalid, create a new one for this 735 // position and continue propagation. 736 } else if (!node) { 737 // Create a new node at this position and continue 738 node = std::make_unique<SwitchNode>((*current)->position, 739 (*current)->question); 740 propagatePattern( 741 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 742 list, std::next(current), end); 743 744 // If the matcher has already been created, and it is for this predicate we 745 // continue propagation to the child. 746 } else if (isSamePredicate(node.get(), *current)) { 747 propagatePattern( 748 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 749 list, std::next(current), end); 750 751 // If the matcher doesn't match the current predicate, insert a branch as 752 // the common set of matchers has diverged. 753 } else { 754 propagatePattern(node->getFailureNode(), list, current, end); 755 } 756 } 757 758 /// Fold any switch nodes nested under `node` to boolean nodes when possible. 759 /// `node` is updated in-place if it is a switch. 760 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 761 if (!node) 762 return; 763 764 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 765 SwitchNode::ChildMapT &children = switchNode->getChildren(); 766 for (auto &it : children) 767 foldSwitchToBool(it.second); 768 769 // If the node only contains one child, collapse it into a boolean predicate 770 // node. 771 if (children.size() == 1) { 772 auto childIt = children.begin(); 773 node = std::make_unique<BoolNode>( 774 node->getPosition(), node->getQuestion(), childIt->first, 775 std::move(childIt->second), std::move(node->getFailureNode())); 776 } 777 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 778 foldSwitchToBool(boolNode->getSuccessNode()); 779 } 780 781 foldSwitchToBool(node->getFailureNode()); 782 } 783 784 /// Insert an exit node at the end of the failure path of the `root`. 785 static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 786 while (*root) 787 root = &(*root)->getFailureNode(); 788 *root = std::make_unique<ExitNode>(); 789 } 790 791 /// Given a module containing PDL pattern operations, generate a matcher tree 792 /// using the patterns within the given module and return the root matcher node. 793 std::unique_ptr<MatcherNode> 794 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 795 DenseMap<Value, Position *> &valueToPosition) { 796 // The set of predicates contained within the pattern operations of the 797 // module. 798 struct PatternPredicates { 799 PatternPredicates(pdl::PatternOp pattern, Value root, 800 std::vector<PositionalPredicate> predicates) 801 : pattern(pattern), root(root), predicates(std::move(predicates)) {} 802 803 /// A pattern. 804 pdl::PatternOp pattern; 805 806 /// A root of the pattern chosen among the candidate roots in pdl.rewrite. 807 Value root; 808 809 /// The extracted predicates for this pattern and root. 810 std::vector<PositionalPredicate> predicates; 811 }; 812 813 SmallVector<PatternPredicates, 16> patternsAndPredicates; 814 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 815 std::vector<PositionalPredicate> predicateList; 816 Value root = 817 buildPredicateList(pattern, builder, predicateList, valueToPosition); 818 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); 819 } 820 821 // Associate a pattern result with each unique predicate. 822 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 823 for (auto &patternAndPredList : patternsAndPredicates) { 824 for (auto &predicate : patternAndPredList.predicates) { 825 auto it = uniqued.insert(predicate); 826 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, 827 predicate.answer); 828 } 829 } 830 831 // Associate each pattern to a set of its ordered predicates for later lookup. 832 std::vector<OrderedPredicateList> lists; 833 lists.reserve(patternsAndPredicates.size()); 834 for (auto &patternAndPredList : patternsAndPredicates) { 835 OrderedPredicateList list(patternAndPredList.pattern, 836 patternAndPredList.root); 837 for (auto &predicate : patternAndPredList.predicates) { 838 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 839 list.predicates.insert(orderedPredicate); 840 841 // Increment the primary sum for each reference to a particular predicate. 842 ++orderedPredicate->primary; 843 } 844 lists.push_back(std::move(list)); 845 } 846 847 // For a particular pattern, get the total primary sum and add it to the 848 // secondary sum of each predicate. Square the primary sums to emphasize 849 // shared predicates within rather than across patterns. 850 for (auto &list : lists) { 851 unsigned total = 0; 852 for (auto *predicate : list.predicates) 853 total += predicate->primary * predicate->primary; 854 for (auto *predicate : list.predicates) 855 predicate->secondary += total; 856 } 857 858 // Sort the set of predicates now that the cost primary and secondary sums 859 // have been computed. 860 std::vector<OrderedPredicate *> ordered; 861 ordered.reserve(uniqued.size()); 862 for (auto &ip : uniqued) 863 ordered.push_back(&ip); 864 std::stable_sort( 865 ordered.begin(), ordered.end(), 866 [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; }); 867 868 // Build the matchers for each of the pattern predicate lists. 869 std::unique_ptr<MatcherNode> root; 870 for (OrderedPredicateList &list : lists) 871 propagatePattern(root, list, ordered.begin(), ordered.end()); 872 873 // Collapse the graph and insert the exit node. 874 foldSwitchToBool(root); 875 insertExitNode(&root); 876 return root; 877 } 878 879 //===----------------------------------------------------------------------===// 880 // MatcherNode 881 //===----------------------------------------------------------------------===// 882 883 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 884 std::unique_ptr<MatcherNode> failureNode) 885 : position(p), question(q), failureNode(std::move(failureNode)), 886 matcherTypeID(matcherTypeID) {} 887 888 //===----------------------------------------------------------------------===// 889 // BoolNode 890 //===----------------------------------------------------------------------===// 891 892 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 893 std::unique_ptr<MatcherNode> successNode, 894 std::unique_ptr<MatcherNode> failureNode) 895 : MatcherNode(TypeID::get<BoolNode>(), position, question, 896 std::move(failureNode)), 897 answer(answer), successNode(std::move(successNode)) {} 898 899 //===----------------------------------------------------------------------===// 900 // SuccessNode 901 //===----------------------------------------------------------------------===// 902 903 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, 904 std::unique_ptr<MatcherNode> failureNode) 905 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 906 /*question=*/nullptr, std::move(failureNode)), 907 pattern(pattern), root(root) {} 908 909 //===----------------------------------------------------------------------===// 910 // SwitchNode 911 //===----------------------------------------------------------------------===// 912 913 SwitchNode::SwitchNode(Position *position, Qualifier *question) 914 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 915