1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PredicateTree.h" 10 #include "RootOrdering.h" 11 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/Interfaces/InferTypeOpInterface.h" 17 #include "llvm/ADT/MapVector.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 #include "llvm/Support/Debug.h" 20 #include <queue> 21 22 #define DEBUG_TYPE "pdl-predicate-tree" 23 24 using namespace mlir; 25 using namespace mlir::pdl_to_pdl_interp; 26 27 //===----------------------------------------------------------------------===// 28 // Predicate List Building 29 //===----------------------------------------------------------------------===// 30 31 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 32 Value val, PredicateBuilder &builder, 33 DenseMap<Value, Position *> &inputs, 34 Position *pos); 35 36 /// Compares the depths of two positions. 37 static bool comparePosDepth(Position *lhs, Position *rhs) { 38 return lhs->getOperationDepth() < rhs->getOperationDepth(); 39 } 40 41 /// Returns the number of non-range elements within `values`. 42 static unsigned getNumNonRangeValues(ValueRange values) { 43 return llvm::count_if(values.getTypes(), 44 [](Type type) { return !type.isa<pdl::RangeType>(); }); 45 } 46 47 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 48 Value val, PredicateBuilder &builder, 49 DenseMap<Value, Position *> &inputs, 50 AttributePosition *pos) { 51 assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type"); 52 pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp()); 53 predList.emplace_back(pos, builder.getIsNotNull()); 54 55 // If the attribute has a type or value, add a constraint. 56 if (Value type = attr.getValueType()) 57 getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 58 else if (Attribute value = attr.getValueAttr()) 59 predList.emplace_back(pos, builder.getAttributeConstraint(value)); 60 } 61 62 /// Collect all of the predicates for the given operand position. 63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList, 64 Value val, PredicateBuilder &builder, 65 DenseMap<Value, Position *> &inputs, 66 Position *pos) { 67 Type valueType = val.getType(); 68 bool isVariadic = valueType.isa<pdl::RangeType>(); 69 70 // If this is a typed operand, add a type constraint. 71 TypeSwitch<Operation *>(val.getDefiningOp()) 72 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) { 73 // Prevent traversal into a null value if the operand has a proper 74 // index. 75 if (std::is_same<pdl::OperandOp, decltype(op)>::value || 76 cast<OperandGroupPosition>(pos)->getOperandGroupNumber()) 77 predList.emplace_back(pos, builder.getIsNotNull()); 78 79 if (Value type = op.getValueType()) 80 getTreePredicates(predList, type, builder, inputs, 81 builder.getType(pos)); 82 }) 83 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) { 84 Optional<unsigned> index = op.getIndex(); 85 86 // Prevent traversal into a null value if the result has a proper index. 87 if (index) 88 predList.emplace_back(pos, builder.getIsNotNull()); 89 90 // Get the parent operation of this operand. 91 OperationPosition *parentPos = builder.getOperandDefiningOp(pos); 92 predList.emplace_back(parentPos, builder.getIsNotNull()); 93 94 // Ensure that the operands match the corresponding results of the 95 // parent operation. 96 Position *resultPos = nullptr; 97 if (std::is_same<pdl::ResultOp, decltype(op)>::value) 98 resultPos = builder.getResult(parentPos, *index); 99 else 100 resultPos = builder.getResultGroup(parentPos, index, isVariadic); 101 predList.emplace_back(resultPos, builder.getEqualTo(pos)); 102 103 // Collect the predicates of the parent operation. 104 getTreePredicates(predList, op.getParent(), builder, inputs, 105 (Position *)parentPos); 106 }); 107 } 108 109 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 110 Value val, PredicateBuilder &builder, 111 DenseMap<Value, Position *> &inputs, 112 OperationPosition *pos, 113 Optional<unsigned> ignoreOperand = llvm::None) { 114 assert(val.getType().isa<pdl::OperationType>() && "expected operation"); 115 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 116 OperationPosition *opPos = cast<OperationPosition>(pos); 117 118 // Ensure getDefiningOp returns a non-null operation. 119 if (!opPos->isRoot()) 120 predList.emplace_back(pos, builder.getIsNotNull()); 121 122 // Check that this is the correct root operation. 123 if (Optional<StringRef> opName = op.getOpName()) 124 predList.emplace_back(pos, builder.getOperationName(*opName)); 125 126 // Check that the operation has the proper number of operands. If there are 127 // any variable length operands, we check a minimum instead of an exact count. 128 OperandRange operands = op.getOperandValues(); 129 unsigned minOperands = getNumNonRangeValues(operands); 130 if (minOperands != operands.size()) { 131 if (minOperands) 132 predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands)); 133 } else { 134 predList.emplace_back(pos, builder.getOperandCount(minOperands)); 135 } 136 137 // Check that the operation has the proper number of results. If there are 138 // any variable length results, we check a minimum instead of an exact count. 139 OperandRange types = op.getTypeValues(); 140 unsigned minResults = getNumNonRangeValues(types); 141 if (minResults == types.size()) 142 predList.emplace_back(pos, builder.getResultCount(types.size())); 143 else if (minResults) 144 predList.emplace_back(pos, builder.getResultCountAtLeast(minResults)); 145 146 // Recurse into any attributes, operands, or results. 147 for (auto [attrName, attr] : 148 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) { 149 getTreePredicates( 150 predList, attr, builder, inputs, 151 builder.getAttribute(opPos, attrName.cast<StringAttr>().getValue())); 152 } 153 154 // Process the operands and results of the operation. For all values up to 155 // the first variable length value, we use the concrete operand/result 156 // number. After that, we use the "group" given that we can't know the 157 // concrete indices until runtime. If there is only one variadic operand 158 // group, we treat it as all of the operands/results of the operation. 159 /// Operands. 160 if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) { 161 // Ignore the operands if we are performing an upward traversal (in that 162 // case, they have already been visited). 163 if (opPos->isRoot() || opPos->isOperandDefiningOp()) 164 getTreePredicates(predList, operands.front(), builder, inputs, 165 builder.getAllOperands(opPos)); 166 } else { 167 bool foundVariableLength = false; 168 for (const auto &operandIt : llvm::enumerate(operands)) { 169 bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>(); 170 foundVariableLength |= isVariadic; 171 172 // Ignore the specified operand, usually because this position was 173 // visited in an upward traversal via an iterative choice. 174 if (ignoreOperand && *ignoreOperand == operandIt.index()) 175 continue; 176 177 Position *pos = 178 foundVariableLength 179 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) 180 : builder.getOperand(opPos, operandIt.index()); 181 getTreePredicates(predList, operandIt.value(), builder, inputs, pos); 182 } 183 } 184 /// Results. 185 if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) { 186 getTreePredicates(predList, types.front(), builder, inputs, 187 builder.getType(builder.getAllResults(opPos))); 188 } else { 189 bool foundVariableLength = false; 190 for (auto &resultIt : llvm::enumerate(types)) { 191 bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>(); 192 foundVariableLength |= isVariadic; 193 194 auto *resultPos = 195 foundVariableLength 196 ? builder.getResultGroup(pos, resultIt.index(), isVariadic) 197 : builder.getResult(pos, resultIt.index()); 198 predList.emplace_back(resultPos, builder.getIsNotNull()); 199 getTreePredicates(predList, resultIt.value(), builder, inputs, 200 builder.getType(resultPos)); 201 } 202 } 203 } 204 205 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 206 Value val, PredicateBuilder &builder, 207 DenseMap<Value, Position *> &inputs, 208 TypePosition *pos) { 209 // Check for a constraint on a constant type. 210 if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) { 211 if (Attribute type = typeOp.getConstantTypeAttr()) 212 predList.emplace_back(pos, builder.getTypeConstraint(type)); 213 } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) { 214 if (Attribute typeAttr = typeOp.getConstantTypesAttr()) 215 predList.emplace_back(pos, builder.getTypeConstraint(typeAttr)); 216 } 217 } 218 219 /// Collect the tree predicates anchored at the given value. 220 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 221 Value val, PredicateBuilder &builder, 222 DenseMap<Value, Position *> &inputs, 223 Position *pos) { 224 // Make sure this input value is accessible to the rewrite. 225 auto it = inputs.try_emplace(val, pos); 226 if (!it.second) { 227 // If this is an input value that has been visited in the tree, add a 228 // constraint to ensure that both instances refer to the same value. 229 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp, 230 pdl::TypeOp>(val.getDefiningOp())) { 231 auto minMaxPositions = 232 std::minmax(pos, it.first->second, comparePosDepth); 233 predList.emplace_back(minMaxPositions.second, 234 builder.getEqualTo(minMaxPositions.first)); 235 } 236 return; 237 } 238 239 TypeSwitch<Position *>(pos) 240 .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) { 241 getTreePredicates(predList, val, builder, inputs, pos); 242 }) 243 .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) { 244 getOperandTreePredicates(predList, val, builder, inputs, pos); 245 }) 246 .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); 247 } 248 249 static void getAttributePredicates(pdl::AttributeOp op, 250 std::vector<PositionalPredicate> &predList, 251 PredicateBuilder &builder, 252 DenseMap<Value, Position *> &inputs) { 253 Position *&attrPos = inputs[op]; 254 if (attrPos) 255 return; 256 Attribute value = op.getValueAttr(); 257 assert(value && "expected non-tree `pdl.attribute` to contain a value"); 258 attrPos = builder.getAttributeLiteral(value); 259 } 260 261 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, 262 std::vector<PositionalPredicate> &predList, 263 PredicateBuilder &builder, 264 DenseMap<Value, Position *> &inputs) { 265 OperandRange arguments = op.getArgs(); 266 267 std::vector<Position *> allPositions; 268 allPositions.reserve(arguments.size()); 269 for (Value arg : arguments) 270 allPositions.push_back(inputs.lookup(arg)); 271 272 // Push the constraint to the furthest position. 273 Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), 274 comparePosDepth); 275 PredicateBuilder::Predicate pred = 276 builder.getConstraint(op.getName(), allPositions); 277 predList.emplace_back(pos, pred); 278 } 279 280 static void getResultPredicates(pdl::ResultOp op, 281 std::vector<PositionalPredicate> &predList, 282 PredicateBuilder &builder, 283 DenseMap<Value, Position *> &inputs) { 284 Position *&resultPos = inputs[op]; 285 if (resultPos) 286 return; 287 288 // Ensure that the result isn't null. 289 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent())); 290 resultPos = builder.getResult(parentPos, op.getIndex()); 291 predList.emplace_back(resultPos, builder.getIsNotNull()); 292 } 293 294 static void getResultPredicates(pdl::ResultsOp op, 295 std::vector<PositionalPredicate> &predList, 296 PredicateBuilder &builder, 297 DenseMap<Value, Position *> &inputs) { 298 Position *&resultPos = inputs[op]; 299 if (resultPos) 300 return; 301 302 // Ensure that the result isn't null if the result has an index. 303 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent())); 304 bool isVariadic = op.getType().isa<pdl::RangeType>(); 305 Optional<unsigned> index = op.getIndex(); 306 resultPos = builder.getResultGroup(parentPos, index, isVariadic); 307 if (index) 308 predList.emplace_back(resultPos, builder.getIsNotNull()); 309 } 310 311 static void getTypePredicates(Value typeValue, 312 function_ref<Attribute()> typeAttrFn, 313 PredicateBuilder &builder, 314 DenseMap<Value, Position *> &inputs) { 315 Position *&typePos = inputs[typeValue]; 316 if (typePos) 317 return; 318 Attribute typeAttr = typeAttrFn(); 319 assert(typeAttr && 320 "expected non-tree `pdl.type`/`pdl.types` to contain a value"); 321 typePos = builder.getTypeLiteral(typeAttr); 322 } 323 324 /// Collect all of the predicates that cannot be determined via walking the 325 /// tree. 326 static void getNonTreePredicates(pdl::PatternOp pattern, 327 std::vector<PositionalPredicate> &predList, 328 PredicateBuilder &builder, 329 DenseMap<Value, Position *> &inputs) { 330 for (Operation &op : pattern.getBodyRegion().getOps()) { 331 TypeSwitch<Operation *>(&op) 332 .Case([&](pdl::AttributeOp attrOp) { 333 getAttributePredicates(attrOp, predList, builder, inputs); 334 }) 335 .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) { 336 getConstraintPredicates(constraintOp, predList, builder, inputs); 337 }) 338 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 339 getResultPredicates(resultOp, predList, builder, inputs); 340 }) 341 .Case([&](pdl::TypeOp typeOp) { 342 getTypePredicates( 343 typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder, 344 inputs); 345 }) 346 .Case([&](pdl::TypesOp typeOp) { 347 getTypePredicates( 348 typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder, 349 inputs); 350 }); 351 } 352 } 353 354 namespace { 355 356 /// An op accepting a value at an optional index. 357 struct OpIndex { 358 Value parent; 359 Optional<unsigned> index; 360 }; 361 362 /// The parent and operand index of each operation for each root, stored 363 /// as a nested map [root][operation]. 364 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>; 365 366 } // namespace 367 368 /// Given a pattern, determines the set of roots present in this pattern. 369 /// These are the operations whose results are not consumed by other operations. 370 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) { 371 // First, collect all the operations that are used as operands 372 // to other operations. These are not roots by default. 373 DenseSet<Value> used; 374 for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) { 375 for (Value operand : operationOp.getOperandValues()) 376 TypeSwitch<Operation *>(operand.getDefiningOp()) 377 .Case<pdl::ResultOp, pdl::ResultsOp>( 378 [&used](auto resultOp) { used.insert(resultOp.getParent()); }); 379 } 380 381 // Remove the specified root from the use set, so that we can 382 // always select it as a root, even if it is used by other operations. 383 if (Value root = pattern.getRewriter().getRoot()) 384 used.erase(root); 385 386 // Finally, collect all the unused operations. 387 SmallVector<Value> roots; 388 for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) 389 if (!used.contains(operationOp)) 390 roots.push_back(operationOp); 391 392 return roots; 393 } 394 395 /// Given a list of candidate roots, builds the cost graph for connecting them. 396 /// The graph is formed by traversing the DAG of operations starting from each 397 /// root and marking the depth of each connector value (operand). Then we join 398 /// the candidate roots based on the common connector values, taking the one 399 /// with the minimum depth. Along the way, we compute, for each candidate root, 400 /// a mapping from each operation (in the DAG underneath this root) to its 401 /// parent operation and the corresponding operand index. 402 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph, 403 ParentMaps &parentMaps) { 404 405 // The entry of a queue. The entry consists of the following items: 406 // * the value in the DAG underneath the root; 407 // * the parent of the value; 408 // * the operand index of the value in its parent; 409 // * the depth of the visited value. 410 struct Entry { 411 Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth) 412 : value(value), parent(parent), index(index), depth(depth) {} 413 414 Value value; 415 Value parent; 416 Optional<unsigned> index; 417 unsigned depth; 418 }; 419 420 // A root of a value and its depth (distance from root to the value). 421 struct RootDepth { 422 Value root; 423 unsigned depth = 0; 424 }; 425 426 // Map from candidate connector values to their roots and depths. Using a 427 // small vector with 1 entry because most values belong to a single root. 428 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths; 429 430 // Perform a breadth-first traversal of the op DAG rooted at each root. 431 for (Value root : roots) { 432 // The queue of visited values. A value may be present multiple times in 433 // the queue, for multiple parents. We only accept the first occurrence, 434 // which is guaranteed to have the lowest depth. 435 std::queue<Entry> toVisit; 436 toVisit.emplace(root, Value(), 0, 0); 437 438 // The map from value to its parent for the current root. 439 DenseMap<Value, OpIndex> &parentMap = parentMaps[root]; 440 441 while (!toVisit.empty()) { 442 Entry entry = toVisit.front(); 443 toVisit.pop(); 444 // Skip if already visited. 445 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) 446 continue; 447 448 // Mark the root and depth of the value. 449 connectorsRootsDepths[entry.value].push_back({root, entry.depth}); 450 451 // Traverse the operands of an operation and result ops. 452 // We intentionally do not traverse attributes and types, because those 453 // are expensive to join on. 454 TypeSwitch<Operation *>(entry.value.getDefiningOp()) 455 .Case<pdl::OperationOp>([&](auto operationOp) { 456 OperandRange operands = operationOp.getOperandValues(); 457 // Special case when we pass all the operands in one range. 458 // For those, the index is empty. 459 if (operands.size() == 1 && 460 operands[0].getType().isa<pdl::RangeType>()) { 461 toVisit.emplace(operands[0], entry.value, llvm::None, 462 entry.depth + 1); 463 return; 464 } 465 466 // Default case: visit all the operands. 467 for (const auto &p : 468 llvm::enumerate(operationOp.getOperandValues())) 469 toVisit.emplace(p.value(), entry.value, p.index(), 470 entry.depth + 1); 471 }) 472 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 473 toVisit.emplace(resultOp.getParent(), entry.value, 474 resultOp.getIndex(), entry.depth); 475 }); 476 } 477 } 478 479 // Now build the cost graph. 480 // This is simply a minimum over all depths for the target root. 481 unsigned nextID = 0; 482 for (const auto &connectorRootsDepths : connectorsRootsDepths) { 483 Value value = connectorRootsDepths.first; 484 ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second; 485 // If there is only one root for this value, this will not trigger 486 // any edges in the cost graph (a perf optimization). 487 if (rootsDepths.size() == 1) 488 continue; 489 490 for (const RootDepth &p : rootsDepths) { 491 for (const RootDepth &q : rootsDepths) { 492 if (&p == &q) 493 continue; 494 // Insert or retrieve the property of edge from p to q. 495 RootOrderingEntry &entry = graph[q.root][p.root]; 496 if (!entry.connector /* new edge */ || entry.cost.first > q.depth) { 497 if (!entry.connector) 498 entry.cost.second = nextID++; 499 entry.cost.first = q.depth; 500 entry.connector = value; 501 } 502 } 503 } 504 } 505 506 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && 507 "the pattern contains a candidate root disconnected from the others"); 508 } 509 510 /// Returns true if the operand at the given index needs to be queried using an 511 /// operand group, i.e., if it is variadic itself or follows a variadic operand. 512 static bool useOperandGroup(pdl::OperationOp op, unsigned index) { 513 OperandRange operands = op.getOperandValues(); 514 assert(index < operands.size() && "operand index out of range"); 515 for (unsigned i = 0; i <= index; ++i) 516 if (operands[i].getType().isa<pdl::RangeType>()) 517 return true; 518 return false; 519 } 520 521 /// Visit a node during upward traversal. 522 static void visitUpward(std::vector<PositionalPredicate> &predList, 523 OpIndex opIndex, PredicateBuilder &builder, 524 DenseMap<Value, Position *> &valueToPosition, 525 Position *&pos, unsigned rootID) { 526 Value value = opIndex.parent; 527 TypeSwitch<Operation *>(value.getDefiningOp()) 528 .Case<pdl::OperationOp>([&](auto operationOp) { 529 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 530 531 // Get users and iterate over them. 532 Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true); 533 Position *foreachPos = builder.getForEach(usersPos, rootID); 534 OperationPosition *opPos = builder.getPassthroughOp(foreachPos); 535 536 // Compare the operand(s) of the user against the input value(s). 537 Position *operandPos; 538 if (!opIndex.index) { 539 // We are querying all the operands of the operation. 540 operandPos = builder.getAllOperands(opPos); 541 } else if (useOperandGroup(operationOp, *opIndex.index)) { 542 // We are querying an operand group. 543 Type type = operationOp.getOperandValues()[*opIndex.index].getType(); 544 bool variadic = type.isa<pdl::RangeType>(); 545 operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic); 546 } else { 547 // We are querying an individual operand. 548 operandPos = builder.getOperand(opPos, *opIndex.index); 549 } 550 predList.emplace_back(operandPos, builder.getEqualTo(pos)); 551 552 // Guard against duplicate upward visits. These are not possible, 553 // because if this value was already visited, it would have been 554 // cheaper to start the traversal at this value rather than at the 555 // `connector`, violating the optimality of our spanning tree. 556 bool inserted = valueToPosition.try_emplace(value, opPos).second; 557 (void)inserted; 558 assert(inserted && "duplicate upward visit"); 559 560 // Obtain the tree predicates at the current value. 561 getTreePredicates(predList, value, builder, valueToPosition, opPos, 562 opIndex.index); 563 564 // Update the position 565 pos = opPos; 566 }) 567 .Case<pdl::ResultOp>([&](auto resultOp) { 568 // Traverse up an individual result. 569 auto *opPos = dyn_cast<OperationPosition>(pos); 570 assert(opPos && "operations and results must be interleaved"); 571 pos = builder.getResult(opPos, *opIndex.index); 572 573 // Insert the result position in case we have not visited it yet. 574 valueToPosition.try_emplace(value, pos); 575 }) 576 .Case<pdl::ResultsOp>([&](auto resultOp) { 577 // Traverse up a group of results. 578 auto *opPos = dyn_cast<OperationPosition>(pos); 579 assert(opPos && "operations and results must be interleaved"); 580 bool isVariadic = value.getType().isa<pdl::RangeType>(); 581 if (opIndex.index) 582 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); 583 else 584 pos = builder.getAllResults(opPos); 585 586 // Insert the result position in case we have not visited it yet. 587 valueToPosition.try_emplace(value, pos); 588 }); 589 } 590 591 /// Given a pattern operation, build the set of matcher predicates necessary to 592 /// match this pattern. 593 static Value buildPredicateList(pdl::PatternOp pattern, 594 PredicateBuilder &builder, 595 std::vector<PositionalPredicate> &predList, 596 DenseMap<Value, Position *> &valueToPosition) { 597 SmallVector<Value> roots = detectRoots(pattern); 598 599 // Build the root ordering graph and compute the parent maps. 600 RootOrderingGraph graph; 601 ParentMaps parentMaps; 602 buildCostGraph(roots, graph, parentMaps); 603 LLVM_DEBUG({ 604 llvm::dbgs() << "Graph:\n"; 605 for (auto &target : graph) { 606 llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first 607 << "\n"; 608 for (auto &source : target.second) { 609 RootOrderingEntry &entry = source.second; 610 llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first 611 << ":" << entry.cost.second << " via " 612 << entry.connector.getLoc() << "\n"; 613 } 614 } 615 }); 616 617 // Solve the optimal branching problem for each candidate root, or use the 618 // provided one. 619 Value bestRoot = pattern.getRewriter().getRoot(); 620 OptimalBranching::EdgeList bestEdges; 621 if (!bestRoot) { 622 unsigned bestCost = 0; 623 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); 624 for (Value root : roots) { 625 OptimalBranching solver(graph, root); 626 unsigned cost = solver.solve(); 627 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); 628 if (!bestRoot || bestCost > cost) { 629 bestCost = cost; 630 bestRoot = root; 631 bestEdges = solver.preOrderTraversal(roots); 632 } 633 } 634 } else { 635 OptimalBranching solver(graph, bestRoot); 636 solver.solve(); 637 bestEdges = solver.preOrderTraversal(roots); 638 } 639 640 // Print the best solution. 641 LLVM_DEBUG({ 642 llvm::dbgs() << "Best tree:\n"; 643 for (const std::pair<Value, Value> &edge : bestEdges) { 644 llvm::dbgs() << " * " << edge.first; 645 if (edge.second) 646 llvm::dbgs() << " <- " << edge.second; 647 llvm::dbgs() << "\n"; 648 } 649 }); 650 651 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); 652 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); 653 654 // The best root is the starting point for the traversal. Get the tree 655 // predicates for the DAG rooted at bestRoot. 656 getTreePredicates(predList, bestRoot, builder, valueToPosition, 657 builder.getRoot()); 658 659 // Traverse the selected optimal branching. For all edges in order, traverse 660 // up starting from the connector, until the candidate root is reached, and 661 // call getTreePredicates at every node along the way. 662 for (const auto &it : llvm::enumerate(bestEdges)) { 663 Value target = it.value().first; 664 Value source = it.value().second; 665 666 // Check if we already visited the target root. This happens in two cases: 667 // 1) the initial root (bestRoot); 668 // 2) a root that is dominated by (contained in the subtree rooted at) an 669 // already visited root. 670 if (valueToPosition.count(target)) 671 continue; 672 673 // Determine the connector. 674 Value connector = graph[target][source].connector; 675 assert(connector && "invalid edge"); 676 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); 677 DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target); 678 Position *pos = valueToPosition.lookup(connector); 679 assert(pos && "connector has not been traversed yet"); 680 681 // Traverse from the connector upwards towards the target root. 682 for (Value value = connector; value != target;) { 683 OpIndex opIndex = parentMap.lookup(value); 684 assert(opIndex.parent && "missing parent"); 685 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index()); 686 value = opIndex.parent; 687 } 688 } 689 690 getNonTreePredicates(pattern, predList, builder, valueToPosition); 691 692 return bestRoot; 693 } 694 695 //===----------------------------------------------------------------------===// 696 // Pattern Predicate Tree Merging 697 //===----------------------------------------------------------------------===// 698 699 namespace { 700 701 /// This class represents a specific predicate applied to a position, and 702 /// provides hashing and ordering operators. This class allows for computing a 703 /// frequence sum and ordering predicates based on a cost model. 704 struct OrderedPredicate { 705 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 706 : position(ip.first), question(ip.second) {} 707 OrderedPredicate(const PositionalPredicate &ip) 708 : position(ip.position), question(ip.question) {} 709 710 /// The position this predicate is applied to. 711 Position *position; 712 713 /// The question that is applied by this predicate onto the position. 714 Qualifier *question; 715 716 /// The first and second order benefit sums. 717 /// The primary sum is the number of occurrences of this predicate among all 718 /// of the patterns. 719 unsigned primary = 0; 720 /// The secondary sum is a squared summation of the primary sum of all of the 721 /// predicates within each pattern that contains this predicate. This allows 722 /// for favoring predicates that are more commonly shared within a pattern, as 723 /// opposed to those shared across patterns. 724 unsigned secondary = 0; 725 726 /// The tie breaking ID, used to preserve a deterministic (insertion) order 727 /// among all the predicates with the same priority, depth, and position / 728 /// predicate dependency. 729 unsigned id = 0; 730 731 /// A map between a pattern operation and the answer to the predicate question 732 /// within that pattern. 733 DenseMap<Operation *, Qualifier *> patternToAnswer; 734 735 /// Returns true if this predicate is ordered before `rhs`, based on the cost 736 /// model. 737 bool operator<(const OrderedPredicate &rhs) const { 738 // Sort by: 739 // * higher first and secondary order sums 740 // * lower depth 741 // * lower position dependency 742 // * lower predicate dependency 743 // * lower tie breaking ID 744 auto *rhsPos = rhs.position; 745 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), 746 rhsPos->getKind(), rhs.question->getKind(), rhs.id) > 747 std::make_tuple(rhs.primary, rhs.secondary, 748 position->getOperationDepth(), position->getKind(), 749 question->getKind(), id); 750 } 751 }; 752 753 /// A DenseMapInfo for OrderedPredicate based solely on the position and 754 /// question. 755 struct OrderedPredicateDenseInfo { 756 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 757 758 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 759 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 760 static bool isEqual(const OrderedPredicate &lhs, 761 const OrderedPredicate &rhs) { 762 return lhs.position == rhs.position && lhs.question == rhs.question; 763 } 764 static unsigned getHashValue(const OrderedPredicate &p) { 765 return llvm::hash_combine(p.position, p.question); 766 } 767 }; 768 769 /// This class wraps a set of ordered predicates that are used within a specific 770 /// pattern operation. 771 struct OrderedPredicateList { 772 OrderedPredicateList(pdl::PatternOp pattern, Value root) 773 : pattern(pattern), root(root) {} 774 775 pdl::PatternOp pattern; 776 Value root; 777 DenseSet<OrderedPredicate *> predicates; 778 }; 779 } // namespace 780 781 /// Returns true if the given matcher refers to the same predicate as the given 782 /// ordered predicate. This means that the position and questions of the two 783 /// match. 784 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 785 return node->getPosition() == predicate->position && 786 node->getQuestion() == predicate->question; 787 } 788 789 /// Get or insert a child matcher for the given parent switch node, given a 790 /// predicate and parent pattern. 791 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 792 OrderedPredicate *predicate, 793 pdl::PatternOp pattern) { 794 assert(isSamePredicate(node, predicate) && 795 "expected matcher to equal the given predicate"); 796 797 auto it = predicate->patternToAnswer.find(pattern); 798 assert(it != predicate->patternToAnswer.end() && 799 "expected pattern to exist in predicate"); 800 return node->getChildren().insert({it->second, nullptr}).first->second; 801 } 802 803 /// Build the matcher CFG by "pushing" patterns through by sorted predicate 804 /// order. A pattern will traverse as far as possible using common predicates 805 /// and then either diverge from the CFG or reach the end of a branch and start 806 /// creating new nodes. 807 static void propagatePattern(std::unique_ptr<MatcherNode> &node, 808 OrderedPredicateList &list, 809 std::vector<OrderedPredicate *>::iterator current, 810 std::vector<OrderedPredicate *>::iterator end) { 811 if (current == end) { 812 // We've hit the end of a pattern, so create a successful result node. 813 node = 814 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node)); 815 816 // If the pattern doesn't contain this predicate, ignore it. 817 } else if (list.predicates.find(*current) == list.predicates.end()) { 818 propagatePattern(node, list, std::next(current), end); 819 820 // If the current matcher node is invalid, create a new one for this 821 // position and continue propagation. 822 } else if (!node) { 823 // Create a new node at this position and continue 824 node = std::make_unique<SwitchNode>((*current)->position, 825 (*current)->question); 826 propagatePattern( 827 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 828 list, std::next(current), end); 829 830 // If the matcher has already been created, and it is for this predicate we 831 // continue propagation to the child. 832 } else if (isSamePredicate(node.get(), *current)) { 833 propagatePattern( 834 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 835 list, std::next(current), end); 836 837 // If the matcher doesn't match the current predicate, insert a branch as 838 // the common set of matchers has diverged. 839 } else { 840 propagatePattern(node->getFailureNode(), list, current, end); 841 } 842 } 843 844 /// Fold any switch nodes nested under `node` to boolean nodes when possible. 845 /// `node` is updated in-place if it is a switch. 846 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 847 if (!node) 848 return; 849 850 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 851 SwitchNode::ChildMapT &children = switchNode->getChildren(); 852 for (auto &it : children) 853 foldSwitchToBool(it.second); 854 855 // If the node only contains one child, collapse it into a boolean predicate 856 // node. 857 if (children.size() == 1) { 858 auto childIt = children.begin(); 859 node = std::make_unique<BoolNode>( 860 node->getPosition(), node->getQuestion(), childIt->first, 861 std::move(childIt->second), std::move(node->getFailureNode())); 862 } 863 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 864 foldSwitchToBool(boolNode->getSuccessNode()); 865 } 866 867 foldSwitchToBool(node->getFailureNode()); 868 } 869 870 /// Insert an exit node at the end of the failure path of the `root`. 871 static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 872 while (*root) 873 root = &(*root)->getFailureNode(); 874 *root = std::make_unique<ExitNode>(); 875 } 876 877 /// Given a module containing PDL pattern operations, generate a matcher tree 878 /// using the patterns within the given module and return the root matcher node. 879 std::unique_ptr<MatcherNode> 880 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 881 DenseMap<Value, Position *> &valueToPosition) { 882 // The set of predicates contained within the pattern operations of the 883 // module. 884 struct PatternPredicates { 885 PatternPredicates(pdl::PatternOp pattern, Value root, 886 std::vector<PositionalPredicate> predicates) 887 : pattern(pattern), root(root), predicates(std::move(predicates)) {} 888 889 /// A pattern. 890 pdl::PatternOp pattern; 891 892 /// A root of the pattern chosen among the candidate roots in pdl.rewrite. 893 Value root; 894 895 /// The extracted predicates for this pattern and root. 896 std::vector<PositionalPredicate> predicates; 897 }; 898 899 SmallVector<PatternPredicates, 16> patternsAndPredicates; 900 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 901 std::vector<PositionalPredicate> predicateList; 902 Value root = 903 buildPredicateList(pattern, builder, predicateList, valueToPosition); 904 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); 905 } 906 907 // Associate a pattern result with each unique predicate. 908 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 909 for (auto &patternAndPredList : patternsAndPredicates) { 910 for (auto &predicate : patternAndPredList.predicates) { 911 auto it = uniqued.insert(predicate); 912 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, 913 predicate.answer); 914 // Mark the insertion order (0-based indexing). 915 if (it.second) 916 it.first->id = uniqued.size() - 1; 917 } 918 } 919 920 // Associate each pattern to a set of its ordered predicates for later lookup. 921 std::vector<OrderedPredicateList> lists; 922 lists.reserve(patternsAndPredicates.size()); 923 for (auto &patternAndPredList : patternsAndPredicates) { 924 OrderedPredicateList list(patternAndPredList.pattern, 925 patternAndPredList.root); 926 for (auto &predicate : patternAndPredList.predicates) { 927 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 928 list.predicates.insert(orderedPredicate); 929 930 // Increment the primary sum for each reference to a particular predicate. 931 ++orderedPredicate->primary; 932 } 933 lists.push_back(std::move(list)); 934 } 935 936 // For a particular pattern, get the total primary sum and add it to the 937 // secondary sum of each predicate. Square the primary sums to emphasize 938 // shared predicates within rather than across patterns. 939 for (auto &list : lists) { 940 unsigned total = 0; 941 for (auto *predicate : list.predicates) 942 total += predicate->primary * predicate->primary; 943 for (auto *predicate : list.predicates) 944 predicate->secondary += total; 945 } 946 947 // Sort the set of predicates now that the cost primary and secondary sums 948 // have been computed. 949 std::vector<OrderedPredicate *> ordered; 950 ordered.reserve(uniqued.size()); 951 for (auto &ip : uniqued) 952 ordered.push_back(&ip); 953 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) { 954 return *lhs < *rhs; 955 }); 956 957 // Build the matchers for each of the pattern predicate lists. 958 std::unique_ptr<MatcherNode> root; 959 for (OrderedPredicateList &list : lists) 960 propagatePattern(root, list, ordered.begin(), ordered.end()); 961 962 // Collapse the graph and insert the exit node. 963 foldSwitchToBool(root); 964 insertExitNode(&root); 965 return root; 966 } 967 968 //===----------------------------------------------------------------------===// 969 // MatcherNode 970 //===----------------------------------------------------------------------===// 971 972 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 973 std::unique_ptr<MatcherNode> failureNode) 974 : position(p), question(q), failureNode(std::move(failureNode)), 975 matcherTypeID(matcherTypeID) {} 976 977 //===----------------------------------------------------------------------===// 978 // BoolNode 979 //===----------------------------------------------------------------------===// 980 981 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 982 std::unique_ptr<MatcherNode> successNode, 983 std::unique_ptr<MatcherNode> failureNode) 984 : MatcherNode(TypeID::get<BoolNode>(), position, question, 985 std::move(failureNode)), 986 answer(answer), successNode(std::move(successNode)) {} 987 988 //===----------------------------------------------------------------------===// 989 // SuccessNode 990 //===----------------------------------------------------------------------===// 991 992 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, 993 std::unique_ptr<MatcherNode> failureNode) 994 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 995 /*question=*/nullptr, std::move(failureNode)), 996 pattern(pattern), root(root) {} 997 998 //===----------------------------------------------------------------------===// 999 // SwitchNode 1000 //===----------------------------------------------------------------------===// 1001 1002 SwitchNode::SwitchNode(Position *position, Qualifier *question) 1003 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 1004