1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP 25 #include "mlir/Conversion/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::pdl_to_pdl_interp; 30 31 //===----------------------------------------------------------------------===// 32 // PatternLowering 33 //===----------------------------------------------------------------------===// 34 35 namespace { 36 /// This class generators operations within the PDL Interpreter dialect from a 37 /// given module containing PDL pattern operations. 38 struct PatternLowering { 39 public: 40 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, 41 DenseMap<Operation *, PDLPatternConfigSet *> *configMap); 42 43 /// Generate code for matching and rewriting based on the pattern operations 44 /// within the module. 45 void lower(ModuleOp module); 46 47 private: 48 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 49 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 50 51 /// Generate interpreter operations for the tree rooted at the given matcher 52 /// node, in the specified region. 53 Block *generateMatcher(MatcherNode &node, Region ®ion, 54 Block *block = nullptr); 55 56 /// Get or create an access to the provided positional value in the current 57 /// block. This operation may mutate the provided block pointer if nested 58 /// regions (i.e., pdl_interp.iterate) are required. 59 Value getValueAt(Block *¤tBlock, Position *pos); 60 61 /// Create the interpreter predicate operations. This operation may mutate the 62 /// provided current block pointer if nested regions (iterates) are required. 63 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 64 65 /// Create the interpreter switch / predicate operations, with several case 66 /// destinations. This operation never mutates the provided current block 67 /// pointer, because the switch operation does not need Values beyond `val`. 68 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 69 70 /// Create the interpreter operations to record a successful pattern match 71 /// using the contained root operation. This operation may mutate the current 72 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 73 void generate(SuccessNode *successNode, Block *¤tBlock); 74 75 /// Generate a rewriter function for the given pattern operation, and returns 76 /// a reference to that function. 77 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 78 SmallVectorImpl<Position *> &usedMatchValues); 79 80 /// Generate the rewriter code for the given operation. 81 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 82 DenseMap<Value, Value> &rewriteValues, 83 function_ref<Value(Value)> mapRewriteValue); 84 void generateRewriter(pdl::AttributeOp attrOp, 85 DenseMap<Value, Value> &rewriteValues, 86 function_ref<Value(Value)> mapRewriteValue); 87 void generateRewriter(pdl::EraseOp eraseOp, 88 DenseMap<Value, Value> &rewriteValues, 89 function_ref<Value(Value)> mapRewriteValue); 90 void generateRewriter(pdl::OperationOp operationOp, 91 DenseMap<Value, Value> &rewriteValues, 92 function_ref<Value(Value)> mapRewriteValue); 93 void generateRewriter(pdl::RangeOp rangeOp, 94 DenseMap<Value, Value> &rewriteValues, 95 function_ref<Value(Value)> mapRewriteValue); 96 void generateRewriter(pdl::ReplaceOp replaceOp, 97 DenseMap<Value, Value> &rewriteValues, 98 function_ref<Value(Value)> mapRewriteValue); 99 void generateRewriter(pdl::ResultOp resultOp, 100 DenseMap<Value, Value> &rewriteValues, 101 function_ref<Value(Value)> mapRewriteValue); 102 void generateRewriter(pdl::ResultsOp resultOp, 103 DenseMap<Value, Value> &rewriteValues, 104 function_ref<Value(Value)> mapRewriteValue); 105 void generateRewriter(pdl::TypeOp typeOp, 106 DenseMap<Value, Value> &rewriteValues, 107 function_ref<Value(Value)> mapRewriteValue); 108 void generateRewriter(pdl::TypesOp typeOp, 109 DenseMap<Value, Value> &rewriteValues, 110 function_ref<Value(Value)> mapRewriteValue); 111 112 /// Generate the values used for resolving the result types of an operation 113 /// created within a dag rewriter region. If the result types of the operation 114 /// should be inferred, `hasInferredResultTypes` is set to true. 115 void generateOperationResultTypeRewriter( 116 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 117 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 118 bool &hasInferredResultTypes); 119 120 /// A builder to use when generating interpreter operations. 121 OpBuilder builder; 122 123 /// The matcher function used for all match related logic within PDL patterns. 124 pdl_interp::FuncOp matcherFunc; 125 126 /// The rewriter module containing the all rewrite related logic within PDL 127 /// patterns. 128 ModuleOp rewriterModule; 129 130 /// The symbol table of the rewriter module used for insertion. 131 SymbolTable rewriterSymbolTable; 132 133 /// A scoped map connecting a position with the corresponding interpreter 134 /// value. 135 ValueMap values; 136 137 /// A stack of blocks used as the failure destination for matcher nodes that 138 /// don't have an explicit failure path. 139 SmallVector<Block *, 8> failureBlockStack; 140 141 /// A mapping between values defined in a pattern match, and the corresponding 142 /// positional value. 143 DenseMap<Value, Position *> valueToPosition; 144 145 /// The set of operation values whose location will be used for newly 146 /// generated operations. 147 SetVector<Value> locOps; 148 149 /// A mapping between pattern operations and the corresponding configuration 150 /// set. 151 DenseMap<Operation *, PDLPatternConfigSet *> *configMap; 152 153 /// A mapping from a constraint question to the ApplyConstraintOp 154 /// that implements it. 155 DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap; 156 }; 157 } // namespace 158 159 PatternLowering::PatternLowering( 160 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, 161 DenseMap<Operation *, PDLPatternConfigSet *> *configMap) 162 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 163 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule), 164 configMap(configMap) {} 165 166 void PatternLowering::lower(ModuleOp module) { 167 PredicateUniquer predicateUniquer; 168 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 169 170 // Define top-level scope for the arguments to the matcher function. 171 ValueMapScope topLevelValueScope(values); 172 173 // Insert the root operation, i.e. argument to the matcher, at the root 174 // position. 175 Block *matcherEntryBlock = &matcherFunc.front(); 176 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 177 178 // Generate a root matcher node from the provided PDL module. 179 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 180 module, predicateBuilder, valueToPosition); 181 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 182 assert(failureBlockStack.empty() && "failed to empty the stack"); 183 184 // After generation, merged the first matched block into the entry. 185 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 186 firstMatcherBlock->getOperations()); 187 firstMatcherBlock->erase(); 188 } 189 190 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion, 191 Block *block) { 192 // Push a new scope for the values used by this matcher. 193 if (!block) 194 block = ®ion.emplaceBlock(); 195 ValueMapScope scope(values); 196 197 // If this is the return node, simply insert the corresponding interpreter 198 // finalize. 199 if (isa<ExitNode>(node)) { 200 builder.setInsertionPointToEnd(block); 201 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 202 return block; 203 } 204 205 // Get the next block in the match sequence. 206 // This is intentionally executed first, before we get the value for the 207 // position associated with the node, so that we preserve an "there exist" 208 // semantics: if getting a value requires an upward traversal (going from a 209 // value to its consumers), we want to perform the check on all the consumers 210 // before we pass control to the failure node. 211 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 212 Block *failureBlock; 213 if (failureNode) { 214 failureBlock = generateMatcher(*failureNode, region); 215 failureBlockStack.push_back(failureBlock); 216 } else { 217 assert(!failureBlockStack.empty() && "expected valid failure block"); 218 failureBlock = failureBlockStack.back(); 219 } 220 221 // If this node contains a position, get the corresponding value for this 222 // block. 223 Block *currentBlock = block; 224 Position *position = node.getPosition(); 225 Value val = position ? getValueAt(currentBlock, position) : Value(); 226 227 // If this value corresponds to an operation, record that we are going to use 228 // its location as part of a fused location. 229 bool isOperationValue = val && isa<pdl::OperationType>(val.getType()); 230 if (isOperationValue) 231 locOps.insert(val); 232 233 // Dispatch to the correct method based on derived node type. 234 TypeSwitch<MatcherNode *>(&node) 235 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { 236 this->generate(derivedNode, currentBlock, val); 237 }) 238 .Case([&](SuccessNode *successNode) { 239 generate(successNode, currentBlock); 240 }); 241 242 // Pop all the failure blocks that were inserted due to nesting of 243 // pdl_interp.iterate. 244 while (failureBlockStack.back() != failureBlock) { 245 failureBlockStack.pop_back(); 246 assert(!failureBlockStack.empty() && "unable to locate failure block"); 247 } 248 249 // Pop the new failure block. 250 if (failureNode) 251 failureBlockStack.pop_back(); 252 253 if (isOperationValue) 254 locOps.remove(val); 255 256 return block; 257 } 258 259 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 260 if (Value val = values.lookup(pos)) 261 return val; 262 263 // Get the value for the parent position. 264 Value parentVal; 265 if (Position *parent = pos->getParent()) 266 parentVal = getValueAt(currentBlock, parent); 267 268 // TODO: Use a location from the position. 269 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); 270 builder.setInsertionPointToEnd(currentBlock); 271 Value value; 272 switch (pos->getKind()) { 273 case Predicates::OperationPos: { 274 auto *operationPos = cast<OperationPosition>(pos); 275 if (operationPos->isOperandDefiningOp()) 276 // Standard (downward) traversal which directly follows the defining op. 277 value = builder.create<pdl_interp::GetDefiningOpOp>( 278 loc, builder.getType<pdl::OperationType>(), parentVal); 279 else 280 // A passthrough operation position. 281 value = parentVal; 282 break; 283 } 284 case Predicates::UsersPos: { 285 auto *usersPos = cast<UsersPosition>(pos); 286 287 // The first operation retrieves the representative value of a range. 288 // This applies only when the parent is a range of values and we were 289 // requested to use a representative value (e.g., upward traversal). 290 if (isa<pdl::RangeType>(parentVal.getType()) && 291 usersPos->useRepresentative()) 292 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 293 else 294 value = parentVal; 295 296 // The second operation retrieves the users. 297 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 298 break; 299 } 300 case Predicates::ForEachPos: { 301 assert(!failureBlockStack.empty() && "expected valid failure block"); 302 auto foreach = builder.create<pdl_interp::ForEachOp>( 303 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); 304 value = foreach.getLoopVariable(); 305 306 // Create the continuation block. 307 Block *continueBlock = builder.createBlock(&foreach.getRegion()); 308 builder.create<pdl_interp::ContinueOp>(loc); 309 failureBlockStack.push_back(continueBlock); 310 311 currentBlock = &foreach.getRegion().front(); 312 break; 313 } 314 case Predicates::OperandPos: { 315 auto *operandPos = cast<OperandPosition>(pos); 316 value = builder.create<pdl_interp::GetOperandOp>( 317 loc, builder.getType<pdl::ValueType>(), parentVal, 318 operandPos->getOperandNumber()); 319 break; 320 } 321 case Predicates::OperandGroupPos: { 322 auto *operandPos = cast<OperandGroupPosition>(pos); 323 Type valueTy = builder.getType<pdl::ValueType>(); 324 value = builder.create<pdl_interp::GetOperandsOp>( 325 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 326 parentVal, operandPos->getOperandGroupNumber()); 327 break; 328 } 329 case Predicates::AttributePos: { 330 auto *attrPos = cast<AttributePosition>(pos); 331 value = builder.create<pdl_interp::GetAttributeOp>( 332 loc, builder.getType<pdl::AttributeType>(), parentVal, 333 attrPos->getName().strref()); 334 break; 335 } 336 case Predicates::TypePos: { 337 if (isa<pdl::AttributeType>(parentVal.getType())) 338 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 339 else 340 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 341 break; 342 } 343 case Predicates::ResultPos: { 344 auto *resPos = cast<ResultPosition>(pos); 345 value = builder.create<pdl_interp::GetResultOp>( 346 loc, builder.getType<pdl::ValueType>(), parentVal, 347 resPos->getResultNumber()); 348 break; 349 } 350 case Predicates::ResultGroupPos: { 351 auto *resPos = cast<ResultGroupPosition>(pos); 352 Type valueTy = builder.getType<pdl::ValueType>(); 353 value = builder.create<pdl_interp::GetResultsOp>( 354 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 355 parentVal, resPos->getResultGroupNumber()); 356 break; 357 } 358 case Predicates::AttributeLiteralPos: { 359 auto *attrPos = cast<AttributeLiteralPosition>(pos); 360 value = 361 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); 362 break; 363 } 364 case Predicates::TypeLiteralPos: { 365 auto *typePos = cast<TypeLiteralPosition>(pos); 366 Attribute rawTypeAttr = typePos->getValue(); 367 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr)) 368 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); 369 else 370 value = builder.create<pdl_interp::CreateTypesOp>( 371 loc, cast<ArrayAttr>(rawTypeAttr)); 372 break; 373 } 374 case Predicates::ConstraintResultPos: { 375 // Due to the order of traversal, the ApplyConstraintOp has already been 376 // created and we can find it in constraintOpMap. 377 auto *constrResPos = cast<ConstraintPosition>(pos); 378 auto i = constraintOpMap.find(constrResPos->getQuestion()); 379 assert(i != constraintOpMap.end()); 380 value = i->second->getResult(constrResPos->getIndex()); 381 break; 382 } 383 default: 384 llvm_unreachable("Generating unknown Position getter"); 385 break; 386 } 387 388 values.insert(pos, value); 389 return value; 390 } 391 392 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 393 Value val) { 394 Location loc = val.getLoc(); 395 Qualifier *question = boolNode->getQuestion(); 396 Qualifier *answer = boolNode->getAnswer(); 397 Region *region = currentBlock->getParent(); 398 399 // Execute the getValue queries first, so that we create success 400 // matcher in the correct (possibly nested) region. 401 SmallVector<Value> args; 402 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 403 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 404 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 405 for (Position *position : cstQuestion->getArgs()) 406 args.push_back(getValueAt(currentBlock, position)); 407 } 408 409 // Generate a new block as success successor and get the failure successor. 410 Block *success = ®ion->emplaceBlock(); 411 Block *failure = failureBlockStack.back(); 412 413 // Create the predicate. 414 builder.setInsertionPointToEnd(currentBlock); 415 Predicates::Kind kind = question->getKind(); 416 switch (kind) { 417 case Predicates::IsNotNullQuestion: 418 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 419 break; 420 case Predicates::OperationNameQuestion: { 421 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 422 builder.create<pdl_interp::CheckOperationNameOp>( 423 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 424 break; 425 } 426 case Predicates::TypeQuestion: { 427 auto *ans = cast<TypeAnswer>(answer); 428 if (isa<pdl::RangeType>(val.getType())) 429 builder.create<pdl_interp::CheckTypesOp>( 430 loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure); 431 else 432 builder.create<pdl_interp::CheckTypeOp>( 433 loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure); 434 break; 435 } 436 case Predicates::AttributeQuestion: { 437 auto *ans = cast<AttributeAnswer>(answer); 438 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 439 success, failure); 440 break; 441 } 442 case Predicates::OperandCountAtLeastQuestion: 443 case Predicates::OperandCountQuestion: 444 builder.create<pdl_interp::CheckOperandCountOp>( 445 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 446 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 447 success, failure); 448 break; 449 case Predicates::ResultCountAtLeastQuestion: 450 case Predicates::ResultCountQuestion: 451 builder.create<pdl_interp::CheckResultCountOp>( 452 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 453 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 454 success, failure); 455 break; 456 case Predicates::EqualToQuestion: { 457 bool trueAnswer = isa<TrueAnswer>(answer); 458 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 459 trueAnswer ? success : failure, 460 trueAnswer ? failure : success); 461 break; 462 } 463 case Predicates::ConstraintQuestion: { 464 auto *cstQuestion = cast<ConstraintQuestion>(question); 465 auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>( 466 loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, 467 cstQuestion->getIsNegated(), success, failure); 468 469 constraintOpMap.insert({cstQuestion, applyConstraintOp}); 470 break; 471 } 472 default: 473 llvm_unreachable("Generating unknown Predicate operation"); 474 } 475 476 // Generate the matcher in the current (potentially nested) region. 477 // This might use the results of the current predicate. 478 generateMatcher(*boolNode->getSuccessNode(), *region, success); 479 } 480 481 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 482 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 483 llvm::MapVector<Qualifier *, Block *> &dests) { 484 std::vector<ValT> values; 485 std::vector<Block *> blocks; 486 values.reserve(dests.size()); 487 blocks.reserve(dests.size()); 488 for (const auto &it : dests) { 489 blocks.push_back(it.second); 490 values.push_back(cast<PredT>(it.first)->getValue()); 491 } 492 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 493 } 494 495 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 496 Value val) { 497 Qualifier *question = switchNode->getQuestion(); 498 Region *region = currentBlock->getParent(); 499 Block *defaultDest = failureBlockStack.back(); 500 501 // If the switch question is not an exact answer, i.e. for the `at_least` 502 // cases, we generate a special block sequence. 503 Predicates::Kind kind = question->getKind(); 504 if (kind == Predicates::OperandCountAtLeastQuestion || 505 kind == Predicates::ResultCountAtLeastQuestion) { 506 // Order the children such that the cases are in reverse numerical order. 507 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 508 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 509 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 510 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 511 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 512 }); 513 514 // Build the destination for each child using the next highest child as a 515 // a failure destination. This essentially creates the following control 516 // flow: 517 // 518 // if (operand_count < 1) 519 // goto failure 520 // if (child1.match()) 521 // ... 522 // 523 // if (operand_count < 2) 524 // goto failure 525 // if (child2.match()) 526 // ... 527 // 528 // failure: 529 // ... 530 // 531 failureBlockStack.push_back(defaultDest); 532 Location loc = val.getLoc(); 533 for (unsigned idx : sortedChildren) { 534 auto &child = switchNode->getChild(idx); 535 Block *childBlock = generateMatcher(*child.second, *region); 536 Block *predicateBlock = builder.createBlock(childBlock); 537 builder.setInsertionPointToEnd(predicateBlock); 538 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 539 switch (kind) { 540 case Predicates::OperandCountAtLeastQuestion: 541 builder.create<pdl_interp::CheckOperandCountOp>( 542 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 543 break; 544 case Predicates::ResultCountAtLeastQuestion: 545 builder.create<pdl_interp::CheckResultCountOp>( 546 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 547 break; 548 default: 549 llvm_unreachable("Generating invalid AtLeast operation"); 550 } 551 failureBlockStack.back() = predicateBlock; 552 } 553 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 554 currentBlock->getOperations().splice(currentBlock->end(), 555 firstPredicateBlock->getOperations()); 556 firstPredicateBlock->erase(); 557 return; 558 } 559 560 // Otherwise, generate each of the children and generate an interpreter 561 // switch. 562 llvm::MapVector<Qualifier *, Block *> children; 563 for (auto &it : switchNode->getChildren()) 564 children.insert({it.first, generateMatcher(*it.second, *region)}); 565 builder.setInsertionPointToEnd(currentBlock); 566 567 switch (question->getKind()) { 568 case Predicates::OperandCountQuestion: 569 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 570 int32_t>(val, defaultDest, builder, children); 571 case Predicates::ResultCountQuestion: 572 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 573 int32_t>(val, defaultDest, builder, children); 574 case Predicates::OperationNameQuestion: 575 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 576 OperationNameAnswer>(val, defaultDest, builder, 577 children); 578 case Predicates::TypeQuestion: 579 if (isa<pdl::RangeType>(val.getType())) { 580 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 581 val, defaultDest, builder, children); 582 } 583 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 584 val, defaultDest, builder, children); 585 case Predicates::AttributeQuestion: 586 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 587 val, defaultDest, builder, children); 588 default: 589 llvm_unreachable("Generating unknown switch predicate."); 590 } 591 } 592 593 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 594 pdl::PatternOp pattern = successNode->getPattern(); 595 Value root = successNode->getRoot(); 596 597 // Generate a rewriter for the pattern this success node represents, and track 598 // any values used from the match region. 599 SmallVector<Position *, 8> usedMatchValues; 600 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 601 602 // Process any values used in the rewrite that are defined in the match. 603 std::vector<Value> mappedMatchValues; 604 mappedMatchValues.reserve(usedMatchValues.size()); 605 for (Position *position : usedMatchValues) 606 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 607 608 // Collect the set of operations generated by the rewriter. 609 SmallVector<StringRef, 4> generatedOps; 610 for (auto op : 611 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>()) 612 generatedOps.push_back(*op.getOpName()); 613 ArrayAttr generatedOpsAttr; 614 if (!generatedOps.empty()) 615 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 616 617 // Grab the root kind if present. 618 StringAttr rootKindAttr; 619 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 620 if (std::optional<StringRef> rootKind = rootOp.getOpName()) 621 rootKindAttr = builder.getStringAttr(*rootKind); 622 623 builder.setInsertionPointToEnd(currentBlock); 624 auto matchOp = builder.create<pdl_interp::RecordMatchOp>( 625 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 626 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), 627 failureBlockStack.back()); 628 629 // Set the config of the lowered match to the parent pattern. 630 if (configMap) 631 configMap->try_emplace(matchOp, configMap->lookup(pattern)); 632 } 633 634 SymbolRefAttr PatternLowering::generateRewriter( 635 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 636 builder.setInsertionPointToEnd(rewriterModule.getBody()); 637 auto rewriterFunc = builder.create<pdl_interp::FuncOp>( 638 pattern.getLoc(), "pdl_generated_rewriter", 639 builder.getFunctionType(std::nullopt, std::nullopt)); 640 rewriterSymbolTable.insert(rewriterFunc); 641 642 // Generate the rewriter function body. 643 builder.setInsertionPointToEnd(&rewriterFunc.front()); 644 645 // Map an input operand of the pattern to a generated interpreter value. 646 DenseMap<Value, Value> rewriteValues; 647 auto mapRewriteValue = [&](Value oldValue) { 648 Value &newValue = rewriteValues[oldValue]; 649 if (newValue) 650 return newValue; 651 652 // Prefer materializing constants directly when possible. 653 Operation *oldOp = oldValue.getDefiningOp(); 654 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 655 if (Attribute value = attrOp.getValueAttr()) { 656 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 657 attrOp.getLoc(), value); 658 } 659 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 660 if (TypeAttr type = typeOp.getConstantTypeAttr()) { 661 return newValue = builder.create<pdl_interp::CreateTypeOp>( 662 typeOp.getLoc(), type); 663 } 664 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 665 if (ArrayAttr type = typeOp.getConstantTypesAttr()) { 666 return newValue = builder.create<pdl_interp::CreateTypesOp>( 667 typeOp.getLoc(), typeOp.getType(), type); 668 } 669 } 670 671 // Otherwise, add this as an input to the rewriter. 672 Position *inputPos = valueToPosition.lookup(oldValue); 673 assert(inputPos && "expected value to be a pattern input"); 674 usedMatchValues.push_back(inputPos); 675 return newValue = rewriterFunc.front().addArgument(oldValue.getType(), 676 oldValue.getLoc()); 677 }; 678 679 // If this is a custom rewriter, simply dispatch to the registered rewrite 680 // method. 681 pdl::RewriteOp rewriter = pattern.getRewriter(); 682 if (StringAttr rewriteName = rewriter.getNameAttr()) { 683 SmallVector<Value> args; 684 if (rewriter.getRoot()) 685 args.push_back(mapRewriteValue(rewriter.getRoot())); 686 auto mappedArgs = 687 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); 688 args.append(mappedArgs.begin(), mappedArgs.end()); 689 builder.create<pdl_interp::ApplyRewriteOp>( 690 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); 691 } else { 692 // Otherwise this is a dag rewriter defined using PDL operations. 693 for (Operation &rewriteOp : *rewriter.getBody()) { 694 llvm::TypeSwitch<Operation *>(&rewriteOp) 695 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 696 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp, 697 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) { 698 this->generateRewriter(op, rewriteValues, mapRewriteValue); 699 }); 700 } 701 } 702 703 // Update the signature of the rewrite function. 704 rewriterFunc.setType(builder.getFunctionType( 705 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 706 /*results=*/std::nullopt)); 707 708 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 709 return SymbolRefAttr::get( 710 builder.getContext(), 711 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 712 SymbolRefAttr::get(rewriterFunc)); 713 } 714 715 void PatternLowering::generateRewriter( 716 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 717 function_ref<Value(Value)> mapRewriteValue) { 718 SmallVector<Value, 2> arguments; 719 for (Value argument : rewriteOp.getArgs()) 720 arguments.push_back(mapRewriteValue(argument)); 721 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 722 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), 723 arguments); 724 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) 725 rewriteValues[std::get<0>(it)] = std::get<1>(it); 726 } 727 728 void PatternLowering::generateRewriter( 729 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 730 function_ref<Value(Value)> mapRewriteValue) { 731 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 732 attrOp.getLoc(), attrOp.getValueAttr()); 733 rewriteValues[attrOp] = newAttr; 734 } 735 736 void PatternLowering::generateRewriter( 737 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 738 function_ref<Value(Value)> mapRewriteValue) { 739 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 740 mapRewriteValue(eraseOp.getOpValue())); 741 } 742 743 void PatternLowering::generateRewriter( 744 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 745 function_ref<Value(Value)> mapRewriteValue) { 746 SmallVector<Value, 4> operands; 747 for (Value operand : operationOp.getOperandValues()) 748 operands.push_back(mapRewriteValue(operand)); 749 750 SmallVector<Value, 4> attributes; 751 for (Value attr : operationOp.getAttributeValues()) 752 attributes.push_back(mapRewriteValue(attr)); 753 754 bool hasInferredResultTypes = false; 755 SmallVector<Value, 2> types; 756 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types, 757 rewriteValues, hasInferredResultTypes); 758 759 // Create the new operation. 760 Location loc = operationOp.getLoc(); 761 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 762 loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, 763 attributes, operationOp.getAttributeValueNames()); 764 rewriteValues[operationOp.getOp()] = createdOp; 765 766 // Generate accesses for any results that have their types constrained. 767 // Handle the case where there is a single range representing all of the 768 // result types. 769 OperandRange resultTys = operationOp.getTypeValues(); 770 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) { 771 Value &type = rewriteValues[resultTys[0]]; 772 if (!type) { 773 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 774 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 775 } 776 return; 777 } 778 779 // Otherwise, populate the individual results. 780 bool seenVariableLength = false; 781 Type valueTy = builder.getType<pdl::ValueType>(); 782 Type valueRangeTy = pdl::RangeType::get(valueTy); 783 for (const auto &it : llvm::enumerate(resultTys)) { 784 Value &type = rewriteValues[it.value()]; 785 if (type) 786 continue; 787 bool isVariadic = isa<pdl::RangeType>(it.value().getType()); 788 seenVariableLength |= isVariadic; 789 790 // After a variable length result has been seen, we need to use result 791 // groups because the exact index of the result is not statically known. 792 Value resultVal; 793 if (seenVariableLength) 794 resultVal = builder.create<pdl_interp::GetResultsOp>( 795 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 796 else 797 resultVal = builder.create<pdl_interp::GetResultOp>( 798 loc, valueTy, createdOp, it.index()); 799 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 800 } 801 } 802 803 void PatternLowering::generateRewriter( 804 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues, 805 function_ref<Value(Value)> mapRewriteValue) { 806 SmallVector<Value, 4> replOperands; 807 for (Value operand : rangeOp.getArguments()) 808 replOperands.push_back(mapRewriteValue(operand)); 809 rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>( 810 rangeOp.getLoc(), rangeOp.getType(), replOperands); 811 } 812 813 void PatternLowering::generateRewriter( 814 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 815 function_ref<Value(Value)> mapRewriteValue) { 816 SmallVector<Value, 4> replOperands; 817 818 // If the replacement was another operation, get its results. `pdl` allows 819 // for using an operation for simplicitly, but the interpreter isn't as 820 // user facing. 821 if (Value replOp = replaceOp.getReplOperation()) { 822 // Don't use replace if we know the replaced operation has no results. 823 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>(); 824 if (!opOp || !opOp.getTypeValues().empty()) { 825 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 826 replOp.getLoc(), mapRewriteValue(replOp))); 827 } 828 } else { 829 for (Value operand : replaceOp.getReplValues()) 830 replOperands.push_back(mapRewriteValue(operand)); 831 } 832 833 // If there are no replacement values, just create an erase instead. 834 if (replOperands.empty()) { 835 builder.create<pdl_interp::EraseOp>( 836 replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); 837 return; 838 } 839 840 builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(), 841 mapRewriteValue(replaceOp.getOpValue()), 842 replOperands); 843 } 844 845 void PatternLowering::generateRewriter( 846 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 847 function_ref<Value(Value)> mapRewriteValue) { 848 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 849 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 850 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 851 } 852 853 void PatternLowering::generateRewriter( 854 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 855 function_ref<Value(Value)> mapRewriteValue) { 856 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 857 resultOp.getLoc(), resultOp.getType(), 858 mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); 859 } 860 861 void PatternLowering::generateRewriter( 862 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 863 function_ref<Value(Value)> mapRewriteValue) { 864 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 865 // type. 866 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { 867 rewriteValues[typeOp] = 868 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 869 } 870 } 871 872 void PatternLowering::generateRewriter( 873 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 874 function_ref<Value(Value)> mapRewriteValue) { 875 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 876 // type. 877 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { 878 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 879 typeOp.getLoc(), typeOp.getType(), typeAttr); 880 } 881 } 882 883 void PatternLowering::generateOperationResultTypeRewriter( 884 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, 885 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, 886 bool &hasInferredResultTypes) { 887 Block *rewriterBlock = op->getBlock(); 888 889 // Try to handle resolution for each of the result types individually. This is 890 // preferred over type inferrence because it will allow for us to use existing 891 // types directly, as opposed to trying to rebuild the type list. 892 OperandRange resultTypeValues = op.getTypeValues(); 893 auto tryResolveResultTypes = [&] { 894 types.reserve(resultTypeValues.size()); 895 for (const auto &it : llvm::enumerate(resultTypeValues)) { 896 Value resultType = it.value(); 897 898 // Check for an already translated value. 899 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 900 types.push_back(existingRewriteValue); 901 continue; 902 } 903 904 // Check for an input from the matcher. 905 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 906 types.push_back(mapRewriteValue(resultType)); 907 continue; 908 } 909 910 // Otherwise, we couldn't infer the result types. Bail out here to see if 911 // we can infer the types for this operation from another way. 912 types.clear(); 913 return failure(); 914 } 915 return success(); 916 }; 917 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes())) 918 return; 919 920 // Otherwise, check if the operation has type inference support itself. 921 if (op.hasTypeInference()) { 922 hasInferredResultTypes = true; 923 return; 924 } 925 926 // Look for an operation that was replaced by `op`. The result types will be 927 // inferred from the results that were replaced. 928 for (OpOperand &use : op.getOp().getUses()) { 929 // Check that the use corresponds to a ReplaceOp and that it is the 930 // replacement value, not the operation being replaced. 931 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 932 if (!replOpUser || use.getOperandNumber() == 0) 933 continue; 934 // Make sure the replaced operation was defined before this one. PDL 935 // rewrites only have single block regions, so if the op isn't in the 936 // rewriter block (i.e. the current block of the operation) we already know 937 // it dominates (i.e. it's in the matcher). 938 Value replOpVal = replOpUser.getOpValue(); 939 Operation *replacedOp = replOpVal.getDefiningOp(); 940 if (replacedOp->getBlock() == rewriterBlock && 941 !replacedOp->isBeforeInBlock(op)) 942 continue; 943 944 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 945 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 946 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 947 replacedOp->getLoc(), replacedOpResults)); 948 return; 949 } 950 951 // If the types could not be inferred from any context and there weren't any 952 // explicit result types, assume the user actually meant for the operation to 953 // have no results. 954 if (resultTypeValues.empty()) 955 return; 956 957 // The verifier asserts that the result types of each pdl.getOperation can be 958 // inferred. If we reach here, there is a bug either in the logic above or 959 // in the verifier for pdl.getOperation. 960 op->emitOpError() << "unable to infer result type for operation"; 961 llvm_unreachable("unable to infer result type for operation"); 962 } 963 964 //===----------------------------------------------------------------------===// 965 // Conversion Pass 966 //===----------------------------------------------------------------------===// 967 968 namespace { 969 struct PDLToPDLInterpPass 970 : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 971 PDLToPDLInterpPass() = default; 972 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default; 973 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap) 974 : configMap(&configMap) {} 975 void runOnOperation() final; 976 977 /// A map containing the configuration for each pattern. 978 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr; 979 }; 980 } // namespace 981 982 /// Convert the given module containing PDL pattern operations into a PDL 983 /// Interpreter operations. 984 void PDLToPDLInterpPass::runOnOperation() { 985 ModuleOp module = getOperation(); 986 987 // Create the main matcher function This function contains all of the match 988 // related functionality from patterns in the module. 989 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 990 auto matcherFunc = builder.create<pdl_interp::FuncOp>( 991 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 992 builder.getFunctionType(builder.getType<pdl::OperationType>(), 993 /*results=*/std::nullopt), 994 /*attrs=*/std::nullopt); 995 996 // Create a nested module to hold the functions invoked for rewriting the IR 997 // after a successful match. 998 ModuleOp rewriterModule = builder.create<ModuleOp>( 999 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 1000 1001 // Generate the code for the patterns within the module. 1002 PatternLowering generator(matcherFunc, rewriterModule, configMap); 1003 generator.lower(module); 1004 1005 // After generation, delete all of the pattern operations. 1006 for (pdl::PatternOp pattern : 1007 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) { 1008 // Drop the now dead config mappings. 1009 if (configMap) 1010 configMap->erase(pattern); 1011 1012 pattern.erase(); 1013 } 1014 } 1015 1016 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 1017 return std::make_unique<PDLToPDLInterpPass>(); 1018 } 1019 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass( 1020 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) { 1021 return std::make_unique<PDLToPDLInterpPass>(configMap); 1022 } 1023