1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP 25 #include "mlir/Conversion/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::pdl_to_pdl_interp; 30 31 //===----------------------------------------------------------------------===// 32 // PatternLowering 33 //===----------------------------------------------------------------------===// 34 35 namespace { 36 /// This class generators operations within the PDL Interpreter dialect from a 37 /// given module containing PDL pattern operations. 38 struct PatternLowering { 39 public: 40 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, 41 DenseMap<Operation *, PDLPatternConfigSet *> *configMap); 42 43 /// Generate code for matching and rewriting based on the pattern operations 44 /// within the module. 45 void lower(ModuleOp module); 46 47 private: 48 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 49 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 50 51 /// Generate interpreter operations for the tree rooted at the given matcher 52 /// node, in the specified region. 53 Block *generateMatcher(MatcherNode &node, Region ®ion); 54 55 /// Get or create an access to the provided positional value in the current 56 /// block. This operation may mutate the provided block pointer if nested 57 /// regions (i.e., pdl_interp.iterate) are required. 58 Value getValueAt(Block *¤tBlock, Position *pos); 59 60 /// Create the interpreter predicate operations. This operation may mutate the 61 /// provided current block pointer if nested regions (iterates) are required. 62 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 63 64 /// Create the interpreter switch / predicate operations, with several case 65 /// destinations. This operation never mutates the provided current block 66 /// pointer, because the switch operation does not need Values beyond `val`. 67 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 68 69 /// Create the interpreter operations to record a successful pattern match 70 /// using the contained root operation. This operation may mutate the current 71 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 72 void generate(SuccessNode *successNode, Block *¤tBlock); 73 74 /// Generate a rewriter function for the given pattern operation, and returns 75 /// a reference to that function. 76 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 77 SmallVectorImpl<Position *> &usedMatchValues); 78 79 /// Generate the rewriter code for the given operation. 80 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::AttributeOp attrOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::EraseOp eraseOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::OperationOp operationOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::RangeOp rangeOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::ReplaceOp replaceOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::ResultOp resultOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 void generateRewriter(pdl::ResultsOp resultOp, 102 DenseMap<Value, Value> &rewriteValues, 103 function_ref<Value(Value)> mapRewriteValue); 104 void generateRewriter(pdl::TypeOp typeOp, 105 DenseMap<Value, Value> &rewriteValues, 106 function_ref<Value(Value)> mapRewriteValue); 107 void generateRewriter(pdl::TypesOp typeOp, 108 DenseMap<Value, Value> &rewriteValues, 109 function_ref<Value(Value)> mapRewriteValue); 110 111 /// Generate the values used for resolving the result types of an operation 112 /// created within a dag rewriter region. If the result types of the operation 113 /// should be inferred, `hasInferredResultTypes` is set to true. 114 void generateOperationResultTypeRewriter( 115 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 116 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 117 bool &hasInferredResultTypes); 118 119 /// A builder to use when generating interpreter operations. 120 OpBuilder builder; 121 122 /// The matcher function used for all match related logic within PDL patterns. 123 pdl_interp::FuncOp matcherFunc; 124 125 /// The rewriter module containing the all rewrite related logic within PDL 126 /// patterns. 127 ModuleOp rewriterModule; 128 129 /// The symbol table of the rewriter module used for insertion. 130 SymbolTable rewriterSymbolTable; 131 132 /// A scoped map connecting a position with the corresponding interpreter 133 /// value. 134 ValueMap values; 135 136 /// A stack of blocks used as the failure destination for matcher nodes that 137 /// don't have an explicit failure path. 138 SmallVector<Block *, 8> failureBlockStack; 139 140 /// A mapping between values defined in a pattern match, and the corresponding 141 /// positional value. 142 DenseMap<Value, Position *> valueToPosition; 143 144 /// The set of operation values whose location will be used for newly 145 /// generated operations. 146 SetVector<Value> locOps; 147 148 /// A mapping between pattern operations and the corresponding configuration 149 /// set. 150 DenseMap<Operation *, PDLPatternConfigSet *> *configMap; 151 }; 152 } // namespace 153 154 PatternLowering::PatternLowering( 155 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, 156 DenseMap<Operation *, PDLPatternConfigSet *> *configMap) 157 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 158 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule), 159 configMap(configMap) {} 160 161 void PatternLowering::lower(ModuleOp module) { 162 PredicateUniquer predicateUniquer; 163 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 164 165 // Define top-level scope for the arguments to the matcher function. 166 ValueMapScope topLevelValueScope(values); 167 168 // Insert the root operation, i.e. argument to the matcher, at the root 169 // position. 170 Block *matcherEntryBlock = &matcherFunc.front(); 171 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 172 173 // Generate a root matcher node from the provided PDL module. 174 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 175 module, predicateBuilder, valueToPosition); 176 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 177 assert(failureBlockStack.empty() && "failed to empty the stack"); 178 179 // After generation, merged the first matched block into the entry. 180 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 181 firstMatcherBlock->getOperations()); 182 firstMatcherBlock->erase(); 183 } 184 185 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { 186 // Push a new scope for the values used by this matcher. 187 Block *block = ®ion.emplaceBlock(); 188 ValueMapScope scope(values); 189 190 // If this is the return node, simply insert the corresponding interpreter 191 // finalize. 192 if (isa<ExitNode>(node)) { 193 builder.setInsertionPointToEnd(block); 194 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 195 return block; 196 } 197 198 // Get the next block in the match sequence. 199 // This is intentionally executed first, before we get the value for the 200 // position associated with the node, so that we preserve an "there exist" 201 // semantics: if getting a value requires an upward traversal (going from a 202 // value to its consumers), we want to perform the check on all the consumers 203 // before we pass control to the failure node. 204 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 205 Block *failureBlock; 206 if (failureNode) { 207 failureBlock = generateMatcher(*failureNode, region); 208 failureBlockStack.push_back(failureBlock); 209 } else { 210 assert(!failureBlockStack.empty() && "expected valid failure block"); 211 failureBlock = failureBlockStack.back(); 212 } 213 214 // If this node contains a position, get the corresponding value for this 215 // block. 216 Block *currentBlock = block; 217 Position *position = node.getPosition(); 218 Value val = position ? getValueAt(currentBlock, position) : Value(); 219 220 // If this value corresponds to an operation, record that we are going to use 221 // its location as part of a fused location. 222 bool isOperationValue = val && isa<pdl::OperationType>(val.getType()); 223 if (isOperationValue) 224 locOps.insert(val); 225 226 // Dispatch to the correct method based on derived node type. 227 TypeSwitch<MatcherNode *>(&node) 228 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { 229 this->generate(derivedNode, currentBlock, val); 230 }) 231 .Case([&](SuccessNode *successNode) { 232 generate(successNode, currentBlock); 233 }); 234 235 // Pop all the failure blocks that were inserted due to nesting of 236 // pdl_interp.iterate. 237 while (failureBlockStack.back() != failureBlock) { 238 failureBlockStack.pop_back(); 239 assert(!failureBlockStack.empty() && "unable to locate failure block"); 240 } 241 242 // Pop the new failure block. 243 if (failureNode) 244 failureBlockStack.pop_back(); 245 246 if (isOperationValue) 247 locOps.remove(val); 248 249 return block; 250 } 251 252 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 253 if (Value val = values.lookup(pos)) 254 return val; 255 256 // Get the value for the parent position. 257 Value parentVal; 258 if (Position *parent = pos->getParent()) 259 parentVal = getValueAt(currentBlock, parent); 260 261 // TODO: Use a location from the position. 262 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); 263 builder.setInsertionPointToEnd(currentBlock); 264 Value value; 265 switch (pos->getKind()) { 266 case Predicates::OperationPos: { 267 auto *operationPos = cast<OperationPosition>(pos); 268 if (operationPos->isOperandDefiningOp()) 269 // Standard (downward) traversal which directly follows the defining op. 270 value = builder.create<pdl_interp::GetDefiningOpOp>( 271 loc, builder.getType<pdl::OperationType>(), parentVal); 272 else 273 // A passthrough operation position. 274 value = parentVal; 275 break; 276 } 277 case Predicates::UsersPos: { 278 auto *usersPos = cast<UsersPosition>(pos); 279 280 // The first operation retrieves the representative value of a range. 281 // This applies only when the parent is a range of values and we were 282 // requested to use a representative value (e.g., upward traversal). 283 if (isa<pdl::RangeType>(parentVal.getType()) && 284 usersPos->useRepresentative()) 285 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 286 else 287 value = parentVal; 288 289 // The second operation retrieves the users. 290 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 291 break; 292 } 293 case Predicates::ForEachPos: { 294 assert(!failureBlockStack.empty() && "expected valid failure block"); 295 auto foreach = builder.create<pdl_interp::ForEachOp>( 296 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); 297 value = foreach.getLoopVariable(); 298 299 // Create the continuation block. 300 Block *continueBlock = builder.createBlock(&foreach.getRegion()); 301 builder.create<pdl_interp::ContinueOp>(loc); 302 failureBlockStack.push_back(continueBlock); 303 304 currentBlock = &foreach.getRegion().front(); 305 break; 306 } 307 case Predicates::OperandPos: { 308 auto *operandPos = cast<OperandPosition>(pos); 309 value = builder.create<pdl_interp::GetOperandOp>( 310 loc, builder.getType<pdl::ValueType>(), parentVal, 311 operandPos->getOperandNumber()); 312 break; 313 } 314 case Predicates::OperandGroupPos: { 315 auto *operandPos = cast<OperandGroupPosition>(pos); 316 Type valueTy = builder.getType<pdl::ValueType>(); 317 value = builder.create<pdl_interp::GetOperandsOp>( 318 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 319 parentVal, operandPos->getOperandGroupNumber()); 320 break; 321 } 322 case Predicates::AttributePos: { 323 auto *attrPos = cast<AttributePosition>(pos); 324 value = builder.create<pdl_interp::GetAttributeOp>( 325 loc, builder.getType<pdl::AttributeType>(), parentVal, 326 attrPos->getName().strref()); 327 break; 328 } 329 case Predicates::TypePos: { 330 if (isa<pdl::AttributeType>(parentVal.getType())) 331 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 332 else 333 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 334 break; 335 } 336 case Predicates::ResultPos: { 337 auto *resPos = cast<ResultPosition>(pos); 338 value = builder.create<pdl_interp::GetResultOp>( 339 loc, builder.getType<pdl::ValueType>(), parentVal, 340 resPos->getResultNumber()); 341 break; 342 } 343 case Predicates::ResultGroupPos: { 344 auto *resPos = cast<ResultGroupPosition>(pos); 345 Type valueTy = builder.getType<pdl::ValueType>(); 346 value = builder.create<pdl_interp::GetResultsOp>( 347 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 348 parentVal, resPos->getResultGroupNumber()); 349 break; 350 } 351 case Predicates::AttributeLiteralPos: { 352 auto *attrPos = cast<AttributeLiteralPosition>(pos); 353 value = 354 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); 355 break; 356 } 357 case Predicates::TypeLiteralPos: { 358 auto *typePos = cast<TypeLiteralPosition>(pos); 359 Attribute rawTypeAttr = typePos->getValue(); 360 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr)) 361 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); 362 else 363 value = builder.create<pdl_interp::CreateTypesOp>( 364 loc, cast<ArrayAttr>(rawTypeAttr)); 365 break; 366 } 367 default: 368 llvm_unreachable("Generating unknown Position getter"); 369 break; 370 } 371 372 values.insert(pos, value); 373 return value; 374 } 375 376 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 377 Value val) { 378 Location loc = val.getLoc(); 379 Qualifier *question = boolNode->getQuestion(); 380 Qualifier *answer = boolNode->getAnswer(); 381 Region *region = currentBlock->getParent(); 382 383 // Execute the getValue queries first, so that we create success 384 // matcher in the correct (possibly nested) region. 385 SmallVector<Value> args; 386 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 387 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 388 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 389 for (Position *position : cstQuestion->getArgs()) 390 args.push_back(getValueAt(currentBlock, position)); 391 } 392 393 // Generate the matcher in the current (potentially nested) region 394 // and get the failure successor. 395 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); 396 Block *failure = failureBlockStack.back(); 397 398 // Finally, create the predicate. 399 builder.setInsertionPointToEnd(currentBlock); 400 Predicates::Kind kind = question->getKind(); 401 switch (kind) { 402 case Predicates::IsNotNullQuestion: 403 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 404 break; 405 case Predicates::OperationNameQuestion: { 406 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 407 builder.create<pdl_interp::CheckOperationNameOp>( 408 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 409 break; 410 } 411 case Predicates::TypeQuestion: { 412 auto *ans = cast<TypeAnswer>(answer); 413 if (isa<pdl::RangeType>(val.getType())) 414 builder.create<pdl_interp::CheckTypesOp>( 415 loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure); 416 else 417 builder.create<pdl_interp::CheckTypeOp>( 418 loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure); 419 break; 420 } 421 case Predicates::AttributeQuestion: { 422 auto *ans = cast<AttributeAnswer>(answer); 423 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 424 success, failure); 425 break; 426 } 427 case Predicates::OperandCountAtLeastQuestion: 428 case Predicates::OperandCountQuestion: 429 builder.create<pdl_interp::CheckOperandCountOp>( 430 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 431 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 432 success, failure); 433 break; 434 case Predicates::ResultCountAtLeastQuestion: 435 case Predicates::ResultCountQuestion: 436 builder.create<pdl_interp::CheckResultCountOp>( 437 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 438 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 439 success, failure); 440 break; 441 case Predicates::EqualToQuestion: { 442 bool trueAnswer = isa<TrueAnswer>(answer); 443 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 444 trueAnswer ? success : failure, 445 trueAnswer ? failure : success); 446 break; 447 } 448 case Predicates::ConstraintQuestion: { 449 auto *cstQuestion = cast<ConstraintQuestion>(question); 450 builder.create<pdl_interp::ApplyConstraintOp>( 451 loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success, 452 failure); 453 break; 454 } 455 default: 456 llvm_unreachable("Generating unknown Predicate operation"); 457 } 458 } 459 460 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 461 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 462 llvm::MapVector<Qualifier *, Block *> &dests) { 463 std::vector<ValT> values; 464 std::vector<Block *> blocks; 465 values.reserve(dests.size()); 466 blocks.reserve(dests.size()); 467 for (const auto &it : dests) { 468 blocks.push_back(it.second); 469 values.push_back(cast<PredT>(it.first)->getValue()); 470 } 471 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 472 } 473 474 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 475 Value val) { 476 Qualifier *question = switchNode->getQuestion(); 477 Region *region = currentBlock->getParent(); 478 Block *defaultDest = failureBlockStack.back(); 479 480 // If the switch question is not an exact answer, i.e. for the `at_least` 481 // cases, we generate a special block sequence. 482 Predicates::Kind kind = question->getKind(); 483 if (kind == Predicates::OperandCountAtLeastQuestion || 484 kind == Predicates::ResultCountAtLeastQuestion) { 485 // Order the children such that the cases are in reverse numerical order. 486 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 487 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 488 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 489 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 490 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 491 }); 492 493 // Build the destination for each child using the next highest child as a 494 // a failure destination. This essentially creates the following control 495 // flow: 496 // 497 // if (operand_count < 1) 498 // goto failure 499 // if (child1.match()) 500 // ... 501 // 502 // if (operand_count < 2) 503 // goto failure 504 // if (child2.match()) 505 // ... 506 // 507 // failure: 508 // ... 509 // 510 failureBlockStack.push_back(defaultDest); 511 Location loc = val.getLoc(); 512 for (unsigned idx : sortedChildren) { 513 auto &child = switchNode->getChild(idx); 514 Block *childBlock = generateMatcher(*child.second, *region); 515 Block *predicateBlock = builder.createBlock(childBlock); 516 builder.setInsertionPointToEnd(predicateBlock); 517 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 518 switch (kind) { 519 case Predicates::OperandCountAtLeastQuestion: 520 builder.create<pdl_interp::CheckOperandCountOp>( 521 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 522 break; 523 case Predicates::ResultCountAtLeastQuestion: 524 builder.create<pdl_interp::CheckResultCountOp>( 525 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 526 break; 527 default: 528 llvm_unreachable("Generating invalid AtLeast operation"); 529 } 530 failureBlockStack.back() = predicateBlock; 531 } 532 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 533 currentBlock->getOperations().splice(currentBlock->end(), 534 firstPredicateBlock->getOperations()); 535 firstPredicateBlock->erase(); 536 return; 537 } 538 539 // Otherwise, generate each of the children and generate an interpreter 540 // switch. 541 llvm::MapVector<Qualifier *, Block *> children; 542 for (auto &it : switchNode->getChildren()) 543 children.insert({it.first, generateMatcher(*it.second, *region)}); 544 builder.setInsertionPointToEnd(currentBlock); 545 546 switch (question->getKind()) { 547 case Predicates::OperandCountQuestion: 548 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 549 int32_t>(val, defaultDest, builder, children); 550 case Predicates::ResultCountQuestion: 551 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 552 int32_t>(val, defaultDest, builder, children); 553 case Predicates::OperationNameQuestion: 554 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 555 OperationNameAnswer>(val, defaultDest, builder, 556 children); 557 case Predicates::TypeQuestion: 558 if (isa<pdl::RangeType>(val.getType())) { 559 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 560 val, defaultDest, builder, children); 561 } 562 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 563 val, defaultDest, builder, children); 564 case Predicates::AttributeQuestion: 565 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 566 val, defaultDest, builder, children); 567 default: 568 llvm_unreachable("Generating unknown switch predicate."); 569 } 570 } 571 572 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 573 pdl::PatternOp pattern = successNode->getPattern(); 574 Value root = successNode->getRoot(); 575 576 // Generate a rewriter for the pattern this success node represents, and track 577 // any values used from the match region. 578 SmallVector<Position *, 8> usedMatchValues; 579 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 580 581 // Process any values used in the rewrite that are defined in the match. 582 std::vector<Value> mappedMatchValues; 583 mappedMatchValues.reserve(usedMatchValues.size()); 584 for (Position *position : usedMatchValues) 585 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 586 587 // Collect the set of operations generated by the rewriter. 588 SmallVector<StringRef, 4> generatedOps; 589 for (auto op : 590 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>()) 591 generatedOps.push_back(*op.getOpName()); 592 ArrayAttr generatedOpsAttr; 593 if (!generatedOps.empty()) 594 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 595 596 // Grab the root kind if present. 597 StringAttr rootKindAttr; 598 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 599 if (std::optional<StringRef> rootKind = rootOp.getOpName()) 600 rootKindAttr = builder.getStringAttr(*rootKind); 601 602 builder.setInsertionPointToEnd(currentBlock); 603 auto matchOp = builder.create<pdl_interp::RecordMatchOp>( 604 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 605 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), 606 failureBlockStack.back()); 607 608 // Set the config of the lowered match to the parent pattern. 609 if (configMap) 610 configMap->try_emplace(matchOp, configMap->lookup(pattern)); 611 } 612 613 SymbolRefAttr PatternLowering::generateRewriter( 614 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 615 builder.setInsertionPointToEnd(rewriterModule.getBody()); 616 auto rewriterFunc = builder.create<pdl_interp::FuncOp>( 617 pattern.getLoc(), "pdl_generated_rewriter", 618 builder.getFunctionType(std::nullopt, std::nullopt)); 619 rewriterSymbolTable.insert(rewriterFunc); 620 621 // Generate the rewriter function body. 622 builder.setInsertionPointToEnd(&rewriterFunc.front()); 623 624 // Map an input operand of the pattern to a generated interpreter value. 625 DenseMap<Value, Value> rewriteValues; 626 auto mapRewriteValue = [&](Value oldValue) { 627 Value &newValue = rewriteValues[oldValue]; 628 if (newValue) 629 return newValue; 630 631 // Prefer materializing constants directly when possible. 632 Operation *oldOp = oldValue.getDefiningOp(); 633 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 634 if (Attribute value = attrOp.getValueAttr()) { 635 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 636 attrOp.getLoc(), value); 637 } 638 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 639 if (TypeAttr type = typeOp.getConstantTypeAttr()) { 640 return newValue = builder.create<pdl_interp::CreateTypeOp>( 641 typeOp.getLoc(), type); 642 } 643 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 644 if (ArrayAttr type = typeOp.getConstantTypesAttr()) { 645 return newValue = builder.create<pdl_interp::CreateTypesOp>( 646 typeOp.getLoc(), typeOp.getType(), type); 647 } 648 } 649 650 // Otherwise, add this as an input to the rewriter. 651 Position *inputPos = valueToPosition.lookup(oldValue); 652 assert(inputPos && "expected value to be a pattern input"); 653 usedMatchValues.push_back(inputPos); 654 return newValue = rewriterFunc.front().addArgument(oldValue.getType(), 655 oldValue.getLoc()); 656 }; 657 658 // If this is a custom rewriter, simply dispatch to the registered rewrite 659 // method. 660 pdl::RewriteOp rewriter = pattern.getRewriter(); 661 if (StringAttr rewriteName = rewriter.getNameAttr()) { 662 SmallVector<Value> args; 663 if (rewriter.getRoot()) 664 args.push_back(mapRewriteValue(rewriter.getRoot())); 665 auto mappedArgs = 666 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); 667 args.append(mappedArgs.begin(), mappedArgs.end()); 668 builder.create<pdl_interp::ApplyRewriteOp>( 669 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); 670 } else { 671 // Otherwise this is a dag rewriter defined using PDL operations. 672 for (Operation &rewriteOp : *rewriter.getBody()) { 673 llvm::TypeSwitch<Operation *>(&rewriteOp) 674 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 675 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp, 676 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) { 677 this->generateRewriter(op, rewriteValues, mapRewriteValue); 678 }); 679 } 680 } 681 682 // Update the signature of the rewrite function. 683 rewriterFunc.setType(builder.getFunctionType( 684 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 685 /*results=*/std::nullopt)); 686 687 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 688 return SymbolRefAttr::get( 689 builder.getContext(), 690 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 691 SymbolRefAttr::get(rewriterFunc)); 692 } 693 694 void PatternLowering::generateRewriter( 695 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 696 function_ref<Value(Value)> mapRewriteValue) { 697 SmallVector<Value, 2> arguments; 698 for (Value argument : rewriteOp.getArgs()) 699 arguments.push_back(mapRewriteValue(argument)); 700 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 701 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), 702 arguments); 703 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) 704 rewriteValues[std::get<0>(it)] = std::get<1>(it); 705 } 706 707 void PatternLowering::generateRewriter( 708 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 709 function_ref<Value(Value)> mapRewriteValue) { 710 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 711 attrOp.getLoc(), attrOp.getValueAttr()); 712 rewriteValues[attrOp] = newAttr; 713 } 714 715 void PatternLowering::generateRewriter( 716 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 717 function_ref<Value(Value)> mapRewriteValue) { 718 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 719 mapRewriteValue(eraseOp.getOpValue())); 720 } 721 722 void PatternLowering::generateRewriter( 723 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 724 function_ref<Value(Value)> mapRewriteValue) { 725 SmallVector<Value, 4> operands; 726 for (Value operand : operationOp.getOperandValues()) 727 operands.push_back(mapRewriteValue(operand)); 728 729 SmallVector<Value, 4> attributes; 730 for (Value attr : operationOp.getAttributeValues()) 731 attributes.push_back(mapRewriteValue(attr)); 732 733 bool hasInferredResultTypes = false; 734 SmallVector<Value, 2> types; 735 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types, 736 rewriteValues, hasInferredResultTypes); 737 738 // Create the new operation. 739 Location loc = operationOp.getLoc(); 740 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 741 loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, 742 attributes, operationOp.getAttributeValueNames()); 743 rewriteValues[operationOp.getOp()] = createdOp; 744 745 // Generate accesses for any results that have their types constrained. 746 // Handle the case where there is a single range representing all of the 747 // result types. 748 OperandRange resultTys = operationOp.getTypeValues(); 749 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) { 750 Value &type = rewriteValues[resultTys[0]]; 751 if (!type) { 752 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 753 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 754 } 755 return; 756 } 757 758 // Otherwise, populate the individual results. 759 bool seenVariableLength = false; 760 Type valueTy = builder.getType<pdl::ValueType>(); 761 Type valueRangeTy = pdl::RangeType::get(valueTy); 762 for (const auto &it : llvm::enumerate(resultTys)) { 763 Value &type = rewriteValues[it.value()]; 764 if (type) 765 continue; 766 bool isVariadic = isa<pdl::RangeType>(it.value().getType()); 767 seenVariableLength |= isVariadic; 768 769 // After a variable length result has been seen, we need to use result 770 // groups because the exact index of the result is not statically known. 771 Value resultVal; 772 if (seenVariableLength) 773 resultVal = builder.create<pdl_interp::GetResultsOp>( 774 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 775 else 776 resultVal = builder.create<pdl_interp::GetResultOp>( 777 loc, valueTy, createdOp, it.index()); 778 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 779 } 780 } 781 782 void PatternLowering::generateRewriter( 783 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues, 784 function_ref<Value(Value)> mapRewriteValue) { 785 SmallVector<Value, 4> replOperands; 786 for (Value operand : rangeOp.getArguments()) 787 replOperands.push_back(mapRewriteValue(operand)); 788 rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>( 789 rangeOp.getLoc(), rangeOp.getType(), replOperands); 790 } 791 792 void PatternLowering::generateRewriter( 793 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 794 function_ref<Value(Value)> mapRewriteValue) { 795 SmallVector<Value, 4> replOperands; 796 797 // If the replacement was another operation, get its results. `pdl` allows 798 // for using an operation for simplicitly, but the interpreter isn't as 799 // user facing. 800 if (Value replOp = replaceOp.getReplOperation()) { 801 // Don't use replace if we know the replaced operation has no results. 802 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>(); 803 if (!opOp || !opOp.getTypeValues().empty()) { 804 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 805 replOp.getLoc(), mapRewriteValue(replOp))); 806 } 807 } else { 808 for (Value operand : replaceOp.getReplValues()) 809 replOperands.push_back(mapRewriteValue(operand)); 810 } 811 812 // If there are no replacement values, just create an erase instead. 813 if (replOperands.empty()) { 814 builder.create<pdl_interp::EraseOp>( 815 replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); 816 return; 817 } 818 819 builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(), 820 mapRewriteValue(replaceOp.getOpValue()), 821 replOperands); 822 } 823 824 void PatternLowering::generateRewriter( 825 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 826 function_ref<Value(Value)> mapRewriteValue) { 827 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 828 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 829 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 830 } 831 832 void PatternLowering::generateRewriter( 833 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 834 function_ref<Value(Value)> mapRewriteValue) { 835 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 836 resultOp.getLoc(), resultOp.getType(), 837 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 838 } 839 840 void PatternLowering::generateRewriter( 841 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 842 function_ref<Value(Value)> mapRewriteValue) { 843 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 844 // type. 845 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { 846 rewriteValues[typeOp] = 847 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 848 } 849 } 850 851 void PatternLowering::generateRewriter( 852 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 853 function_ref<Value(Value)> mapRewriteValue) { 854 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 855 // type. 856 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { 857 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 858 typeOp.getLoc(), typeOp.getType(), typeAttr); 859 } 860 } 861 862 void PatternLowering::generateOperationResultTypeRewriter( 863 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 864 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 865 bool &hasInferredResultTypes) { 866 Block *rewriterBlock = op->getBlock(); 867 868 // Try to handle resolution for each of the result types individually. This is 869 // preferred over type inferrence because it will allow for us to use existing 870 // types directly, as opposed to trying to rebuild the type list. 871 OperandRange resultTypeValues = op.getTypeValues(); 872 auto tryResolveResultTypes = [&] { 873 types.reserve(resultTypeValues.size()); 874 for (const auto &it : llvm::enumerate(resultTypeValues)) { 875 Value resultType = it.value(); 876 877 // Check for an already translated value. 878 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 879 types.push_back(existingRewriteValue); 880 continue; 881 } 882 883 // Check for an input from the matcher. 884 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 885 types.push_back(mapRewriteValue(resultType)); 886 continue; 887 } 888 889 // Otherwise, we couldn't infer the result types. Bail out here to see if 890 // we can infer the types for this operation from another way. 891 types.clear(); 892 return failure(); 893 } 894 return success(); 895 }; 896 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes())) 897 return; 898 899 // Otherwise, check if the operation has type inference support itself. 900 if (op.hasTypeInference()) { 901 hasInferredResultTypes = true; 902 return; 903 } 904 905 // Look for an operation that was replaced by `op`. The result types will be 906 // inferred from the results that were replaced. 907 for (OpOperand &use : op.getOp().getUses()) { 908 // Check that the use corresponds to a ReplaceOp and that it is the 909 // replacement value, not the operation being replaced. 910 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 911 if (!replOpUser || use.getOperandNumber() == 0) 912 continue; 913 // Make sure the replaced operation was defined before this one. PDL 914 // rewrites only have single block regions, so if the op isn't in the 915 // rewriter block (i.e. the current block of the operation) we already know 916 // it dominates (i.e. it's in the matcher). 917 Value replOpVal = replOpUser.getOpValue(); 918 Operation *replacedOp = replOpVal.getDefiningOp(); 919 if (replacedOp->getBlock() == rewriterBlock && 920 !replacedOp->isBeforeInBlock(op)) 921 continue; 922 923 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 924 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 925 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 926 replacedOp->getLoc(), replacedOpResults)); 927 return; 928 } 929 930 // If the types could not be inferred from any context and there weren't any 931 // explicit result types, assume the user actually meant for the operation to 932 // have no results. 933 if (resultTypeValues.empty()) 934 return; 935 936 // The verifier asserts that the result types of each pdl.getOperation can be 937 // inferred. If we reach here, there is a bug either in the logic above or 938 // in the verifier for pdl.getOperation. 939 op->emitOpError() << "unable to infer result type for operation"; 940 llvm_unreachable("unable to infer result type for operation"); 941 } 942 943 //===----------------------------------------------------------------------===// 944 // Conversion Pass 945 //===----------------------------------------------------------------------===// 946 947 namespace { 948 struct PDLToPDLInterpPass 949 : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 950 PDLToPDLInterpPass() = default; 951 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default; 952 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap) 953 : configMap(&configMap) {} 954 void runOnOperation() final; 955 956 /// A map containing the configuration for each pattern. 957 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr; 958 }; 959 } // namespace 960 961 /// Convert the given module containing PDL pattern operations into a PDL 962 /// Interpreter operations. 963 void PDLToPDLInterpPass::runOnOperation() { 964 ModuleOp module = getOperation(); 965 966 // Create the main matcher function This function contains all of the match 967 // related functionality from patterns in the module. 968 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 969 auto matcherFunc = builder.create<pdl_interp::FuncOp>( 970 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 971 builder.getFunctionType(builder.getType<pdl::OperationType>(), 972 /*results=*/std::nullopt), 973 /*attrs=*/std::nullopt); 974 975 // Create a nested module to hold the functions invoked for rewriting the IR 976 // after a successful match. 977 ModuleOp rewriterModule = builder.create<ModuleOp>( 978 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 979 980 // Generate the code for the patterns within the module. 981 PatternLowering generator(matcherFunc, rewriterModule, configMap); 982 generator.lower(module); 983 984 // After generation, delete all of the pattern operations. 985 for (pdl::PatternOp pattern : 986 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) { 987 // Drop the now dead config mappings. 988 if (configMap) 989 configMap->erase(pattern); 990 991 pattern.erase(); 992 } 993 } 994 995 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 996 return std::make_unique<PDLToPDLInterpPass>(); 997 } 998 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass( 999 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) { 1000 return std::make_unique<PDLToPDLInterpPass>(configMap); 1001 } 1002