1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::pdl_to_pdl_interp; 25 26 //===----------------------------------------------------------------------===// 27 // PatternLowering 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 /// This class generators operations within the PDL Interpreter dialect from a 32 /// given module containing PDL pattern operations. 33 struct PatternLowering { 34 public: 35 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule); 36 37 /// Generate code for matching and rewriting based on the pattern operations 38 /// within the module. 39 void lower(ModuleOp module); 40 41 private: 42 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 44 45 /// Generate interpreter operations for the tree rooted at the given matcher 46 /// node, in the specified region. 47 Block *generateMatcher(MatcherNode &node, Region ®ion); 48 49 /// Get or create an access to the provided positional value in the current 50 /// block. This operation may mutate the provided block pointer if nested 51 /// regions (i.e., pdl_interp.iterate) are required. 52 Value getValueAt(Block *¤tBlock, Position *pos); 53 54 /// Create the interpreter predicate operations. This operation may mutate the 55 /// provided current block pointer if nested regions (iterates) are required. 56 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 57 58 /// Create the interpreter switch / predicate operations, with several case 59 /// destinations. This operation never mutates the provided current block 60 /// pointer, because the switch operation does not need Values beyond `val`. 61 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 62 63 /// Create the interpreter operations to record a successful pattern match 64 /// using the contained root operation. This operation may mutate the current 65 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 66 void generate(SuccessNode *successNode, Block *¤tBlock); 67 68 /// Generate a rewriter function for the given pattern operation, and returns 69 /// a reference to that function. 70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 71 SmallVectorImpl<Position *> &usedMatchValues); 72 73 /// Generate the rewriter code for the given operation. 74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 75 DenseMap<Value, Value> &rewriteValues, 76 function_ref<Value(Value)> mapRewriteValue); 77 void generateRewriter(pdl::AttributeOp attrOp, 78 DenseMap<Value, Value> &rewriteValues, 79 function_ref<Value(Value)> mapRewriteValue); 80 void generateRewriter(pdl::EraseOp eraseOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::OperationOp operationOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::ReplaceOp replaceOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::ResultOp resultOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::ResultsOp resultOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::TypeOp typeOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::TypesOp typeOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 102 /// Generate the values used for resolving the result types of an operation 103 /// created within a dag rewriter region. If the result types of the operation 104 /// should be inferred, `hasInferredResultTypes` is set to true. 105 void generateOperationResultTypeRewriter( 106 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 107 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 108 bool &hasInferredResultTypes); 109 110 /// A builder to use when generating interpreter operations. 111 OpBuilder builder; 112 113 /// The matcher function used for all match related logic within PDL patterns. 114 pdl_interp::FuncOp matcherFunc; 115 116 /// The rewriter module containing the all rewrite related logic within PDL 117 /// patterns. 118 ModuleOp rewriterModule; 119 120 /// The symbol table of the rewriter module used for insertion. 121 SymbolTable rewriterSymbolTable; 122 123 /// A scoped map connecting a position with the corresponding interpreter 124 /// value. 125 ValueMap values; 126 127 /// A stack of blocks used as the failure destination for matcher nodes that 128 /// don't have an explicit failure path. 129 SmallVector<Block *, 8> failureBlockStack; 130 131 /// A mapping between values defined in a pattern match, and the corresponding 132 /// positional value. 133 DenseMap<Value, Position *> valueToPosition; 134 135 /// The set of operation values whose whose location will be used for newly 136 /// generated operations. 137 SetVector<Value> locOps; 138 }; 139 } // namespace 140 141 PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc, 142 ModuleOp rewriterModule) 143 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 144 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 145 146 void PatternLowering::lower(ModuleOp module) { 147 PredicateUniquer predicateUniquer; 148 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 149 150 // Define top-level scope for the arguments to the matcher function. 151 ValueMapScope topLevelValueScope(values); 152 153 // Insert the root operation, i.e. argument to the matcher, at the root 154 // position. 155 Block *matcherEntryBlock = &matcherFunc.front(); 156 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 157 158 // Generate a root matcher node from the provided PDL module. 159 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 160 module, predicateBuilder, valueToPosition); 161 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 162 assert(failureBlockStack.empty() && "failed to empty the stack"); 163 164 // After generation, merged the first matched block into the entry. 165 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 166 firstMatcherBlock->getOperations()); 167 firstMatcherBlock->erase(); 168 } 169 170 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { 171 // Push a new scope for the values used by this matcher. 172 Block *block = ®ion.emplaceBlock(); 173 ValueMapScope scope(values); 174 175 // If this is the return node, simply insert the corresponding interpreter 176 // finalize. 177 if (isa<ExitNode>(node)) { 178 builder.setInsertionPointToEnd(block); 179 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 180 return block; 181 } 182 183 // Get the next block in the match sequence. 184 // This is intentionally executed first, before we get the value for the 185 // position associated with the node, so that we preserve an "there exist" 186 // semantics: if getting a value requires an upward traversal (going from a 187 // value to its consumers), we want to perform the check on all the consumers 188 // before we pass control to the failure node. 189 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 190 Block *failureBlock; 191 if (failureNode) { 192 failureBlock = generateMatcher(*failureNode, region); 193 failureBlockStack.push_back(failureBlock); 194 } else { 195 assert(!failureBlockStack.empty() && "expected valid failure block"); 196 failureBlock = failureBlockStack.back(); 197 } 198 199 // If this node contains a position, get the corresponding value for this 200 // block. 201 Block *currentBlock = block; 202 Position *position = node.getPosition(); 203 Value val = position ? getValueAt(currentBlock, position) : Value(); 204 205 // If this value corresponds to an operation, record that we are going to use 206 // its location as part of a fused location. 207 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 208 if (isOperationValue) 209 locOps.insert(val); 210 211 // Dispatch to the correct method based on derived node type. 212 TypeSwitch<MatcherNode *>(&node) 213 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { 214 this->generate(derivedNode, currentBlock, val); 215 }) 216 .Case([&](SuccessNode *successNode) { 217 generate(successNode, currentBlock); 218 }); 219 220 // Pop all the failure blocks that were inserted due to nesting of 221 // pdl_interp.iterate. 222 while (failureBlockStack.back() != failureBlock) { 223 failureBlockStack.pop_back(); 224 assert(!failureBlockStack.empty() && "unable to locate failure block"); 225 } 226 227 // Pop the new failure block. 228 if (failureNode) 229 failureBlockStack.pop_back(); 230 231 if (isOperationValue) 232 locOps.remove(val); 233 234 return block; 235 } 236 237 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 238 if (Value val = values.lookup(pos)) 239 return val; 240 241 // Get the value for the parent position. 242 Value parentVal; 243 if (Position *parent = pos->getParent()) 244 parentVal = getValueAt(currentBlock, parent); 245 246 // TODO: Use a location from the position. 247 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); 248 builder.setInsertionPointToEnd(currentBlock); 249 Value value; 250 switch (pos->getKind()) { 251 case Predicates::OperationPos: { 252 auto *operationPos = cast<OperationPosition>(pos); 253 if (operationPos->isOperandDefiningOp()) 254 // Standard (downward) traversal which directly follows the defining op. 255 value = builder.create<pdl_interp::GetDefiningOpOp>( 256 loc, builder.getType<pdl::OperationType>(), parentVal); 257 else 258 // A passthrough operation position. 259 value = parentVal; 260 break; 261 } 262 case Predicates::UsersPos: { 263 auto *usersPos = cast<UsersPosition>(pos); 264 265 // The first operation retrieves the representative value of a range. 266 // This applies only when the parent is a range of values and we were 267 // requested to use a representative value (e.g., upward traversal). 268 if (parentVal.getType().isa<pdl::RangeType>() && 269 usersPos->useRepresentative()) 270 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 271 else 272 value = parentVal; 273 274 // The second operation retrieves the users. 275 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 276 break; 277 } 278 case Predicates::ForEachPos: { 279 assert(!failureBlockStack.empty() && "expected valid failure block"); 280 auto foreach = builder.create<pdl_interp::ForEachOp>( 281 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); 282 value = foreach.getLoopVariable(); 283 284 // Create the continuation block. 285 Block *continueBlock = builder.createBlock(&foreach.getRegion()); 286 builder.create<pdl_interp::ContinueOp>(loc); 287 failureBlockStack.push_back(continueBlock); 288 289 currentBlock = &foreach.getRegion().front(); 290 break; 291 } 292 case Predicates::OperandPos: { 293 auto *operandPos = cast<OperandPosition>(pos); 294 value = builder.create<pdl_interp::GetOperandOp>( 295 loc, builder.getType<pdl::ValueType>(), parentVal, 296 operandPos->getOperandNumber()); 297 break; 298 } 299 case Predicates::OperandGroupPos: { 300 auto *operandPos = cast<OperandGroupPosition>(pos); 301 Type valueTy = builder.getType<pdl::ValueType>(); 302 value = builder.create<pdl_interp::GetOperandsOp>( 303 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 304 parentVal, operandPos->getOperandGroupNumber()); 305 break; 306 } 307 case Predicates::AttributePos: { 308 auto *attrPos = cast<AttributePosition>(pos); 309 value = builder.create<pdl_interp::GetAttributeOp>( 310 loc, builder.getType<pdl::AttributeType>(), parentVal, 311 attrPos->getName().strref()); 312 break; 313 } 314 case Predicates::TypePos: { 315 if (parentVal.getType().isa<pdl::AttributeType>()) 316 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 317 else 318 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 319 break; 320 } 321 case Predicates::ResultPos: { 322 auto *resPos = cast<ResultPosition>(pos); 323 value = builder.create<pdl_interp::GetResultOp>( 324 loc, builder.getType<pdl::ValueType>(), parentVal, 325 resPos->getResultNumber()); 326 break; 327 } 328 case Predicates::ResultGroupPos: { 329 auto *resPos = cast<ResultGroupPosition>(pos); 330 Type valueTy = builder.getType<pdl::ValueType>(); 331 value = builder.create<pdl_interp::GetResultsOp>( 332 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 333 parentVal, resPos->getResultGroupNumber()); 334 break; 335 } 336 case Predicates::AttributeLiteralPos: { 337 auto *attrPos = cast<AttributeLiteralPosition>(pos); 338 value = 339 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); 340 break; 341 } 342 case Predicates::TypeLiteralPos: { 343 auto *typePos = cast<TypeLiteralPosition>(pos); 344 Attribute rawTypeAttr = typePos->getValue(); 345 if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>()) 346 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); 347 else 348 value = builder.create<pdl_interp::CreateTypesOp>( 349 loc, rawTypeAttr.cast<ArrayAttr>()); 350 break; 351 } 352 default: 353 llvm_unreachable("Generating unknown Position getter"); 354 break; 355 } 356 357 values.insert(pos, value); 358 return value; 359 } 360 361 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 362 Value val) { 363 Location loc = val.getLoc(); 364 Qualifier *question = boolNode->getQuestion(); 365 Qualifier *answer = boolNode->getAnswer(); 366 Region *region = currentBlock->getParent(); 367 368 // Execute the getValue queries first, so that we create success 369 // matcher in the correct (possibly nested) region. 370 SmallVector<Value> args; 371 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 372 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 373 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 374 for (Position *position : cstQuestion->getArgs()) 375 args.push_back(getValueAt(currentBlock, position)); 376 } 377 378 // Generate the matcher in the current (potentially nested) region 379 // and get the failure successor. 380 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); 381 Block *failure = failureBlockStack.back(); 382 383 // Finally, create the predicate. 384 builder.setInsertionPointToEnd(currentBlock); 385 Predicates::Kind kind = question->getKind(); 386 switch (kind) { 387 case Predicates::IsNotNullQuestion: 388 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 389 break; 390 case Predicates::OperationNameQuestion: { 391 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 392 builder.create<pdl_interp::CheckOperationNameOp>( 393 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 394 break; 395 } 396 case Predicates::TypeQuestion: { 397 auto *ans = cast<TypeAnswer>(answer); 398 if (val.getType().isa<pdl::RangeType>()) 399 builder.create<pdl_interp::CheckTypesOp>( 400 loc, val, ans->getValue().cast<ArrayAttr>(), success, failure); 401 else 402 builder.create<pdl_interp::CheckTypeOp>( 403 loc, val, ans->getValue().cast<TypeAttr>(), success, failure); 404 break; 405 } 406 case Predicates::AttributeQuestion: { 407 auto *ans = cast<AttributeAnswer>(answer); 408 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 409 success, failure); 410 break; 411 } 412 case Predicates::OperandCountAtLeastQuestion: 413 case Predicates::OperandCountQuestion: 414 builder.create<pdl_interp::CheckOperandCountOp>( 415 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 416 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 417 success, failure); 418 break; 419 case Predicates::ResultCountAtLeastQuestion: 420 case Predicates::ResultCountQuestion: 421 builder.create<pdl_interp::CheckResultCountOp>( 422 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 423 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 424 success, failure); 425 break; 426 case Predicates::EqualToQuestion: { 427 bool trueAnswer = isa<TrueAnswer>(answer); 428 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 429 trueAnswer ? success : failure, 430 trueAnswer ? failure : success); 431 break; 432 } 433 case Predicates::ConstraintQuestion: { 434 auto *cstQuestion = cast<ConstraintQuestion>(question); 435 builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(), 436 args, success, failure); 437 break; 438 } 439 default: 440 llvm_unreachable("Generating unknown Predicate operation"); 441 } 442 } 443 444 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 445 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 446 llvm::MapVector<Qualifier *, Block *> &dests) { 447 std::vector<ValT> values; 448 std::vector<Block *> blocks; 449 values.reserve(dests.size()); 450 blocks.reserve(dests.size()); 451 for (const auto &it : dests) { 452 blocks.push_back(it.second); 453 values.push_back(cast<PredT>(it.first)->getValue()); 454 } 455 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 456 } 457 458 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 459 Value val) { 460 Qualifier *question = switchNode->getQuestion(); 461 Region *region = currentBlock->getParent(); 462 Block *defaultDest = failureBlockStack.back(); 463 464 // If the switch question is not an exact answer, i.e. for the `at_least` 465 // cases, we generate a special block sequence. 466 Predicates::Kind kind = question->getKind(); 467 if (kind == Predicates::OperandCountAtLeastQuestion || 468 kind == Predicates::ResultCountAtLeastQuestion) { 469 // Order the children such that the cases are in reverse numerical order. 470 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 471 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 472 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 473 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 474 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 475 }); 476 477 // Build the destination for each child using the next highest child as a 478 // a failure destination. This essentially creates the following control 479 // flow: 480 // 481 // if (operand_count < 1) 482 // goto failure 483 // if (child1.match()) 484 // ... 485 // 486 // if (operand_count < 2) 487 // goto failure 488 // if (child2.match()) 489 // ... 490 // 491 // failure: 492 // ... 493 // 494 failureBlockStack.push_back(defaultDest); 495 Location loc = val.getLoc(); 496 for (unsigned idx : sortedChildren) { 497 auto &child = switchNode->getChild(idx); 498 Block *childBlock = generateMatcher(*child.second, *region); 499 Block *predicateBlock = builder.createBlock(childBlock); 500 builder.setInsertionPointToEnd(predicateBlock); 501 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 502 switch (kind) { 503 case Predicates::OperandCountAtLeastQuestion: 504 builder.create<pdl_interp::CheckOperandCountOp>( 505 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 506 break; 507 case Predicates::ResultCountAtLeastQuestion: 508 builder.create<pdl_interp::CheckResultCountOp>( 509 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 510 break; 511 default: 512 llvm_unreachable("Generating invalid AtLeast operation"); 513 } 514 failureBlockStack.back() = predicateBlock; 515 } 516 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 517 currentBlock->getOperations().splice(currentBlock->end(), 518 firstPredicateBlock->getOperations()); 519 firstPredicateBlock->erase(); 520 return; 521 } 522 523 // Otherwise, generate each of the children and generate an interpreter 524 // switch. 525 llvm::MapVector<Qualifier *, Block *> children; 526 for (auto &it : switchNode->getChildren()) 527 children.insert({it.first, generateMatcher(*it.second, *region)}); 528 builder.setInsertionPointToEnd(currentBlock); 529 530 switch (question->getKind()) { 531 case Predicates::OperandCountQuestion: 532 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 533 int32_t>(val, defaultDest, builder, children); 534 case Predicates::ResultCountQuestion: 535 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 536 int32_t>(val, defaultDest, builder, children); 537 case Predicates::OperationNameQuestion: 538 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 539 OperationNameAnswer>(val, defaultDest, builder, 540 children); 541 case Predicates::TypeQuestion: 542 if (val.getType().isa<pdl::RangeType>()) { 543 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 544 val, defaultDest, builder, children); 545 } 546 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 547 val, defaultDest, builder, children); 548 case Predicates::AttributeQuestion: 549 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 550 val, defaultDest, builder, children); 551 default: 552 llvm_unreachable("Generating unknown switch predicate."); 553 } 554 } 555 556 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 557 pdl::PatternOp pattern = successNode->getPattern(); 558 Value root = successNode->getRoot(); 559 560 // Generate a rewriter for the pattern this success node represents, and track 561 // any values used from the match region. 562 SmallVector<Position *, 8> usedMatchValues; 563 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 564 565 // Process any values used in the rewrite that are defined in the match. 566 std::vector<Value> mappedMatchValues; 567 mappedMatchValues.reserve(usedMatchValues.size()); 568 for (Position *position : usedMatchValues) 569 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 570 571 // Collect the set of operations generated by the rewriter. 572 SmallVector<StringRef, 4> generatedOps; 573 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 574 generatedOps.push_back(*op.name()); 575 ArrayAttr generatedOpsAttr; 576 if (!generatedOps.empty()) 577 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 578 579 // Grab the root kind if present. 580 StringAttr rootKindAttr; 581 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 582 if (Optional<StringRef> rootKind = rootOp.name()) 583 rootKindAttr = builder.getStringAttr(*rootKind); 584 585 builder.setInsertionPointToEnd(currentBlock); 586 builder.create<pdl_interp::RecordMatchOp>( 587 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 588 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 589 failureBlockStack.back()); 590 } 591 592 SymbolRefAttr PatternLowering::generateRewriter( 593 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 594 builder.setInsertionPointToEnd(rewriterModule.getBody()); 595 auto rewriterFunc = builder.create<pdl_interp::FuncOp>( 596 pattern.getLoc(), "pdl_generated_rewriter", 597 builder.getFunctionType(llvm::None, llvm::None)); 598 rewriterSymbolTable.insert(rewriterFunc); 599 600 // Generate the rewriter function body. 601 builder.setInsertionPointToEnd(&rewriterFunc.front()); 602 603 // Map an input operand of the pattern to a generated interpreter value. 604 DenseMap<Value, Value> rewriteValues; 605 auto mapRewriteValue = [&](Value oldValue) { 606 Value &newValue = rewriteValues[oldValue]; 607 if (newValue) 608 return newValue; 609 610 // Prefer materializing constants directly when possible. 611 Operation *oldOp = oldValue.getDefiningOp(); 612 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 613 if (Attribute value = attrOp.valueAttr()) { 614 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 615 attrOp.getLoc(), value); 616 } 617 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 618 if (TypeAttr type = typeOp.typeAttr()) { 619 return newValue = builder.create<pdl_interp::CreateTypeOp>( 620 typeOp.getLoc(), type); 621 } 622 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 623 if (ArrayAttr type = typeOp.typesAttr()) { 624 return newValue = builder.create<pdl_interp::CreateTypesOp>( 625 typeOp.getLoc(), typeOp.getType(), type); 626 } 627 } 628 629 // Otherwise, add this as an input to the rewriter. 630 Position *inputPos = valueToPosition.lookup(oldValue); 631 assert(inputPos && "expected value to be a pattern input"); 632 usedMatchValues.push_back(inputPos); 633 return newValue = rewriterFunc.front().addArgument(oldValue.getType(), 634 oldValue.getLoc()); 635 }; 636 637 // If this is a custom rewriter, simply dispatch to the registered rewrite 638 // method. 639 pdl::RewriteOp rewriter = pattern.getRewriter(); 640 if (StringAttr rewriteName = rewriter.nameAttr()) { 641 SmallVector<Value> args; 642 if (rewriter.root()) 643 args.push_back(mapRewriteValue(rewriter.root())); 644 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 645 args.append(mappedArgs.begin(), mappedArgs.end()); 646 builder.create<pdl_interp::ApplyRewriteOp>( 647 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); 648 } else { 649 // Otherwise this is a dag rewriter defined using PDL operations. 650 for (Operation &rewriteOp : *rewriter.getBody()) { 651 llvm::TypeSwitch<Operation *>(&rewriteOp) 652 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 653 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 654 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 655 this->generateRewriter(op, rewriteValues, mapRewriteValue); 656 }); 657 } 658 } 659 660 // Update the signature of the rewrite function. 661 rewriterFunc.setType(builder.getFunctionType( 662 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 663 /*results=*/llvm::None)); 664 665 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 666 return SymbolRefAttr::get( 667 builder.getContext(), 668 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 669 SymbolRefAttr::get(rewriterFunc)); 670 } 671 672 void PatternLowering::generateRewriter( 673 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 674 function_ref<Value(Value)> mapRewriteValue) { 675 SmallVector<Value, 2> arguments; 676 for (Value argument : rewriteOp.args()) 677 arguments.push_back(mapRewriteValue(argument)); 678 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 679 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 680 arguments); 681 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) 682 rewriteValues[std::get<0>(it)] = std::get<1>(it); 683 } 684 685 void PatternLowering::generateRewriter( 686 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 687 function_ref<Value(Value)> mapRewriteValue) { 688 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 689 attrOp.getLoc(), attrOp.valueAttr()); 690 rewriteValues[attrOp] = newAttr; 691 } 692 693 void PatternLowering::generateRewriter( 694 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 695 function_ref<Value(Value)> mapRewriteValue) { 696 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 697 mapRewriteValue(eraseOp.operation())); 698 } 699 700 void PatternLowering::generateRewriter( 701 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 702 function_ref<Value(Value)> mapRewriteValue) { 703 SmallVector<Value, 4> operands; 704 for (Value operand : operationOp.operands()) 705 operands.push_back(mapRewriteValue(operand)); 706 707 SmallVector<Value, 4> attributes; 708 for (Value attr : operationOp.attributes()) 709 attributes.push_back(mapRewriteValue(attr)); 710 711 bool hasInferredResultTypes = false; 712 SmallVector<Value, 2> types; 713 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types, 714 rewriteValues, hasInferredResultTypes); 715 716 // Create the new operation. 717 Location loc = operationOp.getLoc(); 718 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 719 loc, *operationOp.name(), types, hasInferredResultTypes, operands, 720 attributes, operationOp.attributeNames()); 721 rewriteValues[operationOp.op()] = createdOp; 722 723 // Generate accesses for any results that have their types constrained. 724 // Handle the case where there is a single range representing all of the 725 // result types. 726 OperandRange resultTys = operationOp.types(); 727 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 728 Value &type = rewriteValues[resultTys[0]]; 729 if (!type) { 730 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 731 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 732 } 733 return; 734 } 735 736 // Otherwise, populate the individual results. 737 bool seenVariableLength = false; 738 Type valueTy = builder.getType<pdl::ValueType>(); 739 Type valueRangeTy = pdl::RangeType::get(valueTy); 740 for (const auto &it : llvm::enumerate(resultTys)) { 741 Value &type = rewriteValues[it.value()]; 742 if (type) 743 continue; 744 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 745 seenVariableLength |= isVariadic; 746 747 // After a variable length result has been seen, we need to use result 748 // groups because the exact index of the result is not statically known. 749 Value resultVal; 750 if (seenVariableLength) 751 resultVal = builder.create<pdl_interp::GetResultsOp>( 752 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 753 else 754 resultVal = builder.create<pdl_interp::GetResultOp>( 755 loc, valueTy, createdOp, it.index()); 756 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 757 } 758 } 759 760 void PatternLowering::generateRewriter( 761 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 762 function_ref<Value(Value)> mapRewriteValue) { 763 SmallVector<Value, 4> replOperands; 764 765 // If the replacement was another operation, get its results. `pdl` allows 766 // for using an operation for simplicitly, but the interpreter isn't as 767 // user facing. 768 if (Value replOp = replaceOp.replOperation()) { 769 // Don't use replace if we know the replaced operation has no results. 770 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 771 if (!opOp || !opOp.types().empty()) { 772 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 773 replOp.getLoc(), mapRewriteValue(replOp))); 774 } 775 } else { 776 for (Value operand : replaceOp.replValues()) 777 replOperands.push_back(mapRewriteValue(operand)); 778 } 779 780 // If there are no replacement values, just create an erase instead. 781 if (replOperands.empty()) { 782 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 783 mapRewriteValue(replaceOp.operation())); 784 return; 785 } 786 787 builder.create<pdl_interp::ReplaceOp>( 788 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 789 } 790 791 void PatternLowering::generateRewriter( 792 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 793 function_ref<Value(Value)> mapRewriteValue) { 794 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 795 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 796 mapRewriteValue(resultOp.parent()), resultOp.index()); 797 } 798 799 void PatternLowering::generateRewriter( 800 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 801 function_ref<Value(Value)> mapRewriteValue) { 802 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 803 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 804 resultOp.index()); 805 } 806 807 void PatternLowering::generateRewriter( 808 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 809 function_ref<Value(Value)> mapRewriteValue) { 810 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 811 // type. 812 if (TypeAttr typeAttr = typeOp.typeAttr()) { 813 rewriteValues[typeOp] = 814 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 815 } 816 } 817 818 void PatternLowering::generateRewriter( 819 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 820 function_ref<Value(Value)> mapRewriteValue) { 821 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 822 // type. 823 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 824 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 825 typeOp.getLoc(), typeOp.getType(), typeAttr); 826 } 827 } 828 829 void PatternLowering::generateOperationResultTypeRewriter( 830 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 831 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 832 bool &hasInferredResultTypes) { 833 // Look for an operation that was replaced by `op`. The result types will be 834 // inferred from the results that were replaced. 835 Block *rewriterBlock = op->getBlock(); 836 for (OpOperand &use : op.op().getUses()) { 837 // Check that the use corresponds to a ReplaceOp and that it is the 838 // replacement value, not the operation being replaced. 839 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 840 if (!replOpUser || use.getOperandNumber() == 0) 841 continue; 842 // Make sure the replaced operation was defined before this one. 843 Value replOpVal = replOpUser.operation(); 844 Operation *replacedOp = replOpVal.getDefiningOp(); 845 if (replacedOp->getBlock() == rewriterBlock && 846 !replacedOp->isBeforeInBlock(op)) 847 continue; 848 849 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 850 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 851 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 852 replacedOp->getLoc(), replacedOpResults)); 853 return; 854 } 855 856 // Try to handle resolution for each of the result types individually. This is 857 // preferred over type inferrence because it will allow for us to use existing 858 // types directly, as opposed to trying to rebuild the type list. 859 OperandRange resultTypeValues = op.types(); 860 auto tryResolveResultTypes = [&] { 861 types.reserve(resultTypeValues.size()); 862 for (const auto &it : llvm::enumerate(resultTypeValues)) { 863 Value resultType = it.value(); 864 865 // Check for an already translated value. 866 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 867 types.push_back(existingRewriteValue); 868 continue; 869 } 870 871 // Check for an input from the matcher. 872 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 873 types.push_back(mapRewriteValue(resultType)); 874 continue; 875 } 876 877 // Otherwise, we couldn't infer the result types. Bail out here to see if 878 // we can infer the types for this operation from another way. 879 types.clear(); 880 return failure(); 881 } 882 return success(); 883 }; 884 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes())) 885 return; 886 887 // Otherwise, check if the operation has type inference support itself. 888 if (op.hasTypeInference()) { 889 hasInferredResultTypes = true; 890 return; 891 } 892 893 // If the types could not be inferred from any context and there weren't any 894 // explicit result types, assume the user actually meant for the operation to 895 // have no results. 896 if (resultTypeValues.empty()) 897 return; 898 899 // The verifier asserts that the result types of each pdl.operation can be 900 // inferred. If we reach here, there is a bug either in the logic above or 901 // in the verifier for pdl.operation. 902 op->emitOpError() << "unable to infer result type for operation"; 903 llvm_unreachable("unable to infer result type for operation"); 904 } 905 906 //===----------------------------------------------------------------------===// 907 // Conversion Pass 908 //===----------------------------------------------------------------------===// 909 910 namespace { 911 struct PDLToPDLInterpPass 912 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 913 void runOnOperation() final; 914 }; 915 } // namespace 916 917 /// Convert the given module containing PDL pattern operations into a PDL 918 /// Interpreter operations. 919 void PDLToPDLInterpPass::runOnOperation() { 920 ModuleOp module = getOperation(); 921 922 // Create the main matcher function This function contains all of the match 923 // related functionality from patterns in the module. 924 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 925 auto matcherFunc = builder.create<pdl_interp::FuncOp>( 926 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 927 builder.getFunctionType(builder.getType<pdl::OperationType>(), 928 /*results=*/llvm::None), 929 /*attrs=*/llvm::None); 930 931 // Create a nested module to hold the functions invoked for rewriting the IR 932 // after a successful match. 933 ModuleOp rewriterModule = builder.create<ModuleOp>( 934 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 935 936 // Generate the code for the patterns within the module. 937 PatternLowering generator(matcherFunc, rewriterModule); 938 generator.lower(module); 939 940 // After generation, delete all of the pattern operations. 941 for (pdl::PatternOp pattern : 942 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 943 pattern.erase(); 944 } 945 946 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 947 return std::make_unique<PDLToPDLInterpPass>(); 948 } 949