1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PredicateTree.h" 10 #include "RootOrdering.h" 11 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/Interfaces/InferTypeOpInterface.h" 17 #include "llvm/ADT/MapVector.h" 18 #include "llvm/ADT/SmallPtrSet.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 #include "llvm/Support/Debug.h" 21 #include <queue> 22 23 #define DEBUG_TYPE "pdl-predicate-tree" 24 25 using namespace mlir; 26 using namespace mlir::pdl_to_pdl_interp; 27 28 //===----------------------------------------------------------------------===// 29 // Predicate List Building 30 //===----------------------------------------------------------------------===// 31 32 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 33 Value val, PredicateBuilder &builder, 34 DenseMap<Value, Position *> &inputs, 35 Position *pos); 36 37 /// Compares the depths of two positions. 38 static bool comparePosDepth(Position *lhs, Position *rhs) { 39 return lhs->getOperationDepth() < rhs->getOperationDepth(); 40 } 41 42 /// Returns the number of non-range elements within `values`. 43 static unsigned getNumNonRangeValues(ValueRange values) { 44 return llvm::count_if(values.getTypes(), 45 [](Type type) { return !isa<pdl::RangeType>(type); }); 46 } 47 48 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 49 Value val, PredicateBuilder &builder, 50 DenseMap<Value, Position *> &inputs, 51 AttributePosition *pos) { 52 assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type"); 53 predList.emplace_back(pos, builder.getIsNotNull()); 54 55 if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { 56 // If the attribute has a type or value, add a constraint. 57 if (Value type = attr.getValueType()) 58 getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 59 else if (Attribute value = attr.getValueAttr()) 60 predList.emplace_back(pos, builder.getAttributeConstraint(value)); 61 } 62 } 63 64 /// Collect all of the predicates for the given operand position. 65 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList, 66 Value val, PredicateBuilder &builder, 67 DenseMap<Value, Position *> &inputs, 68 Position *pos) { 69 Type valueType = val.getType(); 70 bool isVariadic = isa<pdl::RangeType>(valueType); 71 72 // If this is a typed operand, add a type constraint. 73 TypeSwitch<Operation *>(val.getDefiningOp()) 74 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) { 75 // Prevent traversal into a null value if the operand has a proper 76 // index. 77 if (std::is_same<pdl::OperandOp, decltype(op)>::value || 78 cast<OperandGroupPosition>(pos)->getOperandGroupNumber()) 79 predList.emplace_back(pos, builder.getIsNotNull()); 80 81 if (Value type = op.getValueType()) 82 getTreePredicates(predList, type, builder, inputs, 83 builder.getType(pos)); 84 }) 85 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) { 86 std::optional<unsigned> index = op.getIndex(); 87 88 // Prevent traversal into a null value if the result has a proper index. 89 if (index) 90 predList.emplace_back(pos, builder.getIsNotNull()); 91 92 // Get the parent operation of this operand. 93 OperationPosition *parentPos = builder.getOperandDefiningOp(pos); 94 predList.emplace_back(parentPos, builder.getIsNotNull()); 95 96 // Ensure that the operands match the corresponding results of the 97 // parent operation. 98 Position *resultPos = nullptr; 99 if (std::is_same<pdl::ResultOp, decltype(op)>::value) 100 resultPos = builder.getResult(parentPos, *index); 101 else 102 resultPos = builder.getResultGroup(parentPos, index, isVariadic); 103 predList.emplace_back(resultPos, builder.getEqualTo(pos)); 104 105 // Collect the predicates of the parent operation. 106 getTreePredicates(predList, op.getParent(), builder, inputs, 107 (Position *)parentPos); 108 }); 109 } 110 111 static void 112 getTreePredicates(std::vector<PositionalPredicate> &predList, Value val, 113 PredicateBuilder &builder, 114 DenseMap<Value, Position *> &inputs, OperationPosition *pos, 115 std::optional<unsigned> ignoreOperand = std::nullopt) { 116 assert(isa<pdl::OperationType>(val.getType()) && "expected operation"); 117 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 118 OperationPosition *opPos = cast<OperationPosition>(pos); 119 120 // Ensure getDefiningOp returns a non-null operation. 121 if (!opPos->isRoot()) 122 predList.emplace_back(pos, builder.getIsNotNull()); 123 124 // Check that this is the correct root operation. 125 if (std::optional<StringRef> opName = op.getOpName()) 126 predList.emplace_back(pos, builder.getOperationName(*opName)); 127 128 // Check that the operation has the proper number of operands. If there are 129 // any variable length operands, we check a minimum instead of an exact count. 130 OperandRange operands = op.getOperandValues(); 131 unsigned minOperands = getNumNonRangeValues(operands); 132 if (minOperands != operands.size()) { 133 if (minOperands) 134 predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands)); 135 } else { 136 predList.emplace_back(pos, builder.getOperandCount(minOperands)); 137 } 138 139 // Check that the operation has the proper number of results. If there are 140 // any variable length results, we check a minimum instead of an exact count. 141 OperandRange types = op.getTypeValues(); 142 unsigned minResults = getNumNonRangeValues(types); 143 if (minResults == types.size()) 144 predList.emplace_back(pos, builder.getResultCount(types.size())); 145 else if (minResults) 146 predList.emplace_back(pos, builder.getResultCountAtLeast(minResults)); 147 148 // Recurse into any attributes, operands, or results. 149 for (auto [attrName, attr] : 150 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) { 151 getTreePredicates( 152 predList, attr, builder, inputs, 153 builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue())); 154 } 155 156 // Process the operands and results of the operation. For all values up to 157 // the first variable length value, we use the concrete operand/result 158 // number. After that, we use the "group" given that we can't know the 159 // concrete indices until runtime. If there is only one variadic operand 160 // group, we treat it as all of the operands/results of the operation. 161 /// Operands. 162 if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) { 163 // Ignore the operands if we are performing an upward traversal (in that 164 // case, they have already been visited). 165 if (opPos->isRoot() || opPos->isOperandDefiningOp()) 166 getTreePredicates(predList, operands.front(), builder, inputs, 167 builder.getAllOperands(opPos)); 168 } else { 169 bool foundVariableLength = false; 170 for (const auto &operandIt : llvm::enumerate(operands)) { 171 bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType()); 172 foundVariableLength |= isVariadic; 173 174 // Ignore the specified operand, usually because this position was 175 // visited in an upward traversal via an iterative choice. 176 if (ignoreOperand && *ignoreOperand == operandIt.index()) 177 continue; 178 179 Position *pos = 180 foundVariableLength 181 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) 182 : builder.getOperand(opPos, operandIt.index()); 183 getTreePredicates(predList, operandIt.value(), builder, inputs, pos); 184 } 185 } 186 /// Results. 187 if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) { 188 getTreePredicates(predList, types.front(), builder, inputs, 189 builder.getType(builder.getAllResults(opPos))); 190 return; 191 } 192 193 bool foundVariableLength = false; 194 for (auto [idx, typeValue] : llvm::enumerate(types)) { 195 bool isVariadic = isa<pdl::RangeType>(typeValue.getType()); 196 foundVariableLength |= isVariadic; 197 198 auto *resultPos = foundVariableLength 199 ? builder.getResultGroup(pos, idx, isVariadic) 200 : builder.getResult(pos, idx); 201 predList.emplace_back(resultPos, builder.getIsNotNull()); 202 getTreePredicates(predList, typeValue, builder, inputs, 203 builder.getType(resultPos)); 204 } 205 } 206 207 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 208 Value val, PredicateBuilder &builder, 209 DenseMap<Value, Position *> &inputs, 210 TypePosition *pos) { 211 // Check for a constraint on a constant type. 212 if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) { 213 if (Attribute type = typeOp.getConstantTypeAttr()) 214 predList.emplace_back(pos, builder.getTypeConstraint(type)); 215 } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) { 216 if (Attribute typeAttr = typeOp.getConstantTypesAttr()) 217 predList.emplace_back(pos, builder.getTypeConstraint(typeAttr)); 218 } 219 } 220 221 /// Collect the tree predicates anchored at the given value. 222 static void getTreePredicates(std::vector<PositionalPredicate> &predList, 223 Value val, PredicateBuilder &builder, 224 DenseMap<Value, Position *> &inputs, 225 Position *pos) { 226 // Make sure this input value is accessible to the rewrite. 227 auto it = inputs.try_emplace(val, pos); 228 if (!it.second) { 229 // If this is an input value that has been visited in the tree, add a 230 // constraint to ensure that both instances refer to the same value. 231 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp, 232 pdl::TypeOp>(val.getDefiningOp())) { 233 auto minMaxPositions = 234 std::minmax(pos, it.first->second, comparePosDepth); 235 predList.emplace_back(minMaxPositions.second, 236 builder.getEqualTo(minMaxPositions.first)); 237 } 238 return; 239 } 240 241 TypeSwitch<Position *>(pos) 242 .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) { 243 getTreePredicates(predList, val, builder, inputs, pos); 244 }) 245 .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) { 246 getOperandTreePredicates(predList, val, builder, inputs, pos); 247 }) 248 .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); 249 } 250 251 static void getAttributePredicates(pdl::AttributeOp op, 252 std::vector<PositionalPredicate> &predList, 253 PredicateBuilder &builder, 254 DenseMap<Value, Position *> &inputs) { 255 Position *&attrPos = inputs[op]; 256 if (attrPos) 257 return; 258 Attribute value = op.getValueAttr(); 259 assert(value && "expected non-tree `pdl.attribute` to contain a value"); 260 attrPos = builder.getAttributeLiteral(value); 261 } 262 263 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, 264 std::vector<PositionalPredicate> &predList, 265 PredicateBuilder &builder, 266 DenseMap<Value, Position *> &inputs) { 267 OperandRange arguments = op.getArgs(); 268 269 std::vector<Position *> allPositions; 270 allPositions.reserve(arguments.size()); 271 for (Value arg : arguments) 272 allPositions.push_back(inputs.lookup(arg)); 273 274 // Push the constraint to the furthest position. 275 Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), 276 comparePosDepth); 277 ResultRange results = op.getResults(); 278 PredicateBuilder::Predicate pred = builder.getConstraint( 279 op.getName(), allPositions, SmallVector<Type>(results.getTypes()), 280 op.getIsNegated()); 281 282 // For each result register a position so it can be used later 283 for (auto [i, result] : llvm::enumerate(results)) { 284 ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first); 285 ConstraintPosition *pos = builder.getConstraintPosition(q, i); 286 auto [it, inserted] = inputs.try_emplace(result, pos); 287 // If this is an input value that has been visited in the tree, add a 288 // constraint to ensure that both instances refer to the same value. 289 if (!inserted) { 290 Position *first = pos; 291 Position *second = it->second; 292 if (comparePosDepth(second, first)) 293 std::tie(second, first) = std::make_pair(first, second); 294 295 predList.emplace_back(second, builder.getEqualTo(first)); 296 } 297 } 298 predList.emplace_back(pos, pred); 299 } 300 301 static void getResultPredicates(pdl::ResultOp op, 302 std::vector<PositionalPredicate> &predList, 303 PredicateBuilder &builder, 304 DenseMap<Value, Position *> &inputs) { 305 Position *&resultPos = inputs[op]; 306 if (resultPos) 307 return; 308 309 // Ensure that the result isn't null. 310 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent())); 311 resultPos = builder.getResult(parentPos, op.getIndex()); 312 predList.emplace_back(resultPos, builder.getIsNotNull()); 313 } 314 315 static void getResultPredicates(pdl::ResultsOp op, 316 std::vector<PositionalPredicate> &predList, 317 PredicateBuilder &builder, 318 DenseMap<Value, Position *> &inputs) { 319 Position *&resultPos = inputs[op]; 320 if (resultPos) 321 return; 322 323 // Ensure that the result isn't null if the result has an index. 324 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent())); 325 bool isVariadic = isa<pdl::RangeType>(op.getType()); 326 std::optional<unsigned> index = op.getIndex(); 327 resultPos = builder.getResultGroup(parentPos, index, isVariadic); 328 if (index) 329 predList.emplace_back(resultPos, builder.getIsNotNull()); 330 } 331 332 static void getTypePredicates(Value typeValue, 333 function_ref<Attribute()> typeAttrFn, 334 PredicateBuilder &builder, 335 DenseMap<Value, Position *> &inputs) { 336 Position *&typePos = inputs[typeValue]; 337 if (typePos) 338 return; 339 Attribute typeAttr = typeAttrFn(); 340 assert(typeAttr && 341 "expected non-tree `pdl.type`/`pdl.types` to contain a value"); 342 typePos = builder.getTypeLiteral(typeAttr); 343 } 344 345 /// Collect all of the predicates that cannot be determined via walking the 346 /// tree. 347 static void getNonTreePredicates(pdl::PatternOp pattern, 348 std::vector<PositionalPredicate> &predList, 349 PredicateBuilder &builder, 350 DenseMap<Value, Position *> &inputs) { 351 for (Operation &op : pattern.getBodyRegion().getOps()) { 352 TypeSwitch<Operation *>(&op) 353 .Case([&](pdl::AttributeOp attrOp) { 354 getAttributePredicates(attrOp, predList, builder, inputs); 355 }) 356 .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) { 357 getConstraintPredicates(constraintOp, predList, builder, inputs); 358 }) 359 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 360 getResultPredicates(resultOp, predList, builder, inputs); 361 }) 362 .Case([&](pdl::TypeOp typeOp) { 363 getTypePredicates( 364 typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder, 365 inputs); 366 }) 367 .Case([&](pdl::TypesOp typeOp) { 368 getTypePredicates( 369 typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder, 370 inputs); 371 }); 372 } 373 } 374 375 namespace { 376 377 /// An op accepting a value at an optional index. 378 struct OpIndex { 379 Value parent; 380 std::optional<unsigned> index; 381 }; 382 383 /// The parent and operand index of each operation for each root, stored 384 /// as a nested map [root][operation]. 385 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>; 386 387 } // namespace 388 389 /// Given a pattern, determines the set of roots present in this pattern. 390 /// These are the operations whose results are not consumed by other operations. 391 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) { 392 // First, collect all the operations that are used as operands 393 // to other operations. These are not roots by default. 394 DenseSet<Value> used; 395 for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) { 396 for (Value operand : operationOp.getOperandValues()) 397 TypeSwitch<Operation *>(operand.getDefiningOp()) 398 .Case<pdl::ResultOp, pdl::ResultsOp>( 399 [&used](auto resultOp) { used.insert(resultOp.getParent()); }); 400 } 401 402 // Remove the specified root from the use set, so that we can 403 // always select it as a root, even if it is used by other operations. 404 if (Value root = pattern.getRewriter().getRoot()) 405 used.erase(root); 406 407 // Finally, collect all the unused operations. 408 SmallVector<Value> roots; 409 for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) 410 if (!used.contains(operationOp)) 411 roots.push_back(operationOp); 412 413 return roots; 414 } 415 416 /// Given a list of candidate roots, builds the cost graph for connecting them. 417 /// The graph is formed by traversing the DAG of operations starting from each 418 /// root and marking the depth of each connector value (operand). Then we join 419 /// the candidate roots based on the common connector values, taking the one 420 /// with the minimum depth. Along the way, we compute, for each candidate root, 421 /// a mapping from each operation (in the DAG underneath this root) to its 422 /// parent operation and the corresponding operand index. 423 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph, 424 ParentMaps &parentMaps) { 425 426 // The entry of a queue. The entry consists of the following items: 427 // * the value in the DAG underneath the root; 428 // * the parent of the value; 429 // * the operand index of the value in its parent; 430 // * the depth of the visited value. 431 struct Entry { 432 Entry(Value value, Value parent, std::optional<unsigned> index, 433 unsigned depth) 434 : value(value), parent(parent), index(index), depth(depth) {} 435 436 Value value; 437 Value parent; 438 std::optional<unsigned> index; 439 unsigned depth; 440 }; 441 442 // A root of a value and its depth (distance from root to the value). 443 struct RootDepth { 444 Value root; 445 unsigned depth = 0; 446 }; 447 448 // Map from candidate connector values to their roots and depths. Using a 449 // small vector with 1 entry because most values belong to a single root. 450 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths; 451 452 // Perform a breadth-first traversal of the op DAG rooted at each root. 453 for (Value root : roots) { 454 // The queue of visited values. A value may be present multiple times in 455 // the queue, for multiple parents. We only accept the first occurrence, 456 // which is guaranteed to have the lowest depth. 457 std::queue<Entry> toVisit; 458 toVisit.emplace(root, Value(), 0, 0); 459 460 // The map from value to its parent for the current root. 461 DenseMap<Value, OpIndex> &parentMap = parentMaps[root]; 462 463 while (!toVisit.empty()) { 464 Entry entry = toVisit.front(); 465 toVisit.pop(); 466 // Skip if already visited. 467 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) 468 continue; 469 470 // Mark the root and depth of the value. 471 connectorsRootsDepths[entry.value].push_back({root, entry.depth}); 472 473 // Traverse the operands of an operation and result ops. 474 // We intentionally do not traverse attributes and types, because those 475 // are expensive to join on. 476 TypeSwitch<Operation *>(entry.value.getDefiningOp()) 477 .Case<pdl::OperationOp>([&](auto operationOp) { 478 OperandRange operands = operationOp.getOperandValues(); 479 // Special case when we pass all the operands in one range. 480 // For those, the index is empty. 481 if (operands.size() == 1 && 482 isa<pdl::RangeType>(operands[0].getType())) { 483 toVisit.emplace(operands[0], entry.value, std::nullopt, 484 entry.depth + 1); 485 return; 486 } 487 488 // Default case: visit all the operands. 489 for (const auto &p : 490 llvm::enumerate(operationOp.getOperandValues())) 491 toVisit.emplace(p.value(), entry.value, p.index(), 492 entry.depth + 1); 493 }) 494 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { 495 toVisit.emplace(resultOp.getParent(), entry.value, 496 resultOp.getIndex(), entry.depth); 497 }); 498 } 499 } 500 501 // Now build the cost graph. 502 // This is simply a minimum over all depths for the target root. 503 unsigned nextID = 0; 504 for (const auto &connectorRootsDepths : connectorsRootsDepths) { 505 Value value = connectorRootsDepths.first; 506 ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second; 507 // If there is only one root for this value, this will not trigger 508 // any edges in the cost graph (a perf optimization). 509 if (rootsDepths.size() == 1) 510 continue; 511 512 for (const RootDepth &p : rootsDepths) { 513 for (const RootDepth &q : rootsDepths) { 514 if (&p == &q) 515 continue; 516 // Insert or retrieve the property of edge from p to q. 517 RootOrderingEntry &entry = graph[q.root][p.root]; 518 if (!entry.connector /* new edge */ || entry.cost.first > q.depth) { 519 if (!entry.connector) 520 entry.cost.second = nextID++; 521 entry.cost.first = q.depth; 522 entry.connector = value; 523 } 524 } 525 } 526 } 527 528 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && 529 "the pattern contains a candidate root disconnected from the others"); 530 } 531 532 /// Returns true if the operand at the given index needs to be queried using an 533 /// operand group, i.e., if it is variadic itself or follows a variadic operand. 534 static bool useOperandGroup(pdl::OperationOp op, unsigned index) { 535 OperandRange operands = op.getOperandValues(); 536 assert(index < operands.size() && "operand index out of range"); 537 for (unsigned i = 0; i <= index; ++i) 538 if (isa<pdl::RangeType>(operands[i].getType())) 539 return true; 540 return false; 541 } 542 543 /// Visit a node during upward traversal. 544 static void visitUpward(std::vector<PositionalPredicate> &predList, 545 OpIndex opIndex, PredicateBuilder &builder, 546 DenseMap<Value, Position *> &valueToPosition, 547 Position *&pos, unsigned rootID) { 548 Value value = opIndex.parent; 549 TypeSwitch<Operation *>(value.getDefiningOp()) 550 .Case<pdl::OperationOp>([&](auto operationOp) { 551 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 552 553 // Get users and iterate over them. 554 Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true); 555 Position *foreachPos = builder.getForEach(usersPos, rootID); 556 OperationPosition *opPos = builder.getPassthroughOp(foreachPos); 557 558 // Compare the operand(s) of the user against the input value(s). 559 Position *operandPos; 560 if (!opIndex.index) { 561 // We are querying all the operands of the operation. 562 operandPos = builder.getAllOperands(opPos); 563 } else if (useOperandGroup(operationOp, *opIndex.index)) { 564 // We are querying an operand group. 565 Type type = operationOp.getOperandValues()[*opIndex.index].getType(); 566 bool variadic = isa<pdl::RangeType>(type); 567 operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic); 568 } else { 569 // We are querying an individual operand. 570 operandPos = builder.getOperand(opPos, *opIndex.index); 571 } 572 predList.emplace_back(operandPos, builder.getEqualTo(pos)); 573 574 // Guard against duplicate upward visits. These are not possible, 575 // because if this value was already visited, it would have been 576 // cheaper to start the traversal at this value rather than at the 577 // `connector`, violating the optimality of our spanning tree. 578 bool inserted = valueToPosition.try_emplace(value, opPos).second; 579 (void)inserted; 580 assert(inserted && "duplicate upward visit"); 581 582 // Obtain the tree predicates at the current value. 583 getTreePredicates(predList, value, builder, valueToPosition, opPos, 584 opIndex.index); 585 586 // Update the position 587 pos = opPos; 588 }) 589 .Case<pdl::ResultOp>([&](auto resultOp) { 590 // Traverse up an individual result. 591 auto *opPos = dyn_cast<OperationPosition>(pos); 592 assert(opPos && "operations and results must be interleaved"); 593 pos = builder.getResult(opPos, *opIndex.index); 594 595 // Insert the result position in case we have not visited it yet. 596 valueToPosition.try_emplace(value, pos); 597 }) 598 .Case<pdl::ResultsOp>([&](auto resultOp) { 599 // Traverse up a group of results. 600 auto *opPos = dyn_cast<OperationPosition>(pos); 601 assert(opPos && "operations and results must be interleaved"); 602 bool isVariadic = isa<pdl::RangeType>(value.getType()); 603 if (opIndex.index) 604 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); 605 else 606 pos = builder.getAllResults(opPos); 607 608 // Insert the result position in case we have not visited it yet. 609 valueToPosition.try_emplace(value, pos); 610 }); 611 } 612 613 /// Given a pattern operation, build the set of matcher predicates necessary to 614 /// match this pattern. 615 static Value buildPredicateList(pdl::PatternOp pattern, 616 PredicateBuilder &builder, 617 std::vector<PositionalPredicate> &predList, 618 DenseMap<Value, Position *> &valueToPosition) { 619 SmallVector<Value> roots = detectRoots(pattern); 620 621 // Build the root ordering graph and compute the parent maps. 622 RootOrderingGraph graph; 623 ParentMaps parentMaps; 624 buildCostGraph(roots, graph, parentMaps); 625 LLVM_DEBUG({ 626 llvm::dbgs() << "Graph:\n"; 627 for (auto &target : graph) { 628 llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first 629 << "\n"; 630 for (auto &source : target.second) { 631 RootOrderingEntry &entry = source.second; 632 llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first 633 << ":" << entry.cost.second << " via " 634 << entry.connector.getLoc() << "\n"; 635 } 636 } 637 }); 638 639 // Solve the optimal branching problem for each candidate root, or use the 640 // provided one. 641 Value bestRoot = pattern.getRewriter().getRoot(); 642 OptimalBranching::EdgeList bestEdges; 643 if (!bestRoot) { 644 unsigned bestCost = 0; 645 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); 646 for (Value root : roots) { 647 OptimalBranching solver(graph, root); 648 unsigned cost = solver.solve(); 649 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); 650 if (!bestRoot || bestCost > cost) { 651 bestCost = cost; 652 bestRoot = root; 653 bestEdges = solver.preOrderTraversal(roots); 654 } 655 } 656 } else { 657 OptimalBranching solver(graph, bestRoot); 658 solver.solve(); 659 bestEdges = solver.preOrderTraversal(roots); 660 } 661 662 // Print the best solution. 663 LLVM_DEBUG({ 664 llvm::dbgs() << "Best tree:\n"; 665 for (const std::pair<Value, Value> &edge : bestEdges) { 666 llvm::dbgs() << " * " << edge.first; 667 if (edge.second) 668 llvm::dbgs() << " <- " << edge.second; 669 llvm::dbgs() << "\n"; 670 } 671 }); 672 673 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); 674 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); 675 676 // The best root is the starting point for the traversal. Get the tree 677 // predicates for the DAG rooted at bestRoot. 678 getTreePredicates(predList, bestRoot, builder, valueToPosition, 679 builder.getRoot()); 680 681 // Traverse the selected optimal branching. For all edges in order, traverse 682 // up starting from the connector, until the candidate root is reached, and 683 // call getTreePredicates at every node along the way. 684 for (const auto &it : llvm::enumerate(bestEdges)) { 685 Value target = it.value().first; 686 Value source = it.value().second; 687 688 // Check if we already visited the target root. This happens in two cases: 689 // 1) the initial root (bestRoot); 690 // 2) a root that is dominated by (contained in the subtree rooted at) an 691 // already visited root. 692 if (valueToPosition.count(target)) 693 continue; 694 695 // Determine the connector. 696 Value connector = graph[target][source].connector; 697 assert(connector && "invalid edge"); 698 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); 699 DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target); 700 Position *pos = valueToPosition.lookup(connector); 701 assert(pos && "connector has not been traversed yet"); 702 703 // Traverse from the connector upwards towards the target root. 704 for (Value value = connector; value != target;) { 705 OpIndex opIndex = parentMap.lookup(value); 706 assert(opIndex.parent && "missing parent"); 707 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index()); 708 value = opIndex.parent; 709 } 710 } 711 712 getNonTreePredicates(pattern, predList, builder, valueToPosition); 713 714 return bestRoot; 715 } 716 717 //===----------------------------------------------------------------------===// 718 // Pattern Predicate Tree Merging 719 //===----------------------------------------------------------------------===// 720 721 namespace { 722 723 /// This class represents a specific predicate applied to a position, and 724 /// provides hashing and ordering operators. This class allows for computing a 725 /// frequence sum and ordering predicates based on a cost model. 726 struct OrderedPredicate { 727 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 728 : position(ip.first), question(ip.second) {} 729 OrderedPredicate(const PositionalPredicate &ip) 730 : position(ip.position), question(ip.question) {} 731 732 /// The position this predicate is applied to. 733 Position *position; 734 735 /// The question that is applied by this predicate onto the position. 736 Qualifier *question; 737 738 /// The first and second order benefit sums. 739 /// The primary sum is the number of occurrences of this predicate among all 740 /// of the patterns. 741 unsigned primary = 0; 742 /// The secondary sum is a squared summation of the primary sum of all of the 743 /// predicates within each pattern that contains this predicate. This allows 744 /// for favoring predicates that are more commonly shared within a pattern, as 745 /// opposed to those shared across patterns. 746 unsigned secondary = 0; 747 748 /// The tie breaking ID, used to preserve a deterministic (insertion) order 749 /// among all the predicates with the same priority, depth, and position / 750 /// predicate dependency. 751 unsigned id = 0; 752 753 /// A map between a pattern operation and the answer to the predicate question 754 /// within that pattern. 755 DenseMap<Operation *, Qualifier *> patternToAnswer; 756 757 /// Returns true if this predicate is ordered before `rhs`, based on the cost 758 /// model. 759 bool operator<(const OrderedPredicate &rhs) const { 760 // Sort by: 761 // * higher first and secondary order sums 762 // * lower depth 763 // * lower position dependency 764 // * lower predicate dependency 765 // * lower tie breaking ID 766 auto *rhsPos = rhs.position; 767 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), 768 rhsPos->getKind(), rhs.question->getKind(), rhs.id) > 769 std::make_tuple(rhs.primary, rhs.secondary, 770 position->getOperationDepth(), position->getKind(), 771 question->getKind(), id); 772 } 773 }; 774 775 /// A DenseMapInfo for OrderedPredicate based solely on the position and 776 /// question. 777 struct OrderedPredicateDenseInfo { 778 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 779 780 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 781 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 782 static bool isEqual(const OrderedPredicate &lhs, 783 const OrderedPredicate &rhs) { 784 return lhs.position == rhs.position && lhs.question == rhs.question; 785 } 786 static unsigned getHashValue(const OrderedPredicate &p) { 787 return llvm::hash_combine(p.position, p.question); 788 } 789 }; 790 791 /// This class wraps a set of ordered predicates that are used within a specific 792 /// pattern operation. 793 struct OrderedPredicateList { 794 OrderedPredicateList(pdl::PatternOp pattern, Value root) 795 : pattern(pattern), root(root) {} 796 797 pdl::PatternOp pattern; 798 Value root; 799 DenseSet<OrderedPredicate *> predicates; 800 }; 801 } // namespace 802 803 /// Returns true if the given matcher refers to the same predicate as the given 804 /// ordered predicate. This means that the position and questions of the two 805 /// match. 806 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 807 return node->getPosition() == predicate->position && 808 node->getQuestion() == predicate->question; 809 } 810 811 /// Get or insert a child matcher for the given parent switch node, given a 812 /// predicate and parent pattern. 813 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 814 OrderedPredicate *predicate, 815 pdl::PatternOp pattern) { 816 assert(isSamePredicate(node, predicate) && 817 "expected matcher to equal the given predicate"); 818 819 auto it = predicate->patternToAnswer.find(pattern); 820 assert(it != predicate->patternToAnswer.end() && 821 "expected pattern to exist in predicate"); 822 return node->getChildren().insert({it->second, nullptr}).first->second; 823 } 824 825 /// Build the matcher CFG by "pushing" patterns through by sorted predicate 826 /// order. A pattern will traverse as far as possible using common predicates 827 /// and then either diverge from the CFG or reach the end of a branch and start 828 /// creating new nodes. 829 static void propagatePattern(std::unique_ptr<MatcherNode> &node, 830 OrderedPredicateList &list, 831 std::vector<OrderedPredicate *>::iterator current, 832 std::vector<OrderedPredicate *>::iterator end) { 833 if (current == end) { 834 // We've hit the end of a pattern, so create a successful result node. 835 node = 836 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node)); 837 838 // If the pattern doesn't contain this predicate, ignore it. 839 } else if (!list.predicates.contains(*current)) { 840 propagatePattern(node, list, std::next(current), end); 841 842 // If the current matcher node is invalid, create a new one for this 843 // position and continue propagation. 844 } else if (!node) { 845 // Create a new node at this position and continue 846 node = std::make_unique<SwitchNode>((*current)->position, 847 (*current)->question); 848 propagatePattern( 849 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 850 list, std::next(current), end); 851 852 // If the matcher has already been created, and it is for this predicate we 853 // continue propagation to the child. 854 } else if (isSamePredicate(node.get(), *current)) { 855 propagatePattern( 856 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 857 list, std::next(current), end); 858 859 // If the matcher doesn't match the current predicate, insert a branch as 860 // the common set of matchers has diverged. 861 } else { 862 propagatePattern(node->getFailureNode(), list, current, end); 863 } 864 } 865 866 /// Fold any switch nodes nested under `node` to boolean nodes when possible. 867 /// `node` is updated in-place if it is a switch. 868 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 869 if (!node) 870 return; 871 872 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 873 SwitchNode::ChildMapT &children = switchNode->getChildren(); 874 for (auto &it : children) 875 foldSwitchToBool(it.second); 876 877 // If the node only contains one child, collapse it into a boolean predicate 878 // node. 879 if (children.size() == 1) { 880 auto *childIt = children.begin(); 881 node = std::make_unique<BoolNode>( 882 node->getPosition(), node->getQuestion(), childIt->first, 883 std::move(childIt->second), std::move(node->getFailureNode())); 884 } 885 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 886 foldSwitchToBool(boolNode->getSuccessNode()); 887 } 888 889 foldSwitchToBool(node->getFailureNode()); 890 } 891 892 /// Insert an exit node at the end of the failure path of the `root`. 893 static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 894 while (*root) 895 root = &(*root)->getFailureNode(); 896 *root = std::make_unique<ExitNode>(); 897 } 898 899 /// Sorts the range begin/end with the partial order given by cmp. 900 template <typename Iterator, typename Compare> 901 static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) { 902 while (begin != end) { 903 // Cannot compute sortBeforeOthers in the predicate of stable_partition 904 // because stable_partition will not keep the [begin, end) range intact 905 // while it runs. 906 llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers; 907 for (auto i = begin; i != end; ++i) { 908 if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); })) 909 sortBeforeOthers.insert(*i); 910 } 911 912 auto const next = std::stable_partition(begin, end, [&](auto const &a) { 913 return sortBeforeOthers.contains(a); 914 }); 915 assert(next != begin && "not a partial ordering"); 916 begin = next; 917 } 918 } 919 920 /// Returns true if 'b' depends on a result of 'a'. 921 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) { 922 auto *cqa = dyn_cast<ConstraintQuestion>(a->question); 923 if (!cqa) 924 return false; 925 926 auto positionDependsOnA = [&](Position *p) { 927 auto *cp = dyn_cast<ConstraintPosition>(p); 928 return cp && cp->getQuestion() == cqa; 929 }; 930 931 if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) { 932 // Does any argument of b use a? 933 return llvm::any_of(cqb->getArgs(), positionDependsOnA); 934 } 935 if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) { 936 return positionDependsOnA(b->position) || 937 positionDependsOnA(equalTo->getValue()); 938 } 939 return positionDependsOnA(b->position); 940 } 941 942 /// Given a module containing PDL pattern operations, generate a matcher tree 943 /// using the patterns within the given module and return the root matcher node. 944 std::unique_ptr<MatcherNode> 945 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 946 DenseMap<Value, Position *> &valueToPosition) { 947 // The set of predicates contained within the pattern operations of the 948 // module. 949 struct PatternPredicates { 950 PatternPredicates(pdl::PatternOp pattern, Value root, 951 std::vector<PositionalPredicate> predicates) 952 : pattern(pattern), root(root), predicates(std::move(predicates)) {} 953 954 /// A pattern. 955 pdl::PatternOp pattern; 956 957 /// A root of the pattern chosen among the candidate roots in pdl.rewrite. 958 Value root; 959 960 /// The extracted predicates for this pattern and root. 961 std::vector<PositionalPredicate> predicates; 962 }; 963 964 SmallVector<PatternPredicates, 16> patternsAndPredicates; 965 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 966 std::vector<PositionalPredicate> predicateList; 967 Value root = 968 buildPredicateList(pattern, builder, predicateList, valueToPosition); 969 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); 970 } 971 972 // Associate a pattern result with each unique predicate. 973 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 974 for (auto &patternAndPredList : patternsAndPredicates) { 975 for (auto &predicate : patternAndPredList.predicates) { 976 auto it = uniqued.insert(predicate); 977 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, 978 predicate.answer); 979 // Mark the insertion order (0-based indexing). 980 if (it.second) 981 it.first->id = uniqued.size() - 1; 982 } 983 } 984 985 // Associate each pattern to a set of its ordered predicates for later lookup. 986 std::vector<OrderedPredicateList> lists; 987 lists.reserve(patternsAndPredicates.size()); 988 for (auto &patternAndPredList : patternsAndPredicates) { 989 OrderedPredicateList list(patternAndPredList.pattern, 990 patternAndPredList.root); 991 for (auto &predicate : patternAndPredList.predicates) { 992 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 993 list.predicates.insert(orderedPredicate); 994 995 // Increment the primary sum for each reference to a particular predicate. 996 ++orderedPredicate->primary; 997 } 998 lists.push_back(std::move(list)); 999 } 1000 1001 // For a particular pattern, get the total primary sum and add it to the 1002 // secondary sum of each predicate. Square the primary sums to emphasize 1003 // shared predicates within rather than across patterns. 1004 for (auto &list : lists) { 1005 unsigned total = 0; 1006 for (auto *predicate : list.predicates) 1007 total += predicate->primary * predicate->primary; 1008 for (auto *predicate : list.predicates) 1009 predicate->secondary += total; 1010 } 1011 1012 // Sort the set of predicates now that the cost primary and secondary sums 1013 // have been computed. 1014 std::vector<OrderedPredicate *> ordered; 1015 ordered.reserve(uniqued.size()); 1016 for (auto &ip : uniqued) 1017 ordered.push_back(&ip); 1018 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) { 1019 return *lhs < *rhs; 1020 }); 1021 1022 // Mostly keep the now established order, but also ensure that 1023 // ConstraintQuestions come after the results they use. 1024 stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn); 1025 1026 // Build the matchers for each of the pattern predicate lists. 1027 std::unique_ptr<MatcherNode> root; 1028 for (OrderedPredicateList &list : lists) 1029 propagatePattern(root, list, ordered.begin(), ordered.end()); 1030 1031 // Collapse the graph and insert the exit node. 1032 foldSwitchToBool(root); 1033 insertExitNode(&root); 1034 return root; 1035 } 1036 1037 //===----------------------------------------------------------------------===// 1038 // MatcherNode 1039 //===----------------------------------------------------------------------===// 1040 1041 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 1042 std::unique_ptr<MatcherNode> failureNode) 1043 : position(p), question(q), failureNode(std::move(failureNode)), 1044 matcherTypeID(matcherTypeID) {} 1045 1046 //===----------------------------------------------------------------------===// 1047 // BoolNode 1048 //===----------------------------------------------------------------------===// 1049 1050 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 1051 std::unique_ptr<MatcherNode> successNode, 1052 std::unique_ptr<MatcherNode> failureNode) 1053 : MatcherNode(TypeID::get<BoolNode>(), position, question, 1054 std::move(failureNode)), 1055 answer(answer), successNode(std::move(successNode)) {} 1056 1057 //===----------------------------------------------------------------------===// 1058 // SuccessNode 1059 //===----------------------------------------------------------------------===// 1060 1061 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, 1062 std::unique_ptr<MatcherNode> failureNode) 1063 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 1064 /*question=*/nullptr, std::move(failureNode)), 1065 pattern(pattern), root(root) {} 1066 1067 //===----------------------------------------------------------------------===// 1068 // SwitchNode 1069 //===----------------------------------------------------------------------===// 1070 1071 SwitchNode::SwitchNode(Position *position, Qualifier *question) 1072 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 1073