1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP 25 #include "mlir/Conversion/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::pdl_to_pdl_interp; 30 31 //===----------------------------------------------------------------------===// 32 // PatternLowering 33 //===----------------------------------------------------------------------===// 34 35 namespace { 36 /// This class generators operations within the PDL Interpreter dialect from a 37 /// given module containing PDL pattern operations. 38 struct PatternLowering { 39 public: 40 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, 41 DenseMap<Operation *, PDLPatternConfigSet *> *configMap); 42 43 /// Generate code for matching and rewriting based on the pattern operations 44 /// within the module. 45 void lower(ModuleOp module); 46 47 private: 48 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 49 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 50 51 /// Generate interpreter operations for the tree rooted at the given matcher 52 /// node, in the specified region. 53 Block *generateMatcher(MatcherNode &node, Region ®ion); 54 55 /// Get or create an access to the provided positional value in the current 56 /// block. This operation may mutate the provided block pointer if nested 57 /// regions (i.e., pdl_interp.iterate) are required. 58 Value getValueAt(Block *¤tBlock, Position *pos); 59 60 /// Create the interpreter predicate operations. This operation may mutate the 61 /// provided current block pointer if nested regions (iterates) are required. 62 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 63 64 /// Create the interpreter switch / predicate operations, with several case 65 /// destinations. This operation never mutates the provided current block 66 /// pointer, because the switch operation does not need Values beyond `val`. 67 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 68 69 /// Create the interpreter operations to record a successful pattern match 70 /// using the contained root operation. This operation may mutate the current 71 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 72 void generate(SuccessNode *successNode, Block *¤tBlock); 73 74 /// Generate a rewriter function for the given pattern operation, and returns 75 /// a reference to that function. 76 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 77 SmallVectorImpl<Position *> &usedMatchValues); 78 79 /// Generate the rewriter code for the given operation. 80 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::AttributeOp attrOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::EraseOp eraseOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::OperationOp operationOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::RangeOp rangeOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::ReplaceOp replaceOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::ResultOp resultOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 void generateRewriter(pdl::ResultsOp resultOp, 102 DenseMap<Value, Value> &rewriteValues, 103 function_ref<Value(Value)> mapRewriteValue); 104 void generateRewriter(pdl::TypeOp typeOp, 105 DenseMap<Value, Value> &rewriteValues, 106 function_ref<Value(Value)> mapRewriteValue); 107 void generateRewriter(pdl::TypesOp typeOp, 108 DenseMap<Value, Value> &rewriteValues, 109 function_ref<Value(Value)> mapRewriteValue); 110 111 /// Generate the values used for resolving the result types of an operation 112 /// created within a dag rewriter region. If the result types of the operation 113 /// should be inferred, `hasInferredResultTypes` is set to true. 114 void generateOperationResultTypeRewriter( 115 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 116 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 117 bool &hasInferredResultTypes); 118 119 /// A builder to use when generating interpreter operations. 120 OpBuilder builder; 121 122 /// The matcher function used for all match related logic within PDL patterns. 123 pdl_interp::FuncOp matcherFunc; 124 125 /// The rewriter module containing the all rewrite related logic within PDL 126 /// patterns. 127 ModuleOp rewriterModule; 128 129 /// The symbol table of the rewriter module used for insertion. 130 SymbolTable rewriterSymbolTable; 131 132 /// A scoped map connecting a position with the corresponding interpreter 133 /// value. 134 ValueMap values; 135 136 /// A stack of blocks used as the failure destination for matcher nodes that 137 /// don't have an explicit failure path. 138 SmallVector<Block *, 8> failureBlockStack; 139 140 /// A mapping between values defined in a pattern match, and the corresponding 141 /// positional value. 142 DenseMap<Value, Position *> valueToPosition; 143 144 /// The set of operation values whose whose location will be used for newly 145 /// generated operations. 146 SetVector<Value> locOps; 147 148 /// A mapping between pattern operations and the corresponding configuration 149 /// set. 150 DenseMap<Operation *, PDLPatternConfigSet *> *configMap; 151 }; 152 } // namespace 153 154 PatternLowering::PatternLowering( 155 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, 156 DenseMap<Operation *, PDLPatternConfigSet *> *configMap) 157 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 158 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule), 159 configMap(configMap) {} 160 161 void PatternLowering::lower(ModuleOp module) { 162 PredicateUniquer predicateUniquer; 163 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 164 165 // Define top-level scope for the arguments to the matcher function. 166 ValueMapScope topLevelValueScope(values); 167 168 // Insert the root operation, i.e. argument to the matcher, at the root 169 // position. 170 Block *matcherEntryBlock = &matcherFunc.front(); 171 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 172 173 // Generate a root matcher node from the provided PDL module. 174 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 175 module, predicateBuilder, valueToPosition); 176 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 177 assert(failureBlockStack.empty() && "failed to empty the stack"); 178 179 // After generation, merged the first matched block into the entry. 180 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 181 firstMatcherBlock->getOperations()); 182 firstMatcherBlock->erase(); 183 } 184 185 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { 186 // Push a new scope for the values used by this matcher. 187 Block *block = ®ion.emplaceBlock(); 188 ValueMapScope scope(values); 189 190 // If this is the return node, simply insert the corresponding interpreter 191 // finalize. 192 if (isa<ExitNode>(node)) { 193 builder.setInsertionPointToEnd(block); 194 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 195 return block; 196 } 197 198 // Get the next block in the match sequence. 199 // This is intentionally executed first, before we get the value for the 200 // position associated with the node, so that we preserve an "there exist" 201 // semantics: if getting a value requires an upward traversal (going from a 202 // value to its consumers), we want to perform the check on all the consumers 203 // before we pass control to the failure node. 204 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 205 Block *failureBlock; 206 if (failureNode) { 207 failureBlock = generateMatcher(*failureNode, region); 208 failureBlockStack.push_back(failureBlock); 209 } else { 210 assert(!failureBlockStack.empty() && "expected valid failure block"); 211 failureBlock = failureBlockStack.back(); 212 } 213 214 // If this node contains a position, get the corresponding value for this 215 // block. 216 Block *currentBlock = block; 217 Position *position = node.getPosition(); 218 Value val = position ? getValueAt(currentBlock, position) : Value(); 219 220 // If this value corresponds to an operation, record that we are going to use 221 // its location as part of a fused location. 222 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 223 if (isOperationValue) 224 locOps.insert(val); 225 226 // Dispatch to the correct method based on derived node type. 227 TypeSwitch<MatcherNode *>(&node) 228 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { 229 this->generate(derivedNode, currentBlock, val); 230 }) 231 .Case([&](SuccessNode *successNode) { 232 generate(successNode, currentBlock); 233 }); 234 235 // Pop all the failure blocks that were inserted due to nesting of 236 // pdl_interp.iterate. 237 while (failureBlockStack.back() != failureBlock) { 238 failureBlockStack.pop_back(); 239 assert(!failureBlockStack.empty() && "unable to locate failure block"); 240 } 241 242 // Pop the new failure block. 243 if (failureNode) 244 failureBlockStack.pop_back(); 245 246 if (isOperationValue) 247 locOps.remove(val); 248 249 return block; 250 } 251 252 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 253 if (Value val = values.lookup(pos)) 254 return val; 255 256 // Get the value for the parent position. 257 Value parentVal; 258 if (Position *parent = pos->getParent()) 259 parentVal = getValueAt(currentBlock, parent); 260 261 // TODO: Use a location from the position. 262 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); 263 builder.setInsertionPointToEnd(currentBlock); 264 Value value; 265 switch (pos->getKind()) { 266 case Predicates::OperationPos: { 267 auto *operationPos = cast<OperationPosition>(pos); 268 if (operationPos->isOperandDefiningOp()) 269 // Standard (downward) traversal which directly follows the defining op. 270 value = builder.create<pdl_interp::GetDefiningOpOp>( 271 loc, builder.getType<pdl::OperationType>(), parentVal); 272 else 273 // A passthrough operation position. 274 value = parentVal; 275 break; 276 } 277 case Predicates::UsersPos: { 278 auto *usersPos = cast<UsersPosition>(pos); 279 280 // The first operation retrieves the representative value of a range. 281 // This applies only when the parent is a range of values and we were 282 // requested to use a representative value (e.g., upward traversal). 283 if (parentVal.getType().isa<pdl::RangeType>() && 284 usersPos->useRepresentative()) 285 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 286 else 287 value = parentVal; 288 289 // The second operation retrieves the users. 290 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 291 break; 292 } 293 case Predicates::ForEachPos: { 294 assert(!failureBlockStack.empty() && "expected valid failure block"); 295 auto foreach = builder.create<pdl_interp::ForEachOp>( 296 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); 297 value = foreach.getLoopVariable(); 298 299 // Create the continuation block. 300 Block *continueBlock = builder.createBlock(&foreach.getRegion()); 301 builder.create<pdl_interp::ContinueOp>(loc); 302 failureBlockStack.push_back(continueBlock); 303 304 currentBlock = &foreach.getRegion().front(); 305 break; 306 } 307 case Predicates::OperandPos: { 308 auto *operandPos = cast<OperandPosition>(pos); 309 value = builder.create<pdl_interp::GetOperandOp>( 310 loc, builder.getType<pdl::ValueType>(), parentVal, 311 operandPos->getOperandNumber()); 312 break; 313 } 314 case Predicates::OperandGroupPos: { 315 auto *operandPos = cast<OperandGroupPosition>(pos); 316 Type valueTy = builder.getType<pdl::ValueType>(); 317 value = builder.create<pdl_interp::GetOperandsOp>( 318 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 319 parentVal, operandPos->getOperandGroupNumber()); 320 break; 321 } 322 case Predicates::AttributePos: { 323 auto *attrPos = cast<AttributePosition>(pos); 324 value = builder.create<pdl_interp::GetAttributeOp>( 325 loc, builder.getType<pdl::AttributeType>(), parentVal, 326 attrPos->getName().strref()); 327 break; 328 } 329 case Predicates::TypePos: { 330 if (parentVal.getType().isa<pdl::AttributeType>()) 331 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 332 else 333 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 334 break; 335 } 336 case Predicates::ResultPos: { 337 auto *resPos = cast<ResultPosition>(pos); 338 value = builder.create<pdl_interp::GetResultOp>( 339 loc, builder.getType<pdl::ValueType>(), parentVal, 340 resPos->getResultNumber()); 341 break; 342 } 343 case Predicates::ResultGroupPos: { 344 auto *resPos = cast<ResultGroupPosition>(pos); 345 Type valueTy = builder.getType<pdl::ValueType>(); 346 value = builder.create<pdl_interp::GetResultsOp>( 347 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 348 parentVal, resPos->getResultGroupNumber()); 349 break; 350 } 351 case Predicates::AttributeLiteralPos: { 352 auto *attrPos = cast<AttributeLiteralPosition>(pos); 353 value = 354 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); 355 break; 356 } 357 case Predicates::TypeLiteralPos: { 358 auto *typePos = cast<TypeLiteralPosition>(pos); 359 Attribute rawTypeAttr = typePos->getValue(); 360 if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>()) 361 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); 362 else 363 value = builder.create<pdl_interp::CreateTypesOp>( 364 loc, rawTypeAttr.cast<ArrayAttr>()); 365 break; 366 } 367 default: 368 llvm_unreachable("Generating unknown Position getter"); 369 break; 370 } 371 372 values.insert(pos, value); 373 return value; 374 } 375 376 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 377 Value val) { 378 Location loc = val.getLoc(); 379 Qualifier *question = boolNode->getQuestion(); 380 Qualifier *answer = boolNode->getAnswer(); 381 Region *region = currentBlock->getParent(); 382 383 // Execute the getValue queries first, so that we create success 384 // matcher in the correct (possibly nested) region. 385 SmallVector<Value> args; 386 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 387 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 388 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 389 for (Position *position : cstQuestion->getArgs()) 390 args.push_back(getValueAt(currentBlock, position)); 391 } 392 393 // Generate the matcher in the current (potentially nested) region 394 // and get the failure successor. 395 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); 396 Block *failure = failureBlockStack.back(); 397 398 // Finally, create the predicate. 399 builder.setInsertionPointToEnd(currentBlock); 400 Predicates::Kind kind = question->getKind(); 401 switch (kind) { 402 case Predicates::IsNotNullQuestion: 403 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 404 break; 405 case Predicates::OperationNameQuestion: { 406 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 407 builder.create<pdl_interp::CheckOperationNameOp>( 408 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 409 break; 410 } 411 case Predicates::TypeQuestion: { 412 auto *ans = cast<TypeAnswer>(answer); 413 if (val.getType().isa<pdl::RangeType>()) 414 builder.create<pdl_interp::CheckTypesOp>( 415 loc, val, ans->getValue().cast<ArrayAttr>(), success, failure); 416 else 417 builder.create<pdl_interp::CheckTypeOp>( 418 loc, val, ans->getValue().cast<TypeAttr>(), success, failure); 419 break; 420 } 421 case Predicates::AttributeQuestion: { 422 auto *ans = cast<AttributeAnswer>(answer); 423 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 424 success, failure); 425 break; 426 } 427 case Predicates::OperandCountAtLeastQuestion: 428 case Predicates::OperandCountQuestion: 429 builder.create<pdl_interp::CheckOperandCountOp>( 430 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 431 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 432 success, failure); 433 break; 434 case Predicates::ResultCountAtLeastQuestion: 435 case Predicates::ResultCountQuestion: 436 builder.create<pdl_interp::CheckResultCountOp>( 437 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 438 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 439 success, failure); 440 break; 441 case Predicates::EqualToQuestion: { 442 bool trueAnswer = isa<TrueAnswer>(answer); 443 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 444 trueAnswer ? success : failure, 445 trueAnswer ? failure : success); 446 break; 447 } 448 case Predicates::ConstraintQuestion: { 449 auto *cstQuestion = cast<ConstraintQuestion>(question); 450 builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(), 451 args, success, failure); 452 break; 453 } 454 default: 455 llvm_unreachable("Generating unknown Predicate operation"); 456 } 457 } 458 459 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 460 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 461 llvm::MapVector<Qualifier *, Block *> &dests) { 462 std::vector<ValT> values; 463 std::vector<Block *> blocks; 464 values.reserve(dests.size()); 465 blocks.reserve(dests.size()); 466 for (const auto &it : dests) { 467 blocks.push_back(it.second); 468 values.push_back(cast<PredT>(it.first)->getValue()); 469 } 470 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 471 } 472 473 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 474 Value val) { 475 Qualifier *question = switchNode->getQuestion(); 476 Region *region = currentBlock->getParent(); 477 Block *defaultDest = failureBlockStack.back(); 478 479 // If the switch question is not an exact answer, i.e. for the `at_least` 480 // cases, we generate a special block sequence. 481 Predicates::Kind kind = question->getKind(); 482 if (kind == Predicates::OperandCountAtLeastQuestion || 483 kind == Predicates::ResultCountAtLeastQuestion) { 484 // Order the children such that the cases are in reverse numerical order. 485 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 486 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 487 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 488 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 489 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 490 }); 491 492 // Build the destination for each child using the next highest child as a 493 // a failure destination. This essentially creates the following control 494 // flow: 495 // 496 // if (operand_count < 1) 497 // goto failure 498 // if (child1.match()) 499 // ... 500 // 501 // if (operand_count < 2) 502 // goto failure 503 // if (child2.match()) 504 // ... 505 // 506 // failure: 507 // ... 508 // 509 failureBlockStack.push_back(defaultDest); 510 Location loc = val.getLoc(); 511 for (unsigned idx : sortedChildren) { 512 auto &child = switchNode->getChild(idx); 513 Block *childBlock = generateMatcher(*child.second, *region); 514 Block *predicateBlock = builder.createBlock(childBlock); 515 builder.setInsertionPointToEnd(predicateBlock); 516 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 517 switch (kind) { 518 case Predicates::OperandCountAtLeastQuestion: 519 builder.create<pdl_interp::CheckOperandCountOp>( 520 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 521 break; 522 case Predicates::ResultCountAtLeastQuestion: 523 builder.create<pdl_interp::CheckResultCountOp>( 524 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 525 break; 526 default: 527 llvm_unreachable("Generating invalid AtLeast operation"); 528 } 529 failureBlockStack.back() = predicateBlock; 530 } 531 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 532 currentBlock->getOperations().splice(currentBlock->end(), 533 firstPredicateBlock->getOperations()); 534 firstPredicateBlock->erase(); 535 return; 536 } 537 538 // Otherwise, generate each of the children and generate an interpreter 539 // switch. 540 llvm::MapVector<Qualifier *, Block *> children; 541 for (auto &it : switchNode->getChildren()) 542 children.insert({it.first, generateMatcher(*it.second, *region)}); 543 builder.setInsertionPointToEnd(currentBlock); 544 545 switch (question->getKind()) { 546 case Predicates::OperandCountQuestion: 547 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 548 int32_t>(val, defaultDest, builder, children); 549 case Predicates::ResultCountQuestion: 550 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 551 int32_t>(val, defaultDest, builder, children); 552 case Predicates::OperationNameQuestion: 553 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 554 OperationNameAnswer>(val, defaultDest, builder, 555 children); 556 case Predicates::TypeQuestion: 557 if (val.getType().isa<pdl::RangeType>()) { 558 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 559 val, defaultDest, builder, children); 560 } 561 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 562 val, defaultDest, builder, children); 563 case Predicates::AttributeQuestion: 564 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 565 val, defaultDest, builder, children); 566 default: 567 llvm_unreachable("Generating unknown switch predicate."); 568 } 569 } 570 571 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 572 pdl::PatternOp pattern = successNode->getPattern(); 573 Value root = successNode->getRoot(); 574 575 // Generate a rewriter for the pattern this success node represents, and track 576 // any values used from the match region. 577 SmallVector<Position *, 8> usedMatchValues; 578 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 579 580 // Process any values used in the rewrite that are defined in the match. 581 std::vector<Value> mappedMatchValues; 582 mappedMatchValues.reserve(usedMatchValues.size()); 583 for (Position *position : usedMatchValues) 584 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 585 586 // Collect the set of operations generated by the rewriter. 587 SmallVector<StringRef, 4> generatedOps; 588 for (auto op : 589 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>()) 590 generatedOps.push_back(*op.getOpName()); 591 ArrayAttr generatedOpsAttr; 592 if (!generatedOps.empty()) 593 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 594 595 // Grab the root kind if present. 596 StringAttr rootKindAttr; 597 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 598 if (Optional<StringRef> rootKind = rootOp.getOpName()) 599 rootKindAttr = builder.getStringAttr(*rootKind); 600 601 builder.setInsertionPointToEnd(currentBlock); 602 auto matchOp = builder.create<pdl_interp::RecordMatchOp>( 603 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 604 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), 605 failureBlockStack.back()); 606 607 // Set the config of the lowered match to the parent pattern. 608 if (configMap) 609 configMap->try_emplace(matchOp, configMap->lookup(pattern)); 610 } 611 612 SymbolRefAttr PatternLowering::generateRewriter( 613 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 614 builder.setInsertionPointToEnd(rewriterModule.getBody()); 615 auto rewriterFunc = builder.create<pdl_interp::FuncOp>( 616 pattern.getLoc(), "pdl_generated_rewriter", 617 builder.getFunctionType(std::nullopt, std::nullopt)); 618 rewriterSymbolTable.insert(rewriterFunc); 619 620 // Generate the rewriter function body. 621 builder.setInsertionPointToEnd(&rewriterFunc.front()); 622 623 // Map an input operand of the pattern to a generated interpreter value. 624 DenseMap<Value, Value> rewriteValues; 625 auto mapRewriteValue = [&](Value oldValue) { 626 Value &newValue = rewriteValues[oldValue]; 627 if (newValue) 628 return newValue; 629 630 // Prefer materializing constants directly when possible. 631 Operation *oldOp = oldValue.getDefiningOp(); 632 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 633 if (Attribute value = attrOp.getValueAttr()) { 634 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 635 attrOp.getLoc(), value); 636 } 637 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 638 if (TypeAttr type = typeOp.getConstantTypeAttr()) { 639 return newValue = builder.create<pdl_interp::CreateTypeOp>( 640 typeOp.getLoc(), type); 641 } 642 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 643 if (ArrayAttr type = typeOp.getConstantTypesAttr()) { 644 return newValue = builder.create<pdl_interp::CreateTypesOp>( 645 typeOp.getLoc(), typeOp.getType(), type); 646 } 647 } 648 649 // Otherwise, add this as an input to the rewriter. 650 Position *inputPos = valueToPosition.lookup(oldValue); 651 assert(inputPos && "expected value to be a pattern input"); 652 usedMatchValues.push_back(inputPos); 653 return newValue = rewriterFunc.front().addArgument(oldValue.getType(), 654 oldValue.getLoc()); 655 }; 656 657 // If this is a custom rewriter, simply dispatch to the registered rewrite 658 // method. 659 pdl::RewriteOp rewriter = pattern.getRewriter(); 660 if (StringAttr rewriteName = rewriter.getNameAttr()) { 661 SmallVector<Value> args; 662 if (rewriter.getRoot()) 663 args.push_back(mapRewriteValue(rewriter.getRoot())); 664 auto mappedArgs = 665 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); 666 args.append(mappedArgs.begin(), mappedArgs.end()); 667 builder.create<pdl_interp::ApplyRewriteOp>( 668 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); 669 } else { 670 // Otherwise this is a dag rewriter defined using PDL operations. 671 for (Operation &rewriteOp : *rewriter.getBody()) { 672 llvm::TypeSwitch<Operation *>(&rewriteOp) 673 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 674 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp, 675 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) { 676 this->generateRewriter(op, rewriteValues, mapRewriteValue); 677 }); 678 } 679 } 680 681 // Update the signature of the rewrite function. 682 rewriterFunc.setType(builder.getFunctionType( 683 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 684 /*results=*/std::nullopt)); 685 686 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 687 return SymbolRefAttr::get( 688 builder.getContext(), 689 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 690 SymbolRefAttr::get(rewriterFunc)); 691 } 692 693 void PatternLowering::generateRewriter( 694 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 695 function_ref<Value(Value)> mapRewriteValue) { 696 SmallVector<Value, 2> arguments; 697 for (Value argument : rewriteOp.getArgs()) 698 arguments.push_back(mapRewriteValue(argument)); 699 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 700 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), 701 arguments); 702 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) 703 rewriteValues[std::get<0>(it)] = std::get<1>(it); 704 } 705 706 void PatternLowering::generateRewriter( 707 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 708 function_ref<Value(Value)> mapRewriteValue) { 709 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 710 attrOp.getLoc(), attrOp.getValueAttr()); 711 rewriteValues[attrOp] = newAttr; 712 } 713 714 void PatternLowering::generateRewriter( 715 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 716 function_ref<Value(Value)> mapRewriteValue) { 717 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 718 mapRewriteValue(eraseOp.getOpValue())); 719 } 720 721 void PatternLowering::generateRewriter( 722 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 723 function_ref<Value(Value)> mapRewriteValue) { 724 SmallVector<Value, 4> operands; 725 for (Value operand : operationOp.getOperandValues()) 726 operands.push_back(mapRewriteValue(operand)); 727 728 SmallVector<Value, 4> attributes; 729 for (Value attr : operationOp.getAttributeValues()) 730 attributes.push_back(mapRewriteValue(attr)); 731 732 bool hasInferredResultTypes = false; 733 SmallVector<Value, 2> types; 734 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types, 735 rewriteValues, hasInferredResultTypes); 736 737 // Create the new operation. 738 Location loc = operationOp.getLoc(); 739 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 740 loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, 741 attributes, operationOp.getAttributeValueNames()); 742 rewriteValues[operationOp.getOp()] = createdOp; 743 744 // Generate accesses for any results that have their types constrained. 745 // Handle the case where there is a single range representing all of the 746 // result types. 747 OperandRange resultTys = operationOp.getTypeValues(); 748 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 749 Value &type = rewriteValues[resultTys[0]]; 750 if (!type) { 751 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 752 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 753 } 754 return; 755 } 756 757 // Otherwise, populate the individual results. 758 bool seenVariableLength = false; 759 Type valueTy = builder.getType<pdl::ValueType>(); 760 Type valueRangeTy = pdl::RangeType::get(valueTy); 761 for (const auto &it : llvm::enumerate(resultTys)) { 762 Value &type = rewriteValues[it.value()]; 763 if (type) 764 continue; 765 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 766 seenVariableLength |= isVariadic; 767 768 // After a variable length result has been seen, we need to use result 769 // groups because the exact index of the result is not statically known. 770 Value resultVal; 771 if (seenVariableLength) 772 resultVal = builder.create<pdl_interp::GetResultsOp>( 773 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 774 else 775 resultVal = builder.create<pdl_interp::GetResultOp>( 776 loc, valueTy, createdOp, it.index()); 777 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 778 } 779 } 780 781 void PatternLowering::generateRewriter( 782 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues, 783 function_ref<Value(Value)> mapRewriteValue) { 784 SmallVector<Value, 4> replOperands; 785 for (Value operand : rangeOp.getArguments()) 786 replOperands.push_back(mapRewriteValue(operand)); 787 rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>( 788 rangeOp.getLoc(), rangeOp.getType(), replOperands); 789 } 790 791 void PatternLowering::generateRewriter( 792 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 793 function_ref<Value(Value)> mapRewriteValue) { 794 SmallVector<Value, 4> replOperands; 795 796 // If the replacement was another operation, get its results. `pdl` allows 797 // for using an operation for simplicitly, but the interpreter isn't as 798 // user facing. 799 if (Value replOp = replaceOp.getReplOperation()) { 800 // Don't use replace if we know the replaced operation has no results. 801 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>(); 802 if (!opOp || !opOp.getTypeValues().empty()) { 803 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 804 replOp.getLoc(), mapRewriteValue(replOp))); 805 } 806 } else { 807 for (Value operand : replaceOp.getReplValues()) 808 replOperands.push_back(mapRewriteValue(operand)); 809 } 810 811 // If there are no replacement values, just create an erase instead. 812 if (replOperands.empty()) { 813 builder.create<pdl_interp::EraseOp>( 814 replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); 815 return; 816 } 817 818 builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(), 819 mapRewriteValue(replaceOp.getOpValue()), 820 replOperands); 821 } 822 823 void PatternLowering::generateRewriter( 824 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 825 function_ref<Value(Value)> mapRewriteValue) { 826 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 827 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 828 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 829 } 830 831 void PatternLowering::generateRewriter( 832 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 833 function_ref<Value(Value)> mapRewriteValue) { 834 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 835 resultOp.getLoc(), resultOp.getType(), 836 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 837 } 838 839 void PatternLowering::generateRewriter( 840 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 841 function_ref<Value(Value)> mapRewriteValue) { 842 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 843 // type. 844 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { 845 rewriteValues[typeOp] = 846 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 847 } 848 } 849 850 void PatternLowering::generateRewriter( 851 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 852 function_ref<Value(Value)> mapRewriteValue) { 853 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 854 // type. 855 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { 856 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 857 typeOp.getLoc(), typeOp.getType(), typeAttr); 858 } 859 } 860 861 void PatternLowering::generateOperationResultTypeRewriter( 862 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 863 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 864 bool &hasInferredResultTypes) { 865 Block *rewriterBlock = op->getBlock(); 866 867 // Try to handle resolution for each of the result types individually. This is 868 // preferred over type inferrence because it will allow for us to use existing 869 // types directly, as opposed to trying to rebuild the type list. 870 OperandRange resultTypeValues = op.getTypeValues(); 871 auto tryResolveResultTypes = [&] { 872 types.reserve(resultTypeValues.size()); 873 for (const auto &it : llvm::enumerate(resultTypeValues)) { 874 Value resultType = it.value(); 875 876 // Check for an already translated value. 877 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 878 types.push_back(existingRewriteValue); 879 continue; 880 } 881 882 // Check for an input from the matcher. 883 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 884 types.push_back(mapRewriteValue(resultType)); 885 continue; 886 } 887 888 // Otherwise, we couldn't infer the result types. Bail out here to see if 889 // we can infer the types for this operation from another way. 890 types.clear(); 891 return failure(); 892 } 893 return success(); 894 }; 895 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes())) 896 return; 897 898 // Otherwise, check if the operation has type inference support itself. 899 if (op.hasTypeInference()) { 900 hasInferredResultTypes = true; 901 return; 902 } 903 904 // Look for an operation that was replaced by `op`. The result types will be 905 // inferred from the results that were replaced. 906 for (OpOperand &use : op.getOp().getUses()) { 907 // Check that the use corresponds to a ReplaceOp and that it is the 908 // replacement value, not the operation being replaced. 909 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 910 if (!replOpUser || use.getOperandNumber() == 0) 911 continue; 912 // Make sure the replaced operation was defined before this one. PDL 913 // rewrites only have single block regions, so if the op isn't in the 914 // rewriter block (i.e. the current block of the operation) we already know 915 // it dominates (i.e. it's in the matcher). 916 Value replOpVal = replOpUser.getOpValue(); 917 Operation *replacedOp = replOpVal.getDefiningOp(); 918 if (replacedOp->getBlock() == rewriterBlock && 919 !replacedOp->isBeforeInBlock(op)) 920 continue; 921 922 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 923 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 924 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 925 replacedOp->getLoc(), replacedOpResults)); 926 return; 927 } 928 929 // If the types could not be inferred from any context and there weren't any 930 // explicit result types, assume the user actually meant for the operation to 931 // have no results. 932 if (resultTypeValues.empty()) 933 return; 934 935 // The verifier asserts that the result types of each pdl.getOperation can be 936 // inferred. If we reach here, there is a bug either in the logic above or 937 // in the verifier for pdl.getOperation. 938 op->emitOpError() << "unable to infer result type for operation"; 939 llvm_unreachable("unable to infer result type for operation"); 940 } 941 942 //===----------------------------------------------------------------------===// 943 // Conversion Pass 944 //===----------------------------------------------------------------------===// 945 946 namespace { 947 struct PDLToPDLInterpPass 948 : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 949 PDLToPDLInterpPass() = default; 950 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default; 951 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap) 952 : configMap(&configMap) {} 953 void runOnOperation() final; 954 955 /// A map containing the configuration for each pattern. 956 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr; 957 }; 958 } // namespace 959 960 /// Convert the given module containing PDL pattern operations into a PDL 961 /// Interpreter operations. 962 void PDLToPDLInterpPass::runOnOperation() { 963 ModuleOp module = getOperation(); 964 965 // Create the main matcher function This function contains all of the match 966 // related functionality from patterns in the module. 967 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 968 auto matcherFunc = builder.create<pdl_interp::FuncOp>( 969 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 970 builder.getFunctionType(builder.getType<pdl::OperationType>(), 971 /*results=*/std::nullopt), 972 /*attrs=*/std::nullopt); 973 974 // Create a nested module to hold the functions invoked for rewriting the IR 975 // after a successful match. 976 ModuleOp rewriterModule = builder.create<ModuleOp>( 977 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 978 979 // Generate the code for the patterns within the module. 980 PatternLowering generator(matcherFunc, rewriterModule, configMap); 981 generator.lower(module); 982 983 // After generation, delete all of the pattern operations. 984 for (pdl::PatternOp pattern : 985 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) { 986 // Drop the now dead config mappings. 987 if (configMap) 988 configMap->erase(pattern); 989 990 pattern.erase(); 991 } 992 } 993 994 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 995 return std::make_unique<PDLToPDLInterpPass>(); 996 } 997 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass( 998 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) { 999 return std::make_unique<PDLToPDLInterpPass>(configMap); 1000 } 1001