1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP 25 #include "mlir/Conversion/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::pdl_to_pdl_interp; 30 31 //===----------------------------------------------------------------------===// 32 // PatternLowering 33 //===----------------------------------------------------------------------===// 34 35 namespace { 36 /// This class generators operations within the PDL Interpreter dialect from a 37 /// given module containing PDL pattern operations. 38 struct PatternLowering { 39 public: 40 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule); 41 42 /// Generate code for matching and rewriting based on the pattern operations 43 /// within the module. 44 void lower(ModuleOp module); 45 46 private: 47 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 48 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 49 50 /// Generate interpreter operations for the tree rooted at the given matcher 51 /// node, in the specified region. 52 Block *generateMatcher(MatcherNode &node, Region ®ion); 53 54 /// Get or create an access to the provided positional value in the current 55 /// block. This operation may mutate the provided block pointer if nested 56 /// regions (i.e., pdl_interp.iterate) are required. 57 Value getValueAt(Block *¤tBlock, Position *pos); 58 59 /// Create the interpreter predicate operations. This operation may mutate the 60 /// provided current block pointer if nested regions (iterates) are required. 61 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 62 63 /// Create the interpreter switch / predicate operations, with several case 64 /// destinations. This operation never mutates the provided current block 65 /// pointer, because the switch operation does not need Values beyond `val`. 66 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 67 68 /// Create the interpreter operations to record a successful pattern match 69 /// using the contained root operation. This operation may mutate the current 70 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 71 void generate(SuccessNode *successNode, Block *¤tBlock); 72 73 /// Generate a rewriter function for the given pattern operation, and returns 74 /// a reference to that function. 75 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 76 SmallVectorImpl<Position *> &usedMatchValues); 77 78 /// Generate the rewriter code for the given operation. 79 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 80 DenseMap<Value, Value> &rewriteValues, 81 function_ref<Value(Value)> mapRewriteValue); 82 void generateRewriter(pdl::AttributeOp attrOp, 83 DenseMap<Value, Value> &rewriteValues, 84 function_ref<Value(Value)> mapRewriteValue); 85 void generateRewriter(pdl::EraseOp eraseOp, 86 DenseMap<Value, Value> &rewriteValues, 87 function_ref<Value(Value)> mapRewriteValue); 88 void generateRewriter(pdl::OperationOp operationOp, 89 DenseMap<Value, Value> &rewriteValues, 90 function_ref<Value(Value)> mapRewriteValue); 91 void generateRewriter(pdl::ReplaceOp replaceOp, 92 DenseMap<Value, Value> &rewriteValues, 93 function_ref<Value(Value)> mapRewriteValue); 94 void generateRewriter(pdl::ResultOp resultOp, 95 DenseMap<Value, Value> &rewriteValues, 96 function_ref<Value(Value)> mapRewriteValue); 97 void generateRewriter(pdl::ResultsOp resultOp, 98 DenseMap<Value, Value> &rewriteValues, 99 function_ref<Value(Value)> mapRewriteValue); 100 void generateRewriter(pdl::TypeOp typeOp, 101 DenseMap<Value, Value> &rewriteValues, 102 function_ref<Value(Value)> mapRewriteValue); 103 void generateRewriter(pdl::TypesOp typeOp, 104 DenseMap<Value, Value> &rewriteValues, 105 function_ref<Value(Value)> mapRewriteValue); 106 107 /// Generate the values used for resolving the result types of an operation 108 /// created within a dag rewriter region. If the result types of the operation 109 /// should be inferred, `hasInferredResultTypes` is set to true. 110 void generateOperationResultTypeRewriter( 111 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 112 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 113 bool &hasInferredResultTypes); 114 115 /// A builder to use when generating interpreter operations. 116 OpBuilder builder; 117 118 /// The matcher function used for all match related logic within PDL patterns. 119 pdl_interp::FuncOp matcherFunc; 120 121 /// The rewriter module containing the all rewrite related logic within PDL 122 /// patterns. 123 ModuleOp rewriterModule; 124 125 /// The symbol table of the rewriter module used for insertion. 126 SymbolTable rewriterSymbolTable; 127 128 /// A scoped map connecting a position with the corresponding interpreter 129 /// value. 130 ValueMap values; 131 132 /// A stack of blocks used as the failure destination for matcher nodes that 133 /// don't have an explicit failure path. 134 SmallVector<Block *, 8> failureBlockStack; 135 136 /// A mapping between values defined in a pattern match, and the corresponding 137 /// positional value. 138 DenseMap<Value, Position *> valueToPosition; 139 140 /// The set of operation values whose whose location will be used for newly 141 /// generated operations. 142 SetVector<Value> locOps; 143 }; 144 } // namespace 145 146 PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc, 147 ModuleOp rewriterModule) 148 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 149 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 150 151 void PatternLowering::lower(ModuleOp module) { 152 PredicateUniquer predicateUniquer; 153 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 154 155 // Define top-level scope for the arguments to the matcher function. 156 ValueMapScope topLevelValueScope(values); 157 158 // Insert the root operation, i.e. argument to the matcher, at the root 159 // position. 160 Block *matcherEntryBlock = &matcherFunc.front(); 161 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 162 163 // Generate a root matcher node from the provided PDL module. 164 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 165 module, predicateBuilder, valueToPosition); 166 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 167 assert(failureBlockStack.empty() && "failed to empty the stack"); 168 169 // After generation, merged the first matched block into the entry. 170 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 171 firstMatcherBlock->getOperations()); 172 firstMatcherBlock->erase(); 173 } 174 175 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { 176 // Push a new scope for the values used by this matcher. 177 Block *block = ®ion.emplaceBlock(); 178 ValueMapScope scope(values); 179 180 // If this is the return node, simply insert the corresponding interpreter 181 // finalize. 182 if (isa<ExitNode>(node)) { 183 builder.setInsertionPointToEnd(block); 184 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 185 return block; 186 } 187 188 // Get the next block in the match sequence. 189 // This is intentionally executed first, before we get the value for the 190 // position associated with the node, so that we preserve an "there exist" 191 // semantics: if getting a value requires an upward traversal (going from a 192 // value to its consumers), we want to perform the check on all the consumers 193 // before we pass control to the failure node. 194 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 195 Block *failureBlock; 196 if (failureNode) { 197 failureBlock = generateMatcher(*failureNode, region); 198 failureBlockStack.push_back(failureBlock); 199 } else { 200 assert(!failureBlockStack.empty() && "expected valid failure block"); 201 failureBlock = failureBlockStack.back(); 202 } 203 204 // If this node contains a position, get the corresponding value for this 205 // block. 206 Block *currentBlock = block; 207 Position *position = node.getPosition(); 208 Value val = position ? getValueAt(currentBlock, position) : Value(); 209 210 // If this value corresponds to an operation, record that we are going to use 211 // its location as part of a fused location. 212 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 213 if (isOperationValue) 214 locOps.insert(val); 215 216 // Dispatch to the correct method based on derived node type. 217 TypeSwitch<MatcherNode *>(&node) 218 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { 219 this->generate(derivedNode, currentBlock, val); 220 }) 221 .Case([&](SuccessNode *successNode) { 222 generate(successNode, currentBlock); 223 }); 224 225 // Pop all the failure blocks that were inserted due to nesting of 226 // pdl_interp.iterate. 227 while (failureBlockStack.back() != failureBlock) { 228 failureBlockStack.pop_back(); 229 assert(!failureBlockStack.empty() && "unable to locate failure block"); 230 } 231 232 // Pop the new failure block. 233 if (failureNode) 234 failureBlockStack.pop_back(); 235 236 if (isOperationValue) 237 locOps.remove(val); 238 239 return block; 240 } 241 242 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 243 if (Value val = values.lookup(pos)) 244 return val; 245 246 // Get the value for the parent position. 247 Value parentVal; 248 if (Position *parent = pos->getParent()) 249 parentVal = getValueAt(currentBlock, parent); 250 251 // TODO: Use a location from the position. 252 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); 253 builder.setInsertionPointToEnd(currentBlock); 254 Value value; 255 switch (pos->getKind()) { 256 case Predicates::OperationPos: { 257 auto *operationPos = cast<OperationPosition>(pos); 258 if (operationPos->isOperandDefiningOp()) 259 // Standard (downward) traversal which directly follows the defining op. 260 value = builder.create<pdl_interp::GetDefiningOpOp>( 261 loc, builder.getType<pdl::OperationType>(), parentVal); 262 else 263 // A passthrough operation position. 264 value = parentVal; 265 break; 266 } 267 case Predicates::UsersPos: { 268 auto *usersPos = cast<UsersPosition>(pos); 269 270 // The first operation retrieves the representative value of a range. 271 // This applies only when the parent is a range of values and we were 272 // requested to use a representative value (e.g., upward traversal). 273 if (parentVal.getType().isa<pdl::RangeType>() && 274 usersPos->useRepresentative()) 275 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 276 else 277 value = parentVal; 278 279 // The second operation retrieves the users. 280 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 281 break; 282 } 283 case Predicates::ForEachPos: { 284 assert(!failureBlockStack.empty() && "expected valid failure block"); 285 auto foreach = builder.create<pdl_interp::ForEachOp>( 286 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); 287 value = foreach.getLoopVariable(); 288 289 // Create the continuation block. 290 Block *continueBlock = builder.createBlock(&foreach.getRegion()); 291 builder.create<pdl_interp::ContinueOp>(loc); 292 failureBlockStack.push_back(continueBlock); 293 294 currentBlock = &foreach.getRegion().front(); 295 break; 296 } 297 case Predicates::OperandPos: { 298 auto *operandPos = cast<OperandPosition>(pos); 299 value = builder.create<pdl_interp::GetOperandOp>( 300 loc, builder.getType<pdl::ValueType>(), parentVal, 301 operandPos->getOperandNumber()); 302 break; 303 } 304 case Predicates::OperandGroupPos: { 305 auto *operandPos = cast<OperandGroupPosition>(pos); 306 Type valueTy = builder.getType<pdl::ValueType>(); 307 value = builder.create<pdl_interp::GetOperandsOp>( 308 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 309 parentVal, operandPos->getOperandGroupNumber()); 310 break; 311 } 312 case Predicates::AttributePos: { 313 auto *attrPos = cast<AttributePosition>(pos); 314 value = builder.create<pdl_interp::GetAttributeOp>( 315 loc, builder.getType<pdl::AttributeType>(), parentVal, 316 attrPos->getName().strref()); 317 break; 318 } 319 case Predicates::TypePos: { 320 if (parentVal.getType().isa<pdl::AttributeType>()) 321 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 322 else 323 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 324 break; 325 } 326 case Predicates::ResultPos: { 327 auto *resPos = cast<ResultPosition>(pos); 328 value = builder.create<pdl_interp::GetResultOp>( 329 loc, builder.getType<pdl::ValueType>(), parentVal, 330 resPos->getResultNumber()); 331 break; 332 } 333 case Predicates::ResultGroupPos: { 334 auto *resPos = cast<ResultGroupPosition>(pos); 335 Type valueTy = builder.getType<pdl::ValueType>(); 336 value = builder.create<pdl_interp::GetResultsOp>( 337 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 338 parentVal, resPos->getResultGroupNumber()); 339 break; 340 } 341 case Predicates::AttributeLiteralPos: { 342 auto *attrPos = cast<AttributeLiteralPosition>(pos); 343 value = 344 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); 345 break; 346 } 347 case Predicates::TypeLiteralPos: { 348 auto *typePos = cast<TypeLiteralPosition>(pos); 349 Attribute rawTypeAttr = typePos->getValue(); 350 if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>()) 351 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); 352 else 353 value = builder.create<pdl_interp::CreateTypesOp>( 354 loc, rawTypeAttr.cast<ArrayAttr>()); 355 break; 356 } 357 default: 358 llvm_unreachable("Generating unknown Position getter"); 359 break; 360 } 361 362 values.insert(pos, value); 363 return value; 364 } 365 366 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 367 Value val) { 368 Location loc = val.getLoc(); 369 Qualifier *question = boolNode->getQuestion(); 370 Qualifier *answer = boolNode->getAnswer(); 371 Region *region = currentBlock->getParent(); 372 373 // Execute the getValue queries first, so that we create success 374 // matcher in the correct (possibly nested) region. 375 SmallVector<Value> args; 376 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 377 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 378 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 379 for (Position *position : cstQuestion->getArgs()) 380 args.push_back(getValueAt(currentBlock, position)); 381 } 382 383 // Generate the matcher in the current (potentially nested) region 384 // and get the failure successor. 385 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); 386 Block *failure = failureBlockStack.back(); 387 388 // Finally, create the predicate. 389 builder.setInsertionPointToEnd(currentBlock); 390 Predicates::Kind kind = question->getKind(); 391 switch (kind) { 392 case Predicates::IsNotNullQuestion: 393 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 394 break; 395 case Predicates::OperationNameQuestion: { 396 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 397 builder.create<pdl_interp::CheckOperationNameOp>( 398 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 399 break; 400 } 401 case Predicates::TypeQuestion: { 402 auto *ans = cast<TypeAnswer>(answer); 403 if (val.getType().isa<pdl::RangeType>()) 404 builder.create<pdl_interp::CheckTypesOp>( 405 loc, val, ans->getValue().cast<ArrayAttr>(), success, failure); 406 else 407 builder.create<pdl_interp::CheckTypeOp>( 408 loc, val, ans->getValue().cast<TypeAttr>(), success, failure); 409 break; 410 } 411 case Predicates::AttributeQuestion: { 412 auto *ans = cast<AttributeAnswer>(answer); 413 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 414 success, failure); 415 break; 416 } 417 case Predicates::OperandCountAtLeastQuestion: 418 case Predicates::OperandCountQuestion: 419 builder.create<pdl_interp::CheckOperandCountOp>( 420 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 421 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 422 success, failure); 423 break; 424 case Predicates::ResultCountAtLeastQuestion: 425 case Predicates::ResultCountQuestion: 426 builder.create<pdl_interp::CheckResultCountOp>( 427 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 428 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 429 success, failure); 430 break; 431 case Predicates::EqualToQuestion: { 432 bool trueAnswer = isa<TrueAnswer>(answer); 433 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 434 trueAnswer ? success : failure, 435 trueAnswer ? failure : success); 436 break; 437 } 438 case Predicates::ConstraintQuestion: { 439 auto *cstQuestion = cast<ConstraintQuestion>(question); 440 builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(), 441 args, success, failure); 442 break; 443 } 444 default: 445 llvm_unreachable("Generating unknown Predicate operation"); 446 } 447 } 448 449 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 450 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 451 llvm::MapVector<Qualifier *, Block *> &dests) { 452 std::vector<ValT> values; 453 std::vector<Block *> blocks; 454 values.reserve(dests.size()); 455 blocks.reserve(dests.size()); 456 for (const auto &it : dests) { 457 blocks.push_back(it.second); 458 values.push_back(cast<PredT>(it.first)->getValue()); 459 } 460 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 461 } 462 463 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 464 Value val) { 465 Qualifier *question = switchNode->getQuestion(); 466 Region *region = currentBlock->getParent(); 467 Block *defaultDest = failureBlockStack.back(); 468 469 // If the switch question is not an exact answer, i.e. for the `at_least` 470 // cases, we generate a special block sequence. 471 Predicates::Kind kind = question->getKind(); 472 if (kind == Predicates::OperandCountAtLeastQuestion || 473 kind == Predicates::ResultCountAtLeastQuestion) { 474 // Order the children such that the cases are in reverse numerical order. 475 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 476 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 477 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 478 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 479 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 480 }); 481 482 // Build the destination for each child using the next highest child as a 483 // a failure destination. This essentially creates the following control 484 // flow: 485 // 486 // if (operand_count < 1) 487 // goto failure 488 // if (child1.match()) 489 // ... 490 // 491 // if (operand_count < 2) 492 // goto failure 493 // if (child2.match()) 494 // ... 495 // 496 // failure: 497 // ... 498 // 499 failureBlockStack.push_back(defaultDest); 500 Location loc = val.getLoc(); 501 for (unsigned idx : sortedChildren) { 502 auto &child = switchNode->getChild(idx); 503 Block *childBlock = generateMatcher(*child.second, *region); 504 Block *predicateBlock = builder.createBlock(childBlock); 505 builder.setInsertionPointToEnd(predicateBlock); 506 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 507 switch (kind) { 508 case Predicates::OperandCountAtLeastQuestion: 509 builder.create<pdl_interp::CheckOperandCountOp>( 510 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 511 break; 512 case Predicates::ResultCountAtLeastQuestion: 513 builder.create<pdl_interp::CheckResultCountOp>( 514 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 515 break; 516 default: 517 llvm_unreachable("Generating invalid AtLeast operation"); 518 } 519 failureBlockStack.back() = predicateBlock; 520 } 521 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 522 currentBlock->getOperations().splice(currentBlock->end(), 523 firstPredicateBlock->getOperations()); 524 firstPredicateBlock->erase(); 525 return; 526 } 527 528 // Otherwise, generate each of the children and generate an interpreter 529 // switch. 530 llvm::MapVector<Qualifier *, Block *> children; 531 for (auto &it : switchNode->getChildren()) 532 children.insert({it.first, generateMatcher(*it.second, *region)}); 533 builder.setInsertionPointToEnd(currentBlock); 534 535 switch (question->getKind()) { 536 case Predicates::OperandCountQuestion: 537 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 538 int32_t>(val, defaultDest, builder, children); 539 case Predicates::ResultCountQuestion: 540 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 541 int32_t>(val, defaultDest, builder, children); 542 case Predicates::OperationNameQuestion: 543 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 544 OperationNameAnswer>(val, defaultDest, builder, 545 children); 546 case Predicates::TypeQuestion: 547 if (val.getType().isa<pdl::RangeType>()) { 548 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 549 val, defaultDest, builder, children); 550 } 551 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 552 val, defaultDest, builder, children); 553 case Predicates::AttributeQuestion: 554 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 555 val, defaultDest, builder, children); 556 default: 557 llvm_unreachable("Generating unknown switch predicate."); 558 } 559 } 560 561 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 562 pdl::PatternOp pattern = successNode->getPattern(); 563 Value root = successNode->getRoot(); 564 565 // Generate a rewriter for the pattern this success node represents, and track 566 // any values used from the match region. 567 SmallVector<Position *, 8> usedMatchValues; 568 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 569 570 // Process any values used in the rewrite that are defined in the match. 571 std::vector<Value> mappedMatchValues; 572 mappedMatchValues.reserve(usedMatchValues.size()); 573 for (Position *position : usedMatchValues) 574 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 575 576 // Collect the set of operations generated by the rewriter. 577 SmallVector<StringRef, 4> generatedOps; 578 for (auto op : 579 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>()) 580 generatedOps.push_back(*op.getOpName()); 581 ArrayAttr generatedOpsAttr; 582 if (!generatedOps.empty()) 583 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 584 585 // Grab the root kind if present. 586 StringAttr rootKindAttr; 587 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 588 if (Optional<StringRef> rootKind = rootOp.getOpName()) 589 rootKindAttr = builder.getStringAttr(*rootKind); 590 591 builder.setInsertionPointToEnd(currentBlock); 592 builder.create<pdl_interp::RecordMatchOp>( 593 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 594 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), 595 failureBlockStack.back()); 596 } 597 598 SymbolRefAttr PatternLowering::generateRewriter( 599 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 600 builder.setInsertionPointToEnd(rewriterModule.getBody()); 601 auto rewriterFunc = builder.create<pdl_interp::FuncOp>( 602 pattern.getLoc(), "pdl_generated_rewriter", 603 builder.getFunctionType(llvm::None, llvm::None)); 604 rewriterSymbolTable.insert(rewriterFunc); 605 606 // Generate the rewriter function body. 607 builder.setInsertionPointToEnd(&rewriterFunc.front()); 608 609 // Map an input operand of the pattern to a generated interpreter value. 610 DenseMap<Value, Value> rewriteValues; 611 auto mapRewriteValue = [&](Value oldValue) { 612 Value &newValue = rewriteValues[oldValue]; 613 if (newValue) 614 return newValue; 615 616 // Prefer materializing constants directly when possible. 617 Operation *oldOp = oldValue.getDefiningOp(); 618 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 619 if (Attribute value = attrOp.getValueAttr()) { 620 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 621 attrOp.getLoc(), value); 622 } 623 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 624 if (TypeAttr type = typeOp.getConstantTypeAttr()) { 625 return newValue = builder.create<pdl_interp::CreateTypeOp>( 626 typeOp.getLoc(), type); 627 } 628 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 629 if (ArrayAttr type = typeOp.getConstantTypesAttr()) { 630 return newValue = builder.create<pdl_interp::CreateTypesOp>( 631 typeOp.getLoc(), typeOp.getType(), type); 632 } 633 } 634 635 // Otherwise, add this as an input to the rewriter. 636 Position *inputPos = valueToPosition.lookup(oldValue); 637 assert(inputPos && "expected value to be a pattern input"); 638 usedMatchValues.push_back(inputPos); 639 return newValue = rewriterFunc.front().addArgument(oldValue.getType(), 640 oldValue.getLoc()); 641 }; 642 643 // If this is a custom rewriter, simply dispatch to the registered rewrite 644 // method. 645 pdl::RewriteOp rewriter = pattern.getRewriter(); 646 if (StringAttr rewriteName = rewriter.getNameAttr()) { 647 SmallVector<Value> args; 648 if (rewriter.getRoot()) 649 args.push_back(mapRewriteValue(rewriter.getRoot())); 650 auto mappedArgs = 651 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); 652 args.append(mappedArgs.begin(), mappedArgs.end()); 653 builder.create<pdl_interp::ApplyRewriteOp>( 654 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); 655 } else { 656 // Otherwise this is a dag rewriter defined using PDL operations. 657 for (Operation &rewriteOp : *rewriter.getBody()) { 658 llvm::TypeSwitch<Operation *>(&rewriteOp) 659 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 660 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 661 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 662 this->generateRewriter(op, rewriteValues, mapRewriteValue); 663 }); 664 } 665 } 666 667 // Update the signature of the rewrite function. 668 rewriterFunc.setType(builder.getFunctionType( 669 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 670 /*results=*/llvm::None)); 671 672 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 673 return SymbolRefAttr::get( 674 builder.getContext(), 675 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 676 SymbolRefAttr::get(rewriterFunc)); 677 } 678 679 void PatternLowering::generateRewriter( 680 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 681 function_ref<Value(Value)> mapRewriteValue) { 682 SmallVector<Value, 2> arguments; 683 for (Value argument : rewriteOp.getArgs()) 684 arguments.push_back(mapRewriteValue(argument)); 685 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 686 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), 687 arguments); 688 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) 689 rewriteValues[std::get<0>(it)] = std::get<1>(it); 690 } 691 692 void PatternLowering::generateRewriter( 693 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 694 function_ref<Value(Value)> mapRewriteValue) { 695 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 696 attrOp.getLoc(), attrOp.getValueAttr()); 697 rewriteValues[attrOp] = newAttr; 698 } 699 700 void PatternLowering::generateRewriter( 701 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 702 function_ref<Value(Value)> mapRewriteValue) { 703 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 704 mapRewriteValue(eraseOp.getOpValue())); 705 } 706 707 void PatternLowering::generateRewriter( 708 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 709 function_ref<Value(Value)> mapRewriteValue) { 710 SmallVector<Value, 4> operands; 711 for (Value operand : operationOp.getOperandValues()) 712 operands.push_back(mapRewriteValue(operand)); 713 714 SmallVector<Value, 4> attributes; 715 for (Value attr : operationOp.getAttributeValues()) 716 attributes.push_back(mapRewriteValue(attr)); 717 718 bool hasInferredResultTypes = false; 719 SmallVector<Value, 2> types; 720 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types, 721 rewriteValues, hasInferredResultTypes); 722 723 // Create the new operation. 724 Location loc = operationOp.getLoc(); 725 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 726 loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, 727 attributes, operationOp.getAttributeValueNames()); 728 rewriteValues[operationOp.getOp()] = createdOp; 729 730 // Generate accesses for any results that have their types constrained. 731 // Handle the case where there is a single range representing all of the 732 // result types. 733 OperandRange resultTys = operationOp.getTypeValues(); 734 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 735 Value &type = rewriteValues[resultTys[0]]; 736 if (!type) { 737 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 738 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 739 } 740 return; 741 } 742 743 // Otherwise, populate the individual results. 744 bool seenVariableLength = false; 745 Type valueTy = builder.getType<pdl::ValueType>(); 746 Type valueRangeTy = pdl::RangeType::get(valueTy); 747 for (const auto &it : llvm::enumerate(resultTys)) { 748 Value &type = rewriteValues[it.value()]; 749 if (type) 750 continue; 751 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 752 seenVariableLength |= isVariadic; 753 754 // After a variable length result has been seen, we need to use result 755 // groups because the exact index of the result is not statically known. 756 Value resultVal; 757 if (seenVariableLength) 758 resultVal = builder.create<pdl_interp::GetResultsOp>( 759 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 760 else 761 resultVal = builder.create<pdl_interp::GetResultOp>( 762 loc, valueTy, createdOp, it.index()); 763 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 764 } 765 } 766 767 void PatternLowering::generateRewriter( 768 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 769 function_ref<Value(Value)> mapRewriteValue) { 770 SmallVector<Value, 4> replOperands; 771 772 // If the replacement was another operation, get its results. `pdl` allows 773 // for using an operation for simplicitly, but the interpreter isn't as 774 // user facing. 775 if (Value replOp = replaceOp.getReplOperation()) { 776 // Don't use replace if we know the replaced operation has no results. 777 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>(); 778 if (!opOp || !opOp.getTypeValues().empty()) { 779 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 780 replOp.getLoc(), mapRewriteValue(replOp))); 781 } 782 } else { 783 for (Value operand : replaceOp.getReplValues()) 784 replOperands.push_back(mapRewriteValue(operand)); 785 } 786 787 // If there are no replacement values, just create an erase instead. 788 if (replOperands.empty()) { 789 builder.create<pdl_interp::EraseOp>( 790 replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); 791 return; 792 } 793 794 builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(), 795 mapRewriteValue(replaceOp.getOpValue()), 796 replOperands); 797 } 798 799 void PatternLowering::generateRewriter( 800 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 801 function_ref<Value(Value)> mapRewriteValue) { 802 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 803 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 804 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 805 } 806 807 void PatternLowering::generateRewriter( 808 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 809 function_ref<Value(Value)> mapRewriteValue) { 810 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 811 resultOp.getLoc(), resultOp.getType(), 812 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 813 } 814 815 void PatternLowering::generateRewriter( 816 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 817 function_ref<Value(Value)> mapRewriteValue) { 818 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 819 // type. 820 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { 821 rewriteValues[typeOp] = 822 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 823 } 824 } 825 826 void PatternLowering::generateRewriter( 827 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 828 function_ref<Value(Value)> mapRewriteValue) { 829 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 830 // type. 831 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { 832 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 833 typeOp.getLoc(), typeOp.getType(), typeAttr); 834 } 835 } 836 837 void PatternLowering::generateOperationResultTypeRewriter( 838 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 839 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 840 bool &hasInferredResultTypes) { 841 Block *rewriterBlock = op->getBlock(); 842 843 // Try to handle resolution for each of the result types individually. This is 844 // preferred over type inferrence because it will allow for us to use existing 845 // types directly, as opposed to trying to rebuild the type list. 846 OperandRange resultTypeValues = op.getTypeValues(); 847 auto tryResolveResultTypes = [&] { 848 types.reserve(resultTypeValues.size()); 849 for (const auto &it : llvm::enumerate(resultTypeValues)) { 850 Value resultType = it.value(); 851 852 // Check for an already translated value. 853 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 854 types.push_back(existingRewriteValue); 855 continue; 856 } 857 858 // Check for an input from the matcher. 859 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 860 types.push_back(mapRewriteValue(resultType)); 861 continue; 862 } 863 864 // Otherwise, we couldn't infer the result types. Bail out here to see if 865 // we can infer the types for this operation from another way. 866 types.clear(); 867 return failure(); 868 } 869 return success(); 870 }; 871 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes())) 872 return; 873 874 // Otherwise, check if the operation has type inference support itself. 875 if (op.hasTypeInference()) { 876 hasInferredResultTypes = true; 877 return; 878 } 879 880 // Look for an operation that was replaced by `op`. The result types will be 881 // inferred from the results that were replaced. 882 for (OpOperand &use : op.getOp().getUses()) { 883 // Check that the use corresponds to a ReplaceOp and that it is the 884 // replacement value, not the operation being replaced. 885 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 886 if (!replOpUser || use.getOperandNumber() == 0) 887 continue; 888 // Make sure the replaced operation was defined before this one. PDL 889 // rewrites only have single block regions, so if the op isn't in the 890 // rewriter block (i.e. the current block of the operation) we already know 891 // it dominates (i.e. it's in the matcher). 892 Value replOpVal = replOpUser.getOpValue(); 893 Operation *replacedOp = replOpVal.getDefiningOp(); 894 if (replacedOp->getBlock() == rewriterBlock && 895 !replacedOp->isBeforeInBlock(op)) 896 continue; 897 898 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 899 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 900 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 901 replacedOp->getLoc(), replacedOpResults)); 902 return; 903 } 904 905 // If the types could not be inferred from any context and there weren't any 906 // explicit result types, assume the user actually meant for the operation to 907 // have no results. 908 if (resultTypeValues.empty()) 909 return; 910 911 // The verifier asserts that the result types of each pdl.getOperation can be 912 // inferred. If we reach here, there is a bug either in the logic above or 913 // in the verifier for pdl.getOperation. 914 op->emitOpError() << "unable to infer result type for operation"; 915 llvm_unreachable("unable to infer result type for operation"); 916 } 917 918 //===----------------------------------------------------------------------===// 919 // Conversion Pass 920 //===----------------------------------------------------------------------===// 921 922 namespace { 923 struct PDLToPDLInterpPass 924 : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 925 void runOnOperation() final; 926 }; 927 } // namespace 928 929 /// Convert the given module containing PDL pattern operations into a PDL 930 /// Interpreter operations. 931 void PDLToPDLInterpPass::runOnOperation() { 932 ModuleOp module = getOperation(); 933 934 // Create the main matcher function This function contains all of the match 935 // related functionality from patterns in the module. 936 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 937 auto matcherFunc = builder.create<pdl_interp::FuncOp>( 938 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 939 builder.getFunctionType(builder.getType<pdl::OperationType>(), 940 /*results=*/llvm::None), 941 /*attrs=*/llvm::None); 942 943 // Create a nested module to hold the functions invoked for rewriting the IR 944 // after a successful match. 945 ModuleOp rewriterModule = builder.create<ModuleOp>( 946 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 947 948 // Generate the code for the patterns within the module. 949 PatternLowering generator(matcherFunc, rewriterModule); 950 generator.lower(module); 951 952 // After generation, delete all of the pattern operations. 953 for (pdl::PatternOp pattern : 954 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 955 pattern.erase(); 956 } 957 958 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 959 return std::make_unique<PDLToPDLInterpPass>(); 960 } 961