1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_PATTERNMATCH_H 10 #define MLIR_IR_PATTERNMATCH_H 11 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "llvm/ADT/FunctionExtras.h" 15 #include "llvm/Support/TypeName.h" 16 #include <optional> 17 18 using llvm::SmallPtrSetImpl; 19 namespace mlir { 20 21 class PatternRewriter; 22 23 //===----------------------------------------------------------------------===// 24 // PatternBenefit class 25 //===----------------------------------------------------------------------===// 26 27 /// This class represents the benefit of a pattern match in a unitless scheme 28 /// that ranges from 0 (very little benefit) to 65K. The most common unit to 29 /// use here is the "number of operations matched" by the pattern. 30 /// 31 /// This also has a sentinel representation that can be used for patterns that 32 /// fail to match. 33 /// 34 class PatternBenefit { 35 enum { ImpossibleToMatchSentinel = 65535 }; 36 37 public: 38 PatternBenefit() = default; 39 PatternBenefit(unsigned benefit); 40 PatternBenefit(const PatternBenefit &) = default; 41 PatternBenefit &operator=(const PatternBenefit &) = default; 42 43 static PatternBenefit impossibleToMatch() { return PatternBenefit(); } 44 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); } 45 46 /// If the corresponding pattern can match, return its benefit. If the 47 // corresponding pattern isImpossibleToMatch() then this aborts. 48 unsigned short getBenefit() const; 49 50 bool operator==(const PatternBenefit &rhs) const { 51 return representation == rhs.representation; 52 } 53 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); } 54 bool operator<(const PatternBenefit &rhs) const { 55 return representation < rhs.representation; 56 } 57 bool operator>(const PatternBenefit &rhs) const { return rhs < *this; } 58 bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); } 59 bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); } 60 61 private: 62 unsigned short representation{ImpossibleToMatchSentinel}; 63 }; 64 65 //===----------------------------------------------------------------------===// 66 // Pattern 67 //===----------------------------------------------------------------------===// 68 69 /// This class contains all of the data related to a pattern, but does not 70 /// contain any methods or logic for the actual matching. This class is solely 71 /// used to interface with the metadata of a pattern, such as the benefit or 72 /// root operation. 73 class Pattern { 74 /// This enum represents the kind of value used to select the root operations 75 /// that match this pattern. 76 enum class RootKind { 77 /// The pattern root matches "any" operation. 78 Any, 79 /// The pattern root is matched using a concrete operation name. 80 OperationName, 81 /// The pattern root is matched using an interface ID. 82 InterfaceID, 83 /// The patter root is matched using a trait ID. 84 TraitID 85 }; 86 87 public: 88 /// Return a list of operations that may be generated when rewriting an 89 /// operation instance with this pattern. 90 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; } 91 92 /// Return the root node that this pattern matches. Patterns that can match 93 /// multiple root types return std::nullopt. 94 std::optional<OperationName> getRootKind() const { 95 if (rootKind == RootKind::OperationName) 96 return OperationName::getFromOpaquePointer(rootValue); 97 return std::nullopt; 98 } 99 100 /// Return the interface ID used to match the root operation of this pattern. 101 /// If the pattern does not use an interface ID for deciding the root match, 102 /// this returns std::nullopt. 103 std::optional<TypeID> getRootInterfaceID() const { 104 if (rootKind == RootKind::InterfaceID) 105 return TypeID::getFromOpaquePointer(rootValue); 106 return std::nullopt; 107 } 108 109 /// Return the trait ID used to match the root operation of this pattern. 110 /// If the pattern does not use a trait ID for deciding the root match, this 111 /// returns std::nullopt. 112 std::optional<TypeID> getRootTraitID() const { 113 if (rootKind == RootKind::TraitID) 114 return TypeID::getFromOpaquePointer(rootValue); 115 return std::nullopt; 116 } 117 118 /// Return the benefit (the inverse of "cost") of matching this pattern. The 119 /// benefit of a Pattern is always static - rewrites that may have dynamic 120 /// benefit can be instantiated multiple times (different Pattern instances) 121 /// for each benefit that they may return, and be guarded by different match 122 /// condition predicates. 123 PatternBenefit getBenefit() const { return benefit; } 124 125 /// Returns true if this pattern is known to result in recursive application, 126 /// i.e. this pattern may generate IR that also matches this pattern, but is 127 /// known to bound the recursion. This signals to a rewrite driver that it is 128 /// safe to apply this pattern recursively to generated IR. 129 bool hasBoundedRewriteRecursion() const { 130 return contextAndHasBoundedRecursion.getInt(); 131 } 132 133 /// Return the MLIRContext used to create this pattern. 134 MLIRContext *getContext() const { 135 return contextAndHasBoundedRecursion.getPointer(); 136 } 137 138 /// Return a readable name for this pattern. This name should only be used for 139 /// debugging purposes, and may be empty. 140 StringRef getDebugName() const { return debugName; } 141 142 /// Set the human readable debug name used for this pattern. This name will 143 /// only be used for debugging purposes. 144 void setDebugName(StringRef name) { debugName = name; } 145 146 /// Return the set of debug labels attached to this pattern. 147 ArrayRef<StringRef> getDebugLabels() const { return debugLabels; } 148 149 /// Add the provided debug labels to this pattern. 150 void addDebugLabels(ArrayRef<StringRef> labels) { 151 debugLabels.append(labels.begin(), labels.end()); 152 } 153 void addDebugLabels(StringRef label) { debugLabels.push_back(label); } 154 155 protected: 156 /// This class acts as a special tag that makes the desire to match "any" 157 /// operation type explicit. This helps to avoid unnecessary usages of this 158 /// feature, and ensures that the user is making a conscious decision. 159 struct MatchAnyOpTypeTag {}; 160 /// This class acts as a special tag that makes the desire to match any 161 /// operation that implements a given interface explicit. This helps to avoid 162 /// unnecessary usages of this feature, and ensures that the user is making a 163 /// conscious decision. 164 struct MatchInterfaceOpTypeTag {}; 165 /// This class acts as a special tag that makes the desire to match any 166 /// operation that implements a given trait explicit. This helps to avoid 167 /// unnecessary usages of this feature, and ensures that the user is making a 168 /// conscious decision. 169 struct MatchTraitOpTypeTag {}; 170 171 /// Construct a pattern with a certain benefit that matches the operation 172 /// with the given root name. 173 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, 174 ArrayRef<StringRef> generatedNames = {}); 175 /// Construct a pattern that may match any operation type. `generatedNames` 176 /// contains the names of operations that may be generated during a successful 177 /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" 178 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should 179 /// always be supplied here. 180 Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, 181 ArrayRef<StringRef> generatedNames = {}); 182 /// Construct a pattern that may match any operation that implements the 183 /// interface defined by the provided `interfaceID`. `generatedNames` contains 184 /// the names of operations that may be generated during a successful rewrite. 185 /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match 186 /// interface" behavior is what the user actually desired, 187 /// `MatchInterfaceOpTypeTag()` should always be supplied here. 188 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID, 189 PatternBenefit benefit, MLIRContext *context, 190 ArrayRef<StringRef> generatedNames = {}); 191 /// Construct a pattern that may match any operation that implements the 192 /// trait defined by the provided `traitID`. `generatedNames` contains the 193 /// names of operations that may be generated during a successful rewrite. 194 /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait" 195 /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should 196 /// always be supplied here. 197 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, 198 MLIRContext *context, ArrayRef<StringRef> generatedNames = {}); 199 200 /// Set the flag detailing if this pattern has bounded rewrite recursion or 201 /// not. 202 void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) { 203 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg); 204 } 205 206 private: 207 Pattern(const void *rootValue, RootKind rootKind, 208 ArrayRef<StringRef> generatedNames, PatternBenefit benefit, 209 MLIRContext *context); 210 211 /// The value used to match the root operation of the pattern. 212 const void *rootValue; 213 RootKind rootKind; 214 215 /// The expected benefit of matching this pattern. 216 const PatternBenefit benefit; 217 218 /// The context this pattern was created from, and a boolean flag indicating 219 /// whether this pattern has bounded recursion or not. 220 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion; 221 222 /// A list of the potential operations that may be generated when rewriting 223 /// an op with this pattern. 224 SmallVector<OperationName, 2> generatedOps; 225 226 /// A readable name for this pattern. May be empty. 227 StringRef debugName; 228 229 /// The set of debug labels attached to this pattern. 230 SmallVector<StringRef, 0> debugLabels; 231 }; 232 233 //===----------------------------------------------------------------------===// 234 // RewritePattern 235 //===----------------------------------------------------------------------===// 236 237 /// RewritePattern is the common base class for all DAG to DAG replacements. 238 /// There are two possible usages of this class: 239 /// * Multi-step RewritePattern with "match" and "rewrite" 240 /// - By overloading the "match" and "rewrite" functions, the user can 241 /// separate the concerns of matching and rewriting. 242 /// * Single-step RewritePattern with "matchAndRewrite" 243 /// - By overloading the "matchAndRewrite" function, the user can perform 244 /// the rewrite in the same call as the match. 245 /// 246 class RewritePattern : public Pattern { 247 public: 248 virtual ~RewritePattern() = default; 249 250 /// Rewrite the IR rooted at the specified operation with the result of 251 /// this pattern, generating any new operations with the specified 252 /// builder. If an unexpected error is encountered (an internal 253 /// compiler error), it is emitted through the normal MLIR diagnostic 254 /// hooks and the IR is left in a valid state. 255 virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; 256 257 /// Attempt to match against code rooted at the specified operation, 258 /// which is the same operation code as getRootKind(). 259 virtual LogicalResult match(Operation *op) const; 260 261 /// Attempt to match against code rooted at the specified operation, 262 /// which is the same operation code as getRootKind(). If successful, this 263 /// function will automatically perform the rewrite. 264 virtual LogicalResult matchAndRewrite(Operation *op, 265 PatternRewriter &rewriter) const { 266 if (succeeded(match(op))) { 267 rewrite(op, rewriter); 268 return success(); 269 } 270 return failure(); 271 } 272 273 /// This method provides a convenient interface for creating and initializing 274 /// derived rewrite patterns of the given type `T`. 275 template <typename T, typename... Args> 276 static std::unique_ptr<T> create(Args &&...args) { 277 std::unique_ptr<T> pattern = 278 std::make_unique<T>(std::forward<Args>(args)...); 279 initializePattern<T>(*pattern); 280 281 // Set a default debug name if one wasn't provided. 282 if (pattern->getDebugName().empty()) 283 pattern->setDebugName(llvm::getTypeName<T>()); 284 return pattern; 285 } 286 287 protected: 288 /// Inherit the base constructors from `Pattern`. 289 using Pattern::Pattern; 290 291 private: 292 /// Trait to check if T provides a `initialize` method. 293 template <typename T, typename... Args> 294 using has_initialize = decltype(std::declval<T>().initialize()); 295 template <typename T> 296 using detect_has_initialize = llvm::is_detected<has_initialize, T>; 297 298 /// Initialize the derived pattern by calling its `initialize` method. 299 template <typename T> 300 static std::enable_if_t<detect_has_initialize<T>::value> 301 initializePattern(T &pattern) { 302 pattern.initialize(); 303 } 304 /// Empty derived pattern initializer for patterns that do not have an 305 /// initialize method. 306 template <typename T> 307 static std::enable_if_t<!detect_has_initialize<T>::value> 308 initializePattern(T &) {} 309 310 /// An anchor for the virtual table. 311 virtual void anchor(); 312 }; 313 314 namespace detail { 315 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that 316 /// allows for matching and rewriting against an instance of a derived operation 317 /// class or Interface. 318 template <typename SourceOp> 319 struct OpOrInterfaceRewritePatternBase : public RewritePattern { 320 using RewritePattern::RewritePattern; 321 322 /// Wrappers around the RewritePattern methods that pass the derived op type. 323 void rewrite(Operation *op, PatternRewriter &rewriter) const final { 324 rewrite(cast<SourceOp>(op), rewriter); 325 } 326 LogicalResult match(Operation *op) const final { 327 return match(cast<SourceOp>(op)); 328 } 329 LogicalResult matchAndRewrite(Operation *op, 330 PatternRewriter &rewriter) const final { 331 return matchAndRewrite(cast<SourceOp>(op), rewriter); 332 } 333 334 /// Rewrite and Match methods that operate on the SourceOp type. These must be 335 /// overridden by the derived pattern class. 336 virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { 337 llvm_unreachable("must override rewrite or matchAndRewrite"); 338 } 339 virtual LogicalResult match(SourceOp op) const { 340 llvm_unreachable("must override match or matchAndRewrite"); 341 } 342 virtual LogicalResult matchAndRewrite(SourceOp op, 343 PatternRewriter &rewriter) const { 344 if (succeeded(match(op))) { 345 rewrite(op, rewriter); 346 return success(); 347 } 348 return failure(); 349 } 350 }; 351 } // namespace detail 352 353 /// OpRewritePattern is a wrapper around RewritePattern that allows for 354 /// matching and rewriting against an instance of a derived operation class as 355 /// opposed to a raw Operation. 356 template <typename SourceOp> 357 struct OpRewritePattern 358 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { 359 /// Patterns must specify the root operation name they match against, and can 360 /// also specify the benefit of the pattern matching and a list of generated 361 /// ops. 362 OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1, 363 ArrayRef<StringRef> generatedNames = {}) 364 : detail::OpOrInterfaceRewritePatternBase<SourceOp>( 365 SourceOp::getOperationName(), benefit, context, generatedNames) {} 366 }; 367 368 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for 369 /// matching and rewriting against an instance of an operation interface instead 370 /// of a raw Operation. 371 template <typename SourceOp> 372 struct OpInterfaceRewritePattern 373 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { 374 OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) 375 : detail::OpOrInterfaceRewritePatternBase<SourceOp>( 376 Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), 377 benefit, context) {} 378 }; 379 380 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for 381 /// matching and rewriting against instances of an operation that possess a 382 /// given trait. 383 template <template <typename> class TraitType> 384 class OpTraitRewritePattern : public RewritePattern { 385 public: 386 OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) 387 : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(), 388 benefit, context) {} 389 }; 390 391 //===----------------------------------------------------------------------===// 392 // RewriterBase 393 //===----------------------------------------------------------------------===// 394 395 /// This class coordinates the application of a rewrite on a set of IR, 396 /// providing a way for clients to track mutations and create new operations. 397 /// This class serves as a common API for IR mutation between pattern rewrites 398 /// and non-pattern rewrites, and facilitates the development of shared 399 /// IR transformation utilities. 400 class RewriterBase : public OpBuilder { 401 public: 402 struct Listener : public OpBuilder::Listener { 403 Listener() 404 : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {} 405 406 /// Notify the listener that the specified block is about to be erased. 407 /// At this point, the block has zero uses. 408 virtual void notifyBlockErased(Block *block) {} 409 410 /// Notify the listener that the specified operation was modified in-place. 411 virtual void notifyOperationModified(Operation *op) {} 412 413 /// Notify the listener that all uses of the specified operation's results 414 /// are about to be replaced with the results of another operation. This is 415 /// called before the uses of the old operation have been changed. 416 /// 417 /// By default, this function calls the "operation replaced with values" 418 /// notification. 419 virtual void notifyOperationReplaced(Operation *op, 420 Operation *replacement) { 421 notifyOperationReplaced(op, replacement->getResults()); 422 } 423 424 /// Notify the listener that all uses of the specified operation's results 425 /// are about to be replaced with the a range of values, potentially 426 /// produced by other operations. This is called before the uses of the 427 /// operation have been changed. 428 virtual void notifyOperationReplaced(Operation *op, 429 ValueRange replacement) {} 430 431 /// Notify the listener that the specified operation is about to be erased. 432 /// At this point, the operation has zero uses. 433 /// 434 /// Note: This notification is not triggered when unlinking an operation. 435 virtual void notifyOperationErased(Operation *op) {} 436 437 /// Notify the listener that the specified pattern is about to be applied 438 /// at the specified root operation. 439 virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {} 440 441 /// Notify the listener that a pattern application finished with the 442 /// specified status. "success" indicates that the pattern was applied 443 /// successfully. "failure" indicates that the pattern could not be 444 /// applied. The pattern may have communicated the reason for the failure 445 /// with `notifyMatchFailure`. 446 virtual void notifyPatternEnd(const Pattern &pattern, 447 LogicalResult status) {} 448 449 /// Notify the listener that the pattern failed to match, and provide a 450 /// callback to populate a diagnostic with the reason why the failure 451 /// occurred. This method allows for derived listeners to optionally hook 452 /// into the reason why a rewrite failed, and display it to users. 453 virtual void 454 notifyMatchFailure(Location loc, 455 function_ref<void(Diagnostic &)> reasonCallback) {} 456 457 static bool classof(const OpBuilder::Listener *base); 458 }; 459 460 /// A listener that forwards all notifications to another listener. This 461 /// struct can be used as a base to create listener chains, so that multiple 462 /// listeners can be notified of IR changes. 463 struct ForwardingListener : public RewriterBase::Listener { 464 ForwardingListener(OpBuilder::Listener *listener) 465 : listener(listener), 466 rewriteListener( 467 dyn_cast_if_present<RewriterBase::Listener>(listener)) {} 468 469 void notifyOperationInserted(Operation *op, InsertPoint previous) override { 470 if (listener) 471 listener->notifyOperationInserted(op, previous); 472 } 473 void notifyBlockInserted(Block *block, Region *previous, 474 Region::iterator previousIt) override { 475 if (listener) 476 listener->notifyBlockInserted(block, previous, previousIt); 477 } 478 void notifyBlockErased(Block *block) override { 479 if (rewriteListener) 480 rewriteListener->notifyBlockErased(block); 481 } 482 void notifyOperationModified(Operation *op) override { 483 if (rewriteListener) 484 rewriteListener->notifyOperationModified(op); 485 } 486 void notifyOperationReplaced(Operation *op, Operation *newOp) override { 487 if (rewriteListener) 488 rewriteListener->notifyOperationReplaced(op, newOp); 489 } 490 void notifyOperationReplaced(Operation *op, 491 ValueRange replacement) override { 492 if (rewriteListener) 493 rewriteListener->notifyOperationReplaced(op, replacement); 494 } 495 void notifyOperationErased(Operation *op) override { 496 if (rewriteListener) 497 rewriteListener->notifyOperationErased(op); 498 } 499 void notifyPatternBegin(const Pattern &pattern, Operation *op) override { 500 if (rewriteListener) 501 rewriteListener->notifyPatternBegin(pattern, op); 502 } 503 void notifyPatternEnd(const Pattern &pattern, 504 LogicalResult status) override { 505 if (rewriteListener) 506 rewriteListener->notifyPatternEnd(pattern, status); 507 } 508 void notifyMatchFailure( 509 Location loc, 510 function_ref<void(Diagnostic &)> reasonCallback) override { 511 if (rewriteListener) 512 rewriteListener->notifyMatchFailure(loc, reasonCallback); 513 } 514 515 private: 516 OpBuilder::Listener *listener; 517 RewriterBase::Listener *rewriteListener; 518 }; 519 520 /// Move the blocks that belong to "region" before the given position in 521 /// another region "parent". The two regions must be different. The caller 522 /// is responsible for creating or updating the operation transferring flow 523 /// of control to the region and passing it the correct block arguments. 524 void inlineRegionBefore(Region ®ion, Region &parent, 525 Region::iterator before); 526 void inlineRegionBefore(Region ®ion, Block *before); 527 528 /// Replace the results of the given (original) operation with the specified 529 /// list of values (replacements). The result types of the given op and the 530 /// replacements must match. The original op is erased. 531 virtual void replaceOp(Operation *op, ValueRange newValues); 532 533 /// Replace the results of the given (original) operation with the specified 534 /// new op (replacement). The result types of the two ops must match. The 535 /// original op is erased. 536 virtual void replaceOp(Operation *op, Operation *newOp); 537 538 /// Replace the results of the given (original) op with a new op that is 539 /// created without verification (replacement). The result values of the two 540 /// ops must match. The original op is erased. 541 template <typename OpTy, typename... Args> 542 OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { 543 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); 544 replaceOp(op, newOp.getOperation()); 545 return newOp; 546 } 547 548 /// This method erases an operation that is known to have no uses. 549 virtual void eraseOp(Operation *op); 550 551 /// This method erases all operations in a block. 552 virtual void eraseBlock(Block *block); 553 554 /// Inline the operations of block 'source' into block 'dest' before the given 555 /// position. The source block will be deleted and must have no uses. 556 /// 'argValues' is used to replace the block arguments of 'source'. 557 /// 558 /// If the source block is inserted at the end of the dest block, the dest 559 /// block must have no successors. Similarly, if the source block is inserted 560 /// somewhere in the middle (or beginning) of the dest block, the source block 561 /// must have no successors. Otherwise, the resulting IR would have 562 /// unreachable operations. 563 virtual void inlineBlockBefore(Block *source, Block *dest, 564 Block::iterator before, 565 ValueRange argValues = std::nullopt); 566 567 /// Inline the operations of block 'source' before the operation 'op'. The 568 /// source block will be deleted and must have no uses. 'argValues' is used to 569 /// replace the block arguments of 'source' 570 /// 571 /// The source block must have no successors. Otherwise, the resulting IR 572 /// would have unreachable operations. 573 void inlineBlockBefore(Block *source, Operation *op, 574 ValueRange argValues = std::nullopt); 575 576 /// Inline the operations of block 'source' into the end of block 'dest'. The 577 /// source block will be deleted and must have no uses. 'argValues' is used to 578 /// replace the block arguments of 'source' 579 /// 580 /// The dest block must have no successors. Otherwise, the resulting IR would 581 /// have unreachable operation. 582 void mergeBlocks(Block *source, Block *dest, 583 ValueRange argValues = std::nullopt); 584 585 /// Split the operations starting at "before" (inclusive) out of the given 586 /// block into a new block, and return it. 587 Block *splitBlock(Block *block, Block::iterator before); 588 589 /// Unlink this operation from its current block and insert it right before 590 /// `existingOp` which may be in the same or another block in the same 591 /// function. 592 void moveOpBefore(Operation *op, Operation *existingOp); 593 594 /// Unlink this operation from its current block and insert it right before 595 /// `iterator` in the specified block. 596 void moveOpBefore(Operation *op, Block *block, Block::iterator iterator); 597 598 /// Unlink this operation from its current block and insert it right after 599 /// `existingOp` which may be in the same or another block in the same 600 /// function. 601 void moveOpAfter(Operation *op, Operation *existingOp); 602 603 /// Unlink this operation from its current block and insert it right after 604 /// `iterator` in the specified block. 605 void moveOpAfter(Operation *op, Block *block, Block::iterator iterator); 606 607 /// Unlink this block and insert it right before `existingBlock`. 608 void moveBlockBefore(Block *block, Block *anotherBlock); 609 610 /// Unlink this block and insert it right before the location that the given 611 /// iterator points to in the given region. 612 void moveBlockBefore(Block *block, Region *region, Region::iterator iterator); 613 614 /// This method is used to notify the rewriter that an in-place operation 615 /// modification is about to happen. A call to this function *must* be 616 /// followed by a call to either `finalizeOpModification` or 617 /// `cancelOpModification`. This is a minor efficiency win (it avoids creating 618 /// a new operation and removing the old one) but also often allows simpler 619 /// code in the client. 620 virtual void startOpModification(Operation *op) {} 621 622 /// This method is used to signal the end of an in-place modification of the 623 /// given operation. This can only be called on operations that were provided 624 /// to a call to `startOpModification`. 625 virtual void finalizeOpModification(Operation *op); 626 627 /// This method cancels a pending in-place modification. This can only be 628 /// called on operations that were provided to a call to 629 /// `startOpModification`. 630 virtual void cancelOpModification(Operation *op) {} 631 632 /// This method is a utility wrapper around an in-place modification of an 633 /// operation. It wraps calls to `startOpModification` and 634 /// `finalizeOpModification` around the given callable. 635 template <typename CallableT> 636 void modifyOpInPlace(Operation *root, CallableT &&callable) { 637 startOpModification(root); 638 callable(); 639 finalizeOpModification(root); 640 } 641 642 /// Find uses of `from` and replace them with `to`. Also notify the listener 643 /// about every in-place op modification (for every use that was replaced). 644 void replaceAllUsesWith(Value from, Value to) { 645 for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { 646 Operation *op = operand.getOwner(); 647 modifyOpInPlace(op, [&]() { operand.set(to); }); 648 } 649 } 650 void replaceAllUsesWith(Block *from, Block *to) { 651 for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) { 652 Operation *op = operand.getOwner(); 653 modifyOpInPlace(op, [&]() { operand.set(to); }); 654 } 655 } 656 void replaceAllUsesWith(ValueRange from, ValueRange to) { 657 assert(from.size() == to.size() && "incorrect number of replacements"); 658 for (auto it : llvm::zip(from, to)) 659 replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); 660 } 661 662 /// Find uses of `from` and replace them with `to`. Also notify the listener 663 /// about every in-place op modification (for every use that was replaced) 664 /// and that the `from` operation is about to be replaced. 665 /// 666 /// Note: This function cannot be called `replaceAllUsesWith` because the 667 /// overload resolution, when called with an op that can be implicitly 668 /// converted to a Value, would be ambiguous. 669 void replaceAllOpUsesWith(Operation *from, ValueRange to); 670 void replaceAllOpUsesWith(Operation *from, Operation *to); 671 672 /// Find uses of `from` and replace them with `to` if the `functor` returns 673 /// true. Also notify the listener about every in-place op modification (for 674 /// every use that was replaced). The optional `allUsesReplaced` flag is set 675 /// to "true" if all uses were replaced. 676 void replaceUsesWithIf(Value from, Value to, 677 function_ref<bool(OpOperand &)> functor, 678 bool *allUsesReplaced = nullptr); 679 void replaceUsesWithIf(ValueRange from, ValueRange to, 680 function_ref<bool(OpOperand &)> functor, 681 bool *allUsesReplaced = nullptr); 682 // Note: This function cannot be called `replaceOpUsesWithIf` because the 683 // overload resolution, when called with an op that can be implicitly 684 // converted to a Value, would be ambiguous. 685 void replaceOpUsesWithIf(Operation *from, ValueRange to, 686 function_ref<bool(OpOperand &)> functor, 687 bool *allUsesReplaced = nullptr) { 688 replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced); 689 } 690 691 /// Find uses of `from` within `block` and replace them with `to`. Also notify 692 /// the listener about every in-place op modification (for every use that was 693 /// replaced). The optional `allUsesReplaced` flag is set to "true" if all 694 /// uses were replaced. 695 void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, 696 Block *block, bool *allUsesReplaced = nullptr) { 697 replaceOpUsesWithIf( 698 op, newValues, 699 [block](OpOperand &use) { 700 return block->getParentOp()->isProperAncestor(use.getOwner()); 701 }, 702 allUsesReplaced); 703 } 704 705 /// Find uses of `from` and replace them with `to` except if the user is 706 /// `exceptedUser`. Also notify the listener about every in-place op 707 /// modification (for every use that was replaced). 708 void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) { 709 return replaceUsesWithIf(from, to, [&](OpOperand &use) { 710 Operation *user = use.getOwner(); 711 return user != exceptedUser; 712 }); 713 } 714 void replaceAllUsesExcept(Value from, Value to, 715 const SmallPtrSetImpl<Operation *> &preservedUsers); 716 717 /// Used to notify the listener that the IR failed to be rewritten because of 718 /// a match failure, and provide a callback to populate a diagnostic with the 719 /// reason why the failure occurred. This method allows for derived rewriters 720 /// to optionally hook into the reason why a rewrite failed, and display it to 721 /// users. 722 template <typename CallbackT> 723 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> 724 notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { 725 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener)) 726 rewriteListener->notifyMatchFailure( 727 loc, function_ref<void(Diagnostic &)>(reasonCallback)); 728 return failure(); 729 } 730 template <typename CallbackT> 731 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> 732 notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { 733 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener)) 734 rewriteListener->notifyMatchFailure( 735 op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback)); 736 return failure(); 737 } 738 template <typename ArgT> 739 LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) { 740 return notifyMatchFailure(std::forward<ArgT>(arg), 741 [&](Diagnostic &diag) { diag << msg; }); 742 } 743 template <typename ArgT> 744 LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) { 745 return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg)); 746 } 747 748 protected: 749 /// Initialize the builder. 750 explicit RewriterBase(MLIRContext *ctx, 751 OpBuilder::Listener *listener = nullptr) 752 : OpBuilder(ctx, listener) {} 753 explicit RewriterBase(const OpBuilder &otherBuilder) 754 : OpBuilder(otherBuilder) {} 755 explicit RewriterBase(Operation *op, OpBuilder::Listener *listener = nullptr) 756 : OpBuilder(op, listener) {} 757 virtual ~RewriterBase(); 758 759 private: 760 void operator=(const RewriterBase &) = delete; 761 RewriterBase(const RewriterBase &) = delete; 762 }; 763 764 //===----------------------------------------------------------------------===// 765 // IRRewriter 766 //===----------------------------------------------------------------------===// 767 768 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite, 769 /// providing a way to keep track of the mutations made to the IR. This class 770 /// should only be used in situations where another `RewriterBase` instance, 771 /// such as a `PatternRewriter`, is not available. 772 class IRRewriter : public RewriterBase { 773 public: 774 explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) 775 : RewriterBase(ctx, listener) {} 776 explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} 777 explicit IRRewriter(Operation *op, OpBuilder::Listener *listener = nullptr) 778 : RewriterBase(op, listener) {} 779 }; 780 781 //===----------------------------------------------------------------------===// 782 // PatternRewriter 783 //===----------------------------------------------------------------------===// 784 785 /// A special type of `RewriterBase` that coordinates the application of a 786 /// rewrite pattern on the current IR being matched, providing a way to keep 787 /// track of any mutations made. This class should be used to perform all 788 /// necessary IR mutations within a rewrite pattern, as the pattern driver may 789 /// be tracking various state that would be invalidated when a mutation takes 790 /// place. 791 class PatternRewriter : public RewriterBase { 792 public: 793 explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} 794 using RewriterBase::RewriterBase; 795 796 /// A hook used to indicate if the pattern rewriter can recover from failure 797 /// during the rewrite stage of a pattern. For example, if the pattern 798 /// rewriter supports rollback, it may progress smoothly even if IR was 799 /// changed during the rewrite. 800 virtual bool canRecoverFromRewriteFailure() const { return false; } 801 }; 802 803 } // namespace mlir 804 805 // Optionally expose PDL pattern matching methods. 806 #include "PDLPatternMatch.h.inc" 807 808 namespace mlir { 809 810 //===----------------------------------------------------------------------===// 811 // RewritePatternSet 812 //===----------------------------------------------------------------------===// 813 814 class RewritePatternSet { 815 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; 816 817 public: 818 RewritePatternSet(MLIRContext *context) : context(context) {} 819 820 /// Construct a RewritePatternSet populated with the given pattern. 821 RewritePatternSet(MLIRContext *context, 822 std::unique_ptr<RewritePattern> pattern) 823 : context(context) { 824 nativePatterns.emplace_back(std::move(pattern)); 825 } 826 RewritePatternSet(PDLPatternModule &&pattern) 827 : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {} 828 829 MLIRContext *getContext() const { return context; } 830 831 /// Return the native patterns held in this list. 832 NativePatternListT &getNativePatterns() { return nativePatterns; } 833 834 /// Return the PDL patterns held in this list. 835 PDLPatternModule &getPDLPatterns() { return pdlPatterns; } 836 837 /// Clear out all of the held patterns in this list. 838 void clear() { 839 nativePatterns.clear(); 840 pdlPatterns.clear(); 841 } 842 843 //===--------------------------------------------------------------------===// 844 // 'add' methods for adding patterns to the set. 845 //===--------------------------------------------------------------------===// 846 847 /// Add an instance of each of the pattern types 'Ts' to the pattern list with 848 /// the given arguments. Return a reference to `this` for chaining insertions. 849 /// Note: ConstructorArg is necessary here to separate the two variadic lists. 850 template <typename... Ts, typename ConstructorArg, 851 typename... ConstructorArgs, 852 typename = std::enable_if_t<sizeof...(Ts) != 0>> 853 RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) { 854 // The following expands a call to emplace_back for each of the pattern 855 // types 'Ts'. 856 (addImpl<Ts>(/*debugLabels=*/std::nullopt, 857 std::forward<ConstructorArg>(arg), 858 std::forward<ConstructorArgs>(args)...), 859 ...); 860 return *this; 861 } 862 /// An overload of the above `add` method that allows for attaching a set 863 /// of debug labels to the attached patterns. This is useful for labeling 864 /// groups of patterns that may be shared between multiple different 865 /// passes/users. 866 template <typename... Ts, typename ConstructorArg, 867 typename... ConstructorArgs, 868 typename = std::enable_if_t<sizeof...(Ts) != 0>> 869 RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels, 870 ConstructorArg &&arg, 871 ConstructorArgs &&...args) { 872 // The following expands a call to emplace_back for each of the pattern 873 // types 'Ts'. 874 (addImpl<Ts>(debugLabels, arg, args...), ...); 875 return *this; 876 } 877 878 /// Add an instance of each of the pattern types 'Ts'. Return a reference to 879 /// `this` for chaining insertions. 880 template <typename... Ts> 881 RewritePatternSet &add() { 882 (addImpl<Ts>(), ...); 883 return *this; 884 } 885 886 /// Add the given native pattern to the pattern list. Return a reference to 887 /// `this` for chaining insertions. 888 RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) { 889 nativePatterns.emplace_back(std::move(pattern)); 890 return *this; 891 } 892 893 /// Add the given PDL pattern to the pattern list. Return a reference to 894 /// `this` for chaining insertions. 895 RewritePatternSet &add(PDLPatternModule &&pattern) { 896 pdlPatterns.mergeIn(std::move(pattern)); 897 return *this; 898 } 899 900 // Add a matchAndRewrite style pattern represented as a C function pointer. 901 template <typename OpType> 902 RewritePatternSet & 903 add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), 904 PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) { 905 struct FnPattern final : public OpRewritePattern<OpType> { 906 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), 907 MLIRContext *context, PatternBenefit benefit, 908 ArrayRef<StringRef> generatedNames) 909 : OpRewritePattern<OpType>(context, benefit, generatedNames), 910 implFn(implFn) {} 911 912 LogicalResult matchAndRewrite(OpType op, 913 PatternRewriter &rewriter) const override { 914 return implFn(op, rewriter); 915 } 916 917 private: 918 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); 919 }; 920 add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit, 921 generatedNames)); 922 return *this; 923 } 924 925 //===--------------------------------------------------------------------===// 926 // Pattern Insertion 927 //===--------------------------------------------------------------------===// 928 929 // TODO: These are soft deprecated in favor of the 'add' methods above. 930 931 /// Add an instance of each of the pattern types 'Ts' to the pattern list with 932 /// the given arguments. Return a reference to `this` for chaining insertions. 933 /// Note: ConstructorArg is necessary here to separate the two variadic lists. 934 template <typename... Ts, typename ConstructorArg, 935 typename... ConstructorArgs, 936 typename = std::enable_if_t<sizeof...(Ts) != 0>> 937 RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) { 938 // The following expands a call to emplace_back for each of the pattern 939 // types 'Ts'. 940 (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...); 941 return *this; 942 } 943 944 /// Add an instance of each of the pattern types 'Ts'. Return a reference to 945 /// `this` for chaining insertions. 946 template <typename... Ts> 947 RewritePatternSet &insert() { 948 (addImpl<Ts>(), ...); 949 return *this; 950 } 951 952 /// Add the given native pattern to the pattern list. Return a reference to 953 /// `this` for chaining insertions. 954 RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) { 955 nativePatterns.emplace_back(std::move(pattern)); 956 return *this; 957 } 958 959 /// Add the given PDL pattern to the pattern list. Return a reference to 960 /// `this` for chaining insertions. 961 RewritePatternSet &insert(PDLPatternModule &&pattern) { 962 pdlPatterns.mergeIn(std::move(pattern)); 963 return *this; 964 } 965 966 // Add a matchAndRewrite style pattern represented as a C function pointer. 967 template <typename OpType> 968 RewritePatternSet & 969 insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) { 970 struct FnPattern final : public OpRewritePattern<OpType> { 971 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), 972 MLIRContext *context) 973 : OpRewritePattern<OpType>(context), implFn(implFn) { 974 this->setDebugName(llvm::getTypeName<FnPattern>()); 975 } 976 977 LogicalResult matchAndRewrite(OpType op, 978 PatternRewriter &rewriter) const override { 979 return implFn(op, rewriter); 980 } 981 982 private: 983 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); 984 }; 985 add(std::make_unique<FnPattern>(std::move(implFn), getContext())); 986 return *this; 987 } 988 989 private: 990 /// Add an instance of the pattern type 'T'. Return a reference to `this` for 991 /// chaining insertions. 992 template <typename T, typename... Args> 993 std::enable_if_t<std::is_base_of<RewritePattern, T>::value> 994 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { 995 std::unique_ptr<T> pattern = 996 RewritePattern::create<T>(std::forward<Args>(args)...); 997 pattern->addDebugLabels(debugLabels); 998 nativePatterns.emplace_back(std::move(pattern)); 999 } 1000 1001 template <typename T, typename... Args> 1002 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value> 1003 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { 1004 // TODO: Add the provided labels to the PDL pattern when PDL supports 1005 // labels. 1006 pdlPatterns.mergeIn(T(std::forward<Args>(args)...)); 1007 } 1008 1009 MLIRContext *const context; 1010 NativePatternListT nativePatterns; 1011 1012 // Patterns expressed with PDL. This will compile to a stub class when PDL is 1013 // not enabled. 1014 PDLPatternModule pdlPatterns; 1015 }; 1016 1017 } // namespace mlir 1018 1019 #endif // MLIR_IR_PATTERNMATCH_H 1020