1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H 10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H 11 12 #include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" 13 #include "mlir/Dialect/Transform/Utils/RaggedArray.h" 14 #include "mlir/IR/OpDefinition.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 19 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc" 20 21 namespace mlir { 22 namespace transform { 23 24 class TransformOpInterface; 25 class TransformResults; 26 class TransformRewriter; 27 class TransformState; 28 29 using Param = Attribute; 30 using MappedValue = llvm::PointerUnion<Operation *, Param, Value>; 31 32 namespace detail { 33 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait 34 /// to either the list of operations associated with its operand or the root of 35 /// the payload IR, depending on what is available in the context. 36 LogicalResult 37 mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, 38 Operation *op, Region ®ion); 39 40 /// Verification hook for PossibleTopLevelTransformOpTrait. 41 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); 42 43 /// Populates `effects` with side effects implied by 44 /// PossibleTopLevelTransformOpTrait for the given operation. The operation may 45 /// have an optional `root` operand, indicating it is not in fact top-level. It 46 /// is also expected to have a single-block body. 47 void getPotentialTopLevelEffects( 48 Operation *operation, Value root, Block &body, 49 SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 50 51 /// Verification hook for TransformOpInterface. 52 LogicalResult verifyTransformOpInterface(Operation *op); 53 54 /// Appends the entities associated with the given transform values in `state` 55 /// to the pre-existing list of mappings. The array of mappings must have as 56 /// many elements as values. If `flatten` is set, multiple values may be 57 /// associated with each transform value, and this always succeeds. Otherwise, 58 /// checks that each value has exactly one mapping associated and return failure 59 /// otherwise. 60 LogicalResult appendValueMappings( 61 MutableArrayRef<SmallVector<transform::MappedValue>> mappings, 62 ValueRange values, const transform::TransformState &state, 63 bool flatten = true); 64 65 /// Populates `mappings` with mapped values associated with the given transform 66 /// IR values in the given `state`. 67 void prepareValueMappings( 68 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings, 69 ValueRange values, const transform::TransformState &state); 70 71 /// Populates `results` with payload associations that match exactly those of 72 /// the operands to `block`'s terminator. 73 void forwardTerminatorOperands(Block *block, transform::TransformState &state, 74 transform::TransformResults &results); 75 76 /// Make a dummy transform state for testing purposes. This MUST NOT be used 77 /// outside of test cases. 78 TransformState makeTransformStateForTesting(Region *region, 79 Operation *payloadRoot); 80 81 /// Returns all operands that are handles and being consumed by the given op. 82 SmallVector<OpOperand *> 83 getConsumedHandleOpOperands(transform::TransformOpInterface transformOp); 84 } // namespace detail 85 } // namespace transform 86 } // namespace mlir 87 88 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc" 89 90 namespace mlir { 91 namespace transform { 92 93 /// Options controlling the application of transform operations by the 94 /// TransformState. 95 class TransformOptions { 96 public: 97 TransformOptions() = default; 98 TransformOptions(const TransformOptions &) = default; 99 TransformOptions &operator=(const TransformOptions &) = default; 100 101 /// Requests computationally expensive checks of the transform and payload IR 102 /// well-formedness to be performed before each transformation. In particular, 103 /// these ensure that the handles still point to valid operations when used. 104 TransformOptions &enableExpensiveChecks(bool enable = true) { 105 expensiveChecksEnabled = enable; 106 return *this; 107 } 108 109 // Ensures that only a single top-level transform op is present in the IR. 110 TransformOptions &enableEnforceSingleToplevelTransformOp(bool enable = true) { 111 enforceSingleToplevelTransformOp = enable; 112 return *this; 113 } 114 115 /// Returns true if the expensive checks are requested. 116 bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; } 117 118 // Returns true if enforcing a single top-level transform op is requested. 119 bool getEnforceSingleToplevelTransformOp() const { 120 return enforceSingleToplevelTransformOp; 121 } 122 123 private: 124 bool expensiveChecksEnabled = true; 125 bool enforceSingleToplevelTransformOp = true; 126 }; 127 128 /// Entry point to the Transform dialect infrastructure. Applies the 129 /// transformation specified by `transform` to payload IR contained in 130 /// `payloadRoot`. The `transform` operation may contain other operations that 131 /// will be executed following the internal logic of the operation. It must 132 /// have the `PossibleTopLevelTransformOp` trait and not have any operands. 133 /// This function internally keeps track of the transformation state. 134 LogicalResult applyTransforms( 135 Operation *payloadRoot, TransformOpInterface transform, 136 const RaggedArray<MappedValue> &extraMapping = {}, 137 const TransformOptions &options = TransformOptions(), 138 bool enforceToplevelTransformOp = true, 139 function_ref<void(TransformState &)> stateInitializer = nullptr, 140 function_ref<LogicalResult(TransformState &)> stateExporter = nullptr); 141 142 /// The state maintained across applications of various ops implementing the 143 /// TransformOpInterface. The operations implementing this interface and the 144 /// surrounding structure are referred to as transform IR. The operations to 145 /// which transformations apply are referred to as payload IR. Transform IR 146 /// operates on values that can be associated either with a list of payload IR 147 /// operations (such values are referred to as handles) or with a list of 148 /// parameters represented as attributes. The state thus contains the mapping 149 /// between values defined in the transform IR ops and either payload IR ops or 150 /// parameters. For payload ops, the mapping is many-to-many and the reverse 151 /// mapping is also stored. The "expensive-checks" option can be passed to the 152 /// constructor at transformation execution time that transform IR values used 153 /// as operands by a transform IR operation are not associated with dangling 154 /// pointers to payload IR operations that are known to have been erased by 155 /// previous transformation through the same or a different transform IR value. 156 /// 157 /// A reference to this class is passed as an argument to "apply" methods of the 158 /// transform op interface. Thus the "apply" method can call either 159 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations 160 /// or `state.getParams( getSomeOperand() )` to obtain the list of parameters 161 /// associated with its operand. The method is expected to populate the 162 /// `TransformResults` class instance in order to update the mapping. The 163 /// `applyTransform` method takes care of propagating the state of 164 /// `TransformResults` into the instance of this class. 165 /// 166 /// When applying transform IR operations with regions, the client is expected 167 /// to create a `RegionScope` RAII object to create a new "stack frame" for 168 /// values defined inside the region. The mappings from and to these values will 169 /// be automatically dropped when the object goes out of scope, typically at the 170 /// end of the `apply` function of the parent operation. If a region contains 171 /// blocks with arguments, the client can map those arguments to payload IR ops 172 /// using `mapBlockArguments`. 173 class TransformState { 174 public: 175 using Param = transform::Param; 176 177 private: 178 /// Mapping between a Value in the transform IR and the corresponding set of 179 /// operations in the payload IR. 180 using TransformOpMapping = DenseMap<Value, SmallVector<Operation *, 2>>; 181 182 /// Mapping between a payload IR operation and the transform IR values it is 183 /// associated with. 184 using TransformOpReverseMapping = 185 DenseMap<Operation *, SmallVector<Value, 2>>; 186 187 /// Mapping between a Value in the transform IR and the corresponding list of 188 /// parameters. 189 using ParamMapping = DenseMap<Value, SmallVector<Param>>; 190 191 /// Mapping between a Value in the transform IR and the corrsponding list of 192 /// values in the payload IR. Also works for reverse mappings. 193 using ValueMapping = DenseMap<Value, SmallVector<Value>>; 194 195 /// Mapping between a Value in the transform IR and an error message that 196 /// should be emitted when the value is used. 197 using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>; 198 199 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 200 /// Debug only: A timestamp is associated with each transform IR value, so 201 /// that invalid iterator usage can be detected more reliably. 202 using TransformIRTimestampMapping = DenseMap<Value, int64_t>; 203 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 204 205 /// The bidirectional mappings between transform IR values and payload IR 206 /// operations, and the mapping between transform IR values and parameters. 207 struct Mappings { 208 TransformOpMapping direct; 209 TransformOpReverseMapping reverse; 210 ParamMapping params; 211 ValueMapping values; 212 ValueMapping reverseValues; 213 214 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 215 TransformIRTimestampMapping timestamps; 216 void incrementTimestamp(Value value) { ++timestamps[value]; } 217 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 218 }; 219 220 friend LogicalResult 221 applyTransforms(Operation *, TransformOpInterface, 222 const RaggedArray<MappedValue> &, const TransformOptions &, 223 bool, function_ref<void(TransformState &)>, 224 function_ref<LogicalResult(TransformState &)>); 225 226 friend TransformState 227 detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot); 228 229 public: 230 const TransformOptions &getOptions() const { return options; } 231 232 /// Returns the op at which the transformation state is rooted. This is 233 /// typically helpful for transformations that apply globally. 234 Operation *getTopLevel() const; 235 236 /// Returns the number of extra mappings for the top-level operation. 237 size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); } 238 239 /// Returns the position-th extra mapping for the top-level operation. 240 ArrayRef<MappedValue> getTopLevelMapping(size_t position) const { 241 return topLevelMappedValues[position]; 242 } 243 244 /// Returns an iterator that enumerates all ops that the given transform IR 245 /// value corresponds to. Ops may be erased while iterating; erased ops are 246 /// not enumerated. This function is helpful for transformations that apply to 247 /// a particular handle. 248 auto getPayloadOps(Value value) const { 249 ArrayRef<Operation *> view = getPayloadOpsView(value); 250 251 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 252 // Memorize the current timestamp and make sure that it has not changed 253 // when incrementing or dereferencing the iterator returned by this 254 // function. The timestamp is incremented when the "direct" mapping is 255 // resized; this would invalidate the iterator returned by this function. 256 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value); 257 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 258 259 // When ops are replaced/erased, they are replaced with nullptr (until 260 // the data structure is compacted). Do not enumerate these ops. 261 return llvm::make_filter_range(view, [=](Operation *op) { 262 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 263 [[maybe_unused]] bool sameTimestamp = 264 currentTimestamp == this->getMapping(value).timestamps.lookup(value); 265 assert(sameTimestamp && "iterator was invalidated during iteration"); 266 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 267 return op != nullptr; 268 }); 269 } 270 271 /// Returns the list of parameters that the given transform IR value 272 /// corresponds to. 273 ArrayRef<Attribute> getParams(Value value) const; 274 275 /// Returns an iterator that enumerates all payload IR values that the given 276 /// transform IR value corresponds to. 277 auto getPayloadValues(Value handleValue) const { 278 ArrayRef<Value> view = getPayloadValuesView(handleValue); 279 280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 281 // Memorize the current timestamp and make sure that it has not changed 282 // when incrementing or dereferencing the iterator returned by this 283 // function. The timestamp is incremented when the "values" mapping is 284 // resized; this would invalidate the iterator returned by this function. 285 int64_t currentTimestamp = 286 getMapping(handleValue).timestamps.lookup(handleValue); 287 return llvm::make_filter_range(view, [=](Value v) { 288 [[maybe_unused]] bool sameTimestamp = 289 currentTimestamp == 290 this->getMapping(handleValue).timestamps.lookup(handleValue); 291 assert(sameTimestamp && "iterator was invalidated during iteration"); 292 return true; 293 }); 294 #else 295 return llvm::make_range(view.begin(), view.end()); 296 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 297 } 298 299 /// Populates `handles` with all handles pointing to the given Payload IR op. 300 /// Returns success if such handles exist, failure otherwise. 301 /// If `includeOutOfScope` is set to "true", handles that are defined in 302 /// regions beyond the most recent isolated from above region are included. 303 LogicalResult getHandlesForPayloadOp(Operation *op, 304 SmallVectorImpl<Value> &handles, 305 bool includeOutOfScope = false) const; 306 307 /// Populates `handles` with all handles pointing to the given payload IR 308 /// value. Returns success if such handles exist, failure otherwise. 309 /// If `includeOutOfScope` is set to "true", handles that are defined in 310 /// regions beyond the most recent isolated from above region are included. 311 LogicalResult getHandlesForPayloadValue(Value payloadValue, 312 SmallVectorImpl<Value> &handles, 313 bool includeOutOfScope = false) const; 314 315 /// Applies the transformation specified by the given transform op and updates 316 /// the state accordingly. 317 DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform); 318 319 /// Records the mapping between a block argument in the transform IR and a 320 /// list of operations in the payload IR. The arguments must be defined in 321 /// blocks of the currently processed transform IR region, typically after a 322 /// region scope is defined. 323 /// 324 /// Returns failure if the payload does not satisfy the conditions associated 325 /// with the type of the handle value. 326 LogicalResult mapBlockArguments(BlockArgument argument, 327 ArrayRef<Operation *> operations) { 328 assert(argument.getParentRegion() == regionStack.back()->region && 329 "mapping block arguments from a region other than the active one"); 330 return setPayloadOps(argument, operations); 331 } 332 LogicalResult mapBlockArgument(BlockArgument argument, 333 ArrayRef<MappedValue> values); 334 LogicalResult mapBlockArguments(Block::BlockArgListType arguments, 335 ArrayRef<SmallVector<MappedValue>> mapping); 336 337 // Forward declarations to support limited visibility. 338 class RegionScope; 339 340 /// Creates a new region scope for the given region. The region is expected to 341 /// be nested in the currently processed region. 342 // Implementation note: this method is inline but implemented outside of the 343 // class body to comply with visibility and full-declaration requirements. 344 inline RegionScope make_region_scope(Region ®ion); 345 346 /// A RAII object maintaining a "stack frame" for a transform IR region. When 347 /// applying a transform IR operation that contains a region, the caller is 348 /// expected to create a RegionScope before applying the ops contained in the 349 /// region. This ensures that the mappings between values defined in the 350 /// transform IR region and payload IR operations are cleared when the region 351 /// processing ends; such values cannot be accessed outside the region. 352 class RegionScope { 353 public: 354 /// Forgets the mapping from or to values defined in the associated 355 /// transform IR region, and restores the mapping that existed before 356 /// entering this scope. 357 ~RegionScope(); 358 359 private: 360 /// Creates a new scope for mappings between values defined in the given 361 /// transform IR region and payload IR objects. 362 RegionScope(TransformState &state, Region ®ion) 363 : state(state), region(®ion) { 364 auto res = state.mappings.insert( 365 std::make_pair(®ion, std::make_unique<Mappings>())); 366 assert(res.second && "the region scope is already present"); 367 (void)res; 368 state.regionStack.push_back(this); 369 } 370 371 /// Back-reference to the transform state. 372 TransformState &state; 373 374 /// The region this scope is associated with. 375 Region *region; 376 377 /// The transform op within this region that is currently being applied. 378 TransformOpInterface currentTransform; 379 380 friend class transform::TransformState; 381 }; 382 friend class RegionScope; 383 384 /// Base class for TransformState extensions that allow TransformState to 385 /// contain user-specified information in the state object. Clients are 386 /// expected to derive this class, add the desired fields, and make the 387 /// derived class compatible with the MLIR TypeID mechanism: 388 /// 389 /// ```mlir 390 /// class MyExtension final : public TransformState::Extension { 391 /// public: 392 /// MyExtension(TranfsormState &state, int myData) 393 /// : Extension(state) {...} 394 /// private: 395 /// int mySupplementaryData; 396 /// }; 397 /// ``` 398 /// 399 /// Instances of this and derived classes are not expected to be created by 400 /// the user, instead they are directly constructed within a TransformState. A 401 /// TransformState can only contain one extension with the given TypeID. 402 /// Extensions can be obtained from a TransformState instance, and can be 403 /// removed when they are no longer required. 404 /// 405 /// ```mlir 406 /// transformState.addExtension<MyExtension>(/*myData=*/42); 407 /// MyExtension *ext = transformState.getExtension<MyExtension>(); 408 /// ext->doSomething(); 409 /// ``` 410 class Extension { 411 // Allow TransformState to allocate Extensions. 412 friend class TransformState; 413 414 public: 415 /// Base virtual destructor. 416 // Out-of-line definition ensures symbols are emitted in a single object 417 // file. 418 virtual ~Extension(); 419 420 protected: 421 /// Constructs an extension of the given TransformState object. 422 Extension(TransformState &state) : state(state) {} 423 424 /// Provides read-only access to the parent TransformState object. 425 const TransformState &getTransformState() const { return state; } 426 427 /// Replaces the given payload op with another op. If the replacement op is 428 /// null, removes the association of the payload op with its handle. Returns 429 /// failure if the op is not associated with any handle. 430 /// 431 /// Note: This function does not update value handles. None of the original 432 /// op's results are allowed to be mapped to any value handle. 433 LogicalResult replacePayloadOp(Operation *op, Operation *replacement); 434 435 /// Replaces the given payload value with another value. If the replacement 436 /// value is null, removes the association of the payload value with its 437 /// handle. Returns failure if the value is not associated with any handle. 438 LogicalResult replacePayloadValue(Value value, Value replacement); 439 440 private: 441 /// Back-reference to the state that is being extended. 442 TransformState &state; 443 }; 444 445 /// Adds a new Extension of the type specified as template parameter, 446 /// constructing it with the arguments provided. The extension is owned by the 447 /// TransformState. It is expected that the state does not already have an 448 /// extension of the same type. Extension constructors are expected to take 449 /// a reference to TransformState as first argument, automatically supplied 450 /// by this call. 451 template <typename Ty, typename... Args> 452 Ty &addExtension(Args &&...args) { 453 static_assert( 454 std::is_base_of<Extension, Ty>::value, 455 "only an class derived from TransformState::Extension is allowed here"); 456 auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...); 457 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr)); 458 assert(result.second && "extension already added"); 459 return *static_cast<Ty *>(result.first->second.get()); 460 } 461 462 /// Returns the extension of the specified type. 463 template <typename Ty> 464 Ty *getExtension() { 465 static_assert( 466 std::is_base_of<Extension, Ty>::value, 467 "only an class derived from TransformState::Extension is allowed here"); 468 auto iter = extensions.find(TypeID::get<Ty>()); 469 if (iter == extensions.end()) 470 return nullptr; 471 return static_cast<Ty *>(iter->second.get()); 472 } 473 474 /// Removes the extension of the specified type. 475 template <typename Ty> 476 void removeExtension() { 477 static_assert( 478 std::is_base_of<Extension, Ty>::value, 479 "only an class derived from TransformState::Extension is allowed here"); 480 extensions.erase(TypeID::get<Ty>()); 481 } 482 483 private: 484 /// Identifier for storing top-level value in the `operations` mapping. 485 static constexpr Value kTopLevelValue = Value(); 486 487 /// Creates a state for transform ops living in the given region. The second 488 /// argument points to the root operation in the payload IR being transformed, 489 /// which may or may not contain the region with transform ops. Additional 490 /// options can be provided through the trailing configuration object. 491 TransformState(Region *region, Operation *payloadRoot, 492 const RaggedArray<MappedValue> &extraMappings = {}, 493 const TransformOptions &options = TransformOptions()); 494 495 /// Returns the mappings frame for the region in which the value is defined. 496 /// If `allowOutOfScope` is set to "false", asserts that the value is in 497 /// scope, based on the current stack of frames. 498 const Mappings &getMapping(Value value, bool allowOutOfScope = false) const { 499 return const_cast<TransformState *>(this)->getMapping(value, 500 allowOutOfScope); 501 } 502 Mappings &getMapping(Value value, bool allowOutOfScope = false) { 503 Region *region = value.getParentRegion(); 504 auto it = mappings.find(region); 505 assert(it != mappings.end() && 506 "trying to find a mapping for a value from an unmapped region"); 507 #ifndef NDEBUG 508 if (!allowOutOfScope) { 509 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { 510 if (r == region) 511 break; 512 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 513 llvm_unreachable("trying to get mapping beyond region that is " 514 "isolated from above"); 515 } 516 } 517 #endif // NDEBUG 518 return *it->second; 519 } 520 521 /// Returns the mappings frame for the region in which the operation resides. 522 /// If `allowOutOfScope` is set to "false", asserts that the operation is in 523 /// scope, based on the current stack of frames. 524 const Mappings &getMapping(Operation *operation, 525 bool allowOutOfScope = false) const { 526 return const_cast<TransformState *>(this)->getMapping(operation, 527 allowOutOfScope); 528 } 529 Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) { 530 Region *region = operation->getParentRegion(); 531 auto it = mappings.find(region); 532 assert(it != mappings.end() && 533 "trying to find a mapping for an operation from an unmapped region"); 534 #ifndef NDEBUG 535 if (!allowOutOfScope) { 536 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { 537 if (r == region) 538 break; 539 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 540 llvm_unreachable("trying to get mapping beyond region that is " 541 "isolated from above"); 542 } 543 } 544 #endif // NDEBUG 545 return *it->second; 546 } 547 548 /// Updates the state to include the associations between op results and the 549 /// provided result of applying a transform op. 550 LogicalResult updateStateFromResults(const TransformResults &results, 551 ResultRange opResults); 552 553 /// Returns a list of all ops that the given transform IR value corresponds 554 /// to. In case an op was erased, the returned list contains nullptr. This 555 /// function is helpful for transformations that apply to a particular handle. 556 ArrayRef<Operation *> getPayloadOpsView(Value value) const; 557 558 /// Returns a list of payload IR values that the given transform IR value 559 /// corresponds to. 560 ArrayRef<Value> getPayloadValuesView(Value handleValue) const; 561 562 /// Sets the payload IR ops associated with the given transform IR value 563 /// (handle). A payload op may be associated multiple handles as long as 564 /// at most one of them gets consumed by further transformations. 565 /// For example, a hypothetical "find function by name" may be called twice in 566 /// a row to produce two handles pointing to the same function: 567 /// 568 /// %0 = transform.find_func_by_name { name = "myfunc" } 569 /// %1 = transform.find_func_by_name { name = "myfunc" } 570 /// 571 /// which is valid by itself. However, calling a hypothetical "rewrite and 572 /// rename function" transform on both handles: 573 /// 574 /// transform.rewrite_and_rename %0 { new_name = "func" } 575 /// transform.rewrite_and_rename %1 { new_name = "func" } 576 /// 577 /// is invalid given the transformation "consumes" the handle as expressed 578 /// by side effects. Practically, a transformation consuming a handle means 579 /// that the associated payload operation may no longer exist. 580 /// 581 /// Similarly, operation handles may be invalidate and should not be used 582 /// after a transform that consumed a value handle pointing to a payload value 583 /// defined by the operation as either block argument or op result. For 584 /// example, in the following sequence, the last transform operation rewrites 585 /// the callee to not return a specified result: 586 /// 587 /// %0 = transform.find_call "myfunc" 588 /// %1 = transform.find_results_of_calling "myfunc" 589 /// transform.drop_call_result_from_signature %1[0] 590 /// 591 /// which requires the call operations to be recreated. Therefore, the handle 592 /// %0 becomes associated with a dangling pointer and should not be used. 593 /// 594 /// Returns failure if the payload does not satisfy the conditions associated 595 /// with the type of the handle value. The value is expected to have a type 596 /// implementing TransformHandleTypeInterface. 597 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets); 598 599 /// Sets the payload IR values association with the given transform IR value 600 /// (handle). A payload value may be associated with multiple handles as long 601 /// as at most one of them is consumed by further transformations. For 602 /// example, a hypothetical "get results of calls to function with the given 603 /// name" transform may be performed twice in a row producing handles pointing 604 /// to the same values: 605 /// 606 /// %0 = transform.find_results_of_calling "myfunc" 607 /// %1 = transform.find_results_of_calling "myfunc" 608 /// 609 /// which is valid by itself. However, calling a hypothetical "erase value 610 /// producer" transform on both handles: 611 /// 612 /// transform.erase_value_produce %0 613 /// transform.erase_value_produce %1 614 /// 615 /// is invalid provided the transformation "consumes" the handle as expressed 616 /// by side effects (which themselves reflect the semantics of the transform 617 /// erasing the producer and making the handle dangling). Practically, a 618 /// transformation consuming a handle means the associated payload value may 619 /// no longer exist. 620 /// 621 /// Similarly, value handles are invalidated and should not be used after a 622 /// transform that consumed an operation handle pointing to the payload IR 623 /// operation defining the values associated the value handle, as either block 624 /// arguments or op results, or any ancestor operation. For example, 625 /// 626 /// %0 = transform.find_call "myfunc" 627 /// %1 = transform.find_results_of_calling "myfunc" 628 /// transform.rewrite_and_rename %0 { new_name = "func" } 629 /// 630 /// makes %1 unusable after the last transformation if it consumes %0. When an 631 /// operation handle is consumed, it usually indicates that the operation was 632 /// destroyed or heavily modified, meaning that the values it defines may no 633 /// longer exist. 634 /// 635 /// Returns failure if the payload values do not satisfy the conditions 636 /// associated with the type of the handle value. The value is expected to 637 /// have a type implementing TransformValueHandleTypeInterface. 638 LogicalResult setPayloadValues(Value handle, ValueRange payloadValues); 639 640 /// Sets the parameters associated with the given transform IR value. Returns 641 /// failure if the parameters do not satisfy the conditions associated with 642 /// the type of the value. The value is expected to have a type implementing 643 /// TransformParamTypeInterface. 644 LogicalResult setParams(Value value, ArrayRef<Param> params); 645 646 /// Forgets the payload IR ops associated with the given transform IR value, 647 /// as well as any association between value handles and the results of said 648 /// payload IR op. 649 /// 650 /// If `allowOutOfScope` is set to "false", asserts that the handle is in 651 /// scope, based on the current stack of frames. 652 void forgetMapping(Value opHandle, ValueRange origOpFlatResults, 653 bool allowOutOfScope = false); 654 655 void forgetValueMapping(Value valueHandle, 656 ArrayRef<Operation *> payloadOperations); 657 658 /// Replaces the given payload op with another op. If the replacement op is 659 /// null, removes the association of the payload op with its handle. Returns 660 /// failure if the op is not associated with any handle. 661 /// 662 /// Note: This function does not update value handles. None of the original 663 /// op's results are allowed to be mapped to any value handle. 664 LogicalResult replacePayloadOp(Operation *op, Operation *replacement); 665 666 /// Replaces the given payload value with another value. If the replacement 667 /// value is null, removes the association of the payload value with its 668 /// handle. Returns failure if the value is not associated with any handle. 669 LogicalResult replacePayloadValue(Value value, Value replacement); 670 671 /// Records handle invalidation reporters into `newlyInvalidated`. 672 /// Specifically, 673 /// - `handle` is the op operand that consumes the handle, 674 /// - `potentialAncestors` is a list of ancestors of the payload operation 675 /// that the consumed handle is associated with, including itself, 676 /// - `throughValue` is the payload value the handle to which is consumed, 677 /// when it is the case, null when the operation handle is consumed 678 /// directly. 679 /// Iterates over all known operation and value handles and records reporters 680 /// for any potential future use of `handle` or any other handle that is 681 /// invalidated by its consumption, i.e., any handle pointing to any payload 682 /// IR entity (operation or value) associated with the same payload IR entity 683 /// as the consumed handle, or any nested payload IR entity. If 684 /// `potentialAncestors` is empty, records the reporter anyway. Does not 685 /// override existing reporters. This must remain a const method so it doesn't 686 /// inadvertently mutate `invalidatedHandles` too early. 687 void recordOpHandleInvalidation(OpOperand &consumingHandle, 688 ArrayRef<Operation *> potentialAncestors, 689 Value throughValue, 690 InvalidatedHandleMap &newlyInvalidated) const; 691 692 /// Records handle invalidation reporters into `newlyInvalidated`. 693 /// Specifically, 694 /// - `consumingHandle` is the op operand that consumes the handle, 695 /// - `potentialAncestors` is a list of ancestors of the payload operation 696 /// that the consumed handle is associated with, including itself, 697 /// - `payloadOp` is the operation itself, 698 /// - `otherHandle` is another that may be associated with the affected 699 /// payload operations 700 /// - `throughValue` is the payload value the handle to which is consumed, 701 /// when it is the case, null when the operation handle is consumed 702 /// directly. 703 /// Looks at the payload opreations associated with `otherHandle` and if any 704 /// of these operations has an ancestor (or is itself) listed in 705 /// `potentialAncestors`, records the error message describing the use of the 706 /// invalidated handle. Does nothing if `otherHandle` already has a reporter 707 /// associated with it. This must remain a const method so it doesn't 708 /// inadvertently mutate `invalidatedHandles` too early. 709 void recordOpHandleInvalidationOne( 710 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors, 711 Operation *payloadOp, Value otherHandle, Value throughValue, 712 InvalidatedHandleMap &newlyInvalidated) const; 713 714 /// Records handle invalidation reporters into `newlyInvalidated`. 715 /// Specifically, 716 /// - `opHandle` is the op operand that consumes the handle; 717 /// - `potentialAncestors` is a list of ancestors of the payload operation 718 /// that the consumed handle is associated with, including itself; 719 /// - `payloadValue` is the value defined by the operation associated with 720 /// the consuming handle as either op result or block argument; 721 /// - `valueHandle` is another that may be associated with the payload value. 722 /// Looks at the payload values associated with `valueHandle` and if any of 723 /// these values is defined, as op result or block argument, by an operation 724 /// whose ancestor (or the operation itself) is listed in 725 /// `potentialAncestors`, records the error message describing the use of the 726 /// invalidated handle. Does nothing if `valueHandle` already has a reporter 727 /// associated with it. This must remain a const method so it doesn't 728 /// inadvertently mutate `invalidatedHandles` too early. 729 void recordValueHandleInvalidationByOpHandleOne( 730 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors, 731 Value payloadValue, Value valueHandle, 732 InvalidatedHandleMap &newlyInvalidated) const; 733 734 /// Records handle invalidation reporters into `newlyInvalidated`. 735 /// Specifically, 736 /// - `valueHandle` is the op operand that consumes the handle, 737 /// - `throughValue` is the payload value the handle to which is consumed, 738 /// when it is the case, null when the operation handle is consumed 739 /// directly. 740 /// Iterates over all known operation and value handles and records reporters 741 /// for any potential future use of `handle` or any other handle that is 742 /// invalidated by its consumption, i.e., any handle pointing to any payload 743 /// IR entity (operation or value) associated with the same payload IR entity 744 /// as the consumed handle, or any nested payload IR entity. Does not override 745 /// existing reporters. This must remain a const method so it doesn't 746 /// inadvertently mutate `invalidatedHandles` too early. 747 void 748 recordValueHandleInvalidation(OpOperand &valueHandle, 749 InvalidatedHandleMap &newlyInvalidated) const; 750 751 /// Checks that the operation does not use invalidated handles as operands. 752 /// Reports errors and returns failure if it does. Otherwise, invalidates the 753 /// handles consumed by the operation as well as any handles pointing to 754 /// payload IR operations nested in the operations associated with the 755 /// consumed handles. 756 LogicalResult 757 checkAndRecordHandleInvalidation(TransformOpInterface transform); 758 759 /// Implementation of the checkAndRecordHandleInvalidation. This must remain a 760 /// const method so it doesn't inadvertently mutate `invalidatedHandles` too 761 /// early. 762 LogicalResult checkAndRecordHandleInvalidationImpl( 763 transform::TransformOpInterface transform, 764 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const; 765 766 /// Remove all nullptrs from op handles that were added by `replacePayloadOp`. 767 void compactOpHandles(); 768 769 /// A stack of mappings between transform IR values and payload IR ops, 770 /// aggregated by the region in which the transform IR values are defined. 771 /// We use a pointer to the Mappings struct so that reallocations inside 772 /// MapVector don't invalidate iterators when we apply nested transform ops 773 /// while also iterating over the mappings. 774 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings; 775 776 /// Op handles may be temporarily mapped to nullptr to avoid invalidating 777 /// payload op iterators. This set contains all op handles with nullptrs. 778 /// These handles are "compacted" (i.e., nullptrs removed) at the end of each 779 /// transform. 780 DenseSet<Value> opHandlesToCompact; 781 782 /// Extensions attached to the TransformState, identified by the TypeID of 783 /// their type. Only one extension of any given type is allowed. 784 DenseMap<TypeID, std::unique_ptr<Extension>> extensions; 785 786 /// The top-level operation that contains all payload IR, typically a module. 787 Operation *topLevel; 788 789 /// Extra mapped values (payload operations, values or parameters) to be 790 /// associated with additional entry block arguments of the top-level 791 /// transform operation. 792 RaggedArray<MappedValue> topLevelMappedValues; 793 794 /// Additional options controlling the transformation state behavior. 795 TransformOptions options; 796 797 /// The mapping from invalidated handles to the error-reporting functions that 798 /// describe when the handles were invalidated. Calling such a function emits 799 /// a user-visible diagnostic with an additional note pointing to the given 800 /// location. 801 InvalidatedHandleMap invalidatedHandles; 802 803 /// A stack of nested regions that are being processed in the transform IR. 804 /// Each region must be an ancestor of the following regions in this list. 805 /// These are also the keys for "mappings". 806 SmallVector<RegionScope *> regionStack; 807 808 /// The top-level region scope. The first (bottom) element of `regionStack` 809 /// is the top-level region scope object. 810 std::unique_ptr<RegionScope> topLevelRegionScope; 811 }; 812 813 /// Local mapping between values defined by a specific op implementing the 814 /// TransformOpInterface and the payload IR ops they correspond to. 815 class TransformResults { 816 friend class TransformState; 817 818 public: 819 /// Indicates that the result of the transform IR op at the given position 820 /// corresponds to the given list of payload IR ops. Each result must be set 821 /// by the transformation exactly once in case of transformation succeeding. 822 /// The value must have a type implementing TransformHandleTypeInterface. 823 template <typename Range> 824 void set(OpResult value, Range &&ops) { 825 int64_t position = value.getResultNumber(); 826 assert(position < static_cast<int64_t>(operations.size()) && 827 "setting results for a non-existent handle"); 828 assert(operations[position].data() == nullptr && "results already set"); 829 assert(params[position].data() == nullptr && 830 "another kind of results already set"); 831 assert(values[position].data() == nullptr && 832 "another kind of results already set"); 833 operations.replace(position, std::forward<Range>(ops)); 834 } 835 836 /// Indicates that the result of the transform IR op at the given position 837 /// corresponds to the given list of payload IR ops. Each result must be set 838 /// by the transformation exactly once in case of transformation succeeding. 839 /// The value must have a type implementing TransformHandleTypeInterface. 840 void set(OpResult value, std::initializer_list<Operation *> ops) { 841 set(value, ArrayRef<Operation *>(ops)); 842 } 843 844 /// Indicates that the result of the transform IR op at the given position 845 /// corresponds to the given list of parameters. Each result must be set by 846 /// the transformation exactly once in case of transformation succeeding. The 847 /// value must have a type implementing TransformParamTypeInterface. 848 void setParams(OpResult value, ArrayRef<TransformState::Param> params); 849 850 /// Indicates that the result of the transform IR op at the given position 851 /// corresponds to the given range of payload IR values. Each result must be 852 /// set by the transformation exactly once in case of transformation 853 /// succeeding. The value must have a type implementing 854 /// TransformValueHandleTypeInterface. 855 template <typename Range> 856 void setValues(OpResult handle, Range &&values) { 857 int64_t position = handle.getResultNumber(); 858 assert(position < static_cast<int64_t>(this->values.size()) && 859 "setting values for a non-existent handle"); 860 assert(this->values[position].data() == nullptr && "values already set"); 861 assert(operations[position].data() == nullptr && 862 "another kind of results already set"); 863 assert(params[position].data() == nullptr && 864 "another kind of results already set"); 865 this->values.replace(position, std::forward<Range>(values)); 866 } 867 868 /// Indicates that the result of the transform IR op at the given position 869 /// corresponds to the given range of payload IR values. Each result must be 870 /// set by the transformation exactly once in case of transformation 871 /// succeeding. The value must have a type implementing 872 /// TransformValueHandleTypeInterface. 873 void setValues(OpResult handle, std::initializer_list<Value> values) { 874 setValues(handle, ArrayRef<Value>(values)); 875 } 876 877 /// Indicates that the result of the transform IR op at the given position 878 /// corresponds to the given range of mapped values. All mapped values are 879 /// expected to be compatible with the type of the result, e.g., if the result 880 /// is an operation handle, all mapped values are expected to be payload 881 /// operations. 882 void setMappedValues(OpResult handle, ArrayRef<MappedValue> values); 883 884 /// Sets the currently unset results to empty lists of the kind expected by 885 /// the corresponding results of the given `transform` op. 886 void setRemainingToEmpty(TransformOpInterface transform); 887 888 private: 889 /// Creates an instance of TransformResults that expects mappings for 890 /// `numSegments` values, which may be associated with payload operations or 891 /// parameters. 892 explicit TransformResults(unsigned numSegments); 893 894 /// Gets the list of operations associated with the result identified by its 895 /// number in the list of operation results. The result must have been set to 896 /// be associated with payload IR operations. 897 ArrayRef<Operation *> get(unsigned resultNumber) const; 898 899 /// Gets the list of parameters associated with the result identified by its 900 /// number in the list of operation results. The result must have been set to 901 /// be associated with parameters. 902 ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const; 903 904 /// Gets the list of payload IR values associated with the result identified 905 /// by its number in the list of operation results. The result must have been 906 /// set to be associated with payload IR values. 907 ArrayRef<Value> getValues(unsigned resultNumber) const; 908 909 /// Returns `true` if the result identified by its number in the list of 910 /// operation results is associated with a list of parameters, `false` 911 /// otherwise. 912 bool isParam(unsigned resultNumber) const; 913 914 /// Returns `true` if the result identified by its number in the list of 915 /// operation results is associated with a list of payload IR value, `false` 916 /// otherwise. 917 bool isValue(unsigned resultNumber) const; 918 919 /// Returns `true` if the result identified by its number in the list of 920 /// operation results is associated with something. 921 bool isSet(unsigned resultNumber) const; 922 923 /// Pointers to payload IR ops that are associated with results of a transform 924 /// IR op. 925 RaggedArray<Operation *> operations; 926 927 /// Parameters that are associated with results of the transform IR op. 928 RaggedArray<Param> params; 929 930 /// Payload IR values that are associated with results of a transform IR op. 931 RaggedArray<Value> values; 932 }; 933 934 /// Creates a RAII object the lifetime of which corresponds to the new mapping 935 /// for transform IR values defined in the given region. Values defined in 936 /// surrounding regions remain accessible. 937 TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { 938 return RegionScope(*this, region); 939 } 940 941 /// A configuration object for customizing a `TrackingListener`. 942 struct TrackingListenerConfig { 943 using SkipHandleFn = std::function<bool(Value)>; 944 945 /// An optional function that returns "true" for handles that do not have to 946 /// be updated. These are typically dead or consumed handles. 947 SkipHandleFn skipHandleFn = nullptr; 948 949 /// If set to "true", the name of a replacement op must match the name of the 950 /// original op. If set to "false", the names of the payload ops tracked in a 951 /// handle may change as the tracking listener updates the transform state. 952 bool requireMatchingReplacementOpName = true; 953 954 /// If set to "true", cast ops (that implement the CastOpInterface) are 955 /// skipped and the replacement op search continues with the operands of the 956 /// cast op. 957 bool skipCastOps = true; 958 }; 959 960 /// A listener that updates a TransformState based on IR modifications. This 961 /// listener can be used during a greedy pattern rewrite to keep the transform 962 /// state up-to-date. 963 class TrackingListener : public RewriterBase::Listener, 964 public TransformState::Extension { 965 public: 966 /// Create a new TrackingListener for usage in the specified transform op. 967 /// Optionally, a function can be specified to identify handles that should 968 /// do not have to be updated. 969 TrackingListener(TransformState &state, TransformOpInterface op, 970 TrackingListenerConfig config = TrackingListenerConfig()); 971 972 protected: 973 /// Return a replacement payload op for the given op, which is going to be 974 /// replaced with the given values. By default, if all values are defined by 975 /// the same op, which also has the same type as the given op, that defining 976 /// op is used as a replacement. 977 /// 978 /// A "failure" return value indicates that no replacement operation could be 979 /// found. A "nullptr" return value indicates that no replacement op is needed 980 /// (e.g., handle is dead or was consumed) and that the payload op should 981 /// be dropped from the mapping. 982 /// 983 /// Example: A tracked "linalg.generic" with two results is replaced with two 984 /// values defined by (another) "linalg.generic". It is reasonable to assume 985 /// that the replacement "linalg.generic" represents the same "computation". 986 /// Therefore, the payload op mapping is updated to the defining op of the 987 /// replacement values. 988 /// 989 /// Counter Example: A "linalg.generic" is replaced with values defined by an 990 /// "scf.for". Without further investigation, the relationship between the 991 /// "linalg.generic" and the "scf.for" is unclear. They may not represent the 992 /// same computation; e.g., there may be tiled "linalg.generic" inside the 993 /// loop body that represents the original computation. Therefore, the 994 /// TrackingListener is conservative by default: it drops the mapping and 995 /// triggers the "payload replacement not found" notification. This default 996 /// behavior can be customized in `TrackingListenerConfig`. 997 /// 998 /// If no replacement op could be found according to the rules mentioned 999 /// above, this function tries to skip over cast-like ops that implement 1000 /// `CastOpInterface`. 1001 /// 1002 /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", 1003 /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is 1004 /// reasonable to assume that the wrapped "linalg.generic" represents the same 1005 /// computation as the original "linalg.generic". The mapping is updated 1006 /// accordingly. 1007 /// 1008 /// Certain ops (typically also metadata-only ops) are not considered casts, 1009 /// but should be skipped nonetheless. Such ops should implement 1010 /// `FindPayloadReplacementOpInterface` to specify with which operands the 1011 /// lookup should continue. 1012 /// 1013 /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", 1014 /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but 1015 /// not cast. (Implementing `CastOpInterface` would be incorrect and cause 1016 /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface` 1017 /// implementation, the replacement op lookup continues with the wrapped 1018 /// "linalg.generic" and the mapping is updated accordingly. 1019 /// 1020 /// Derived classes may override `findReplacementOp` to specify custom 1021 /// replacement rules. 1022 virtual DiagnosedSilenceableFailure 1023 findReplacementOp(Operation *&result, Operation *op, 1024 ValueRange newValues) const; 1025 1026 /// Notify the listener that the pattern failed to match the given operation, 1027 /// and provide a callback to populate a diagnostic with the reason why the 1028 /// failure occurred. 1029 void 1030 notifyMatchFailure(Location loc, 1031 function_ref<void(Diagnostic &)> reasonCallback) override; 1032 1033 /// This function is called when a tracked payload op is dropped because no 1034 /// replacement op was found. Derived classes can implement this function for 1035 /// custom error handling. 1036 virtual void 1037 notifyPayloadReplacementNotFound(Operation *op, ValueRange values, 1038 DiagnosedSilenceableFailure &&diag) {} 1039 1040 /// Return the single op that defines all given values (if any). 1041 static Operation *getCommonDefiningOp(ValueRange values); 1042 1043 /// Return the transform op in which this TrackingListener is used. 1044 TransformOpInterface getTransformOp() const { return transformOp; } 1045 1046 private: 1047 friend class TransformRewriter; 1048 1049 void notifyOperationErased(Operation *op) override; 1050 1051 void notifyOperationReplaced(Operation *op, ValueRange newValues) override; 1052 using Listener::notifyOperationReplaced; 1053 1054 /// The transform op in which this TrackingListener is used. 1055 TransformOpInterface transformOp; 1056 1057 /// The handles that are consumed by the transform op. 1058 DenseSet<Value> consumedHandles; 1059 1060 /// Tracking listener configuration. 1061 TrackingListenerConfig config; 1062 }; 1063 1064 /// A specialized listener that keeps track of cases in which no replacement 1065 /// payload could be found. The error state of this listener must be checked 1066 /// before the end of its lifetime. 1067 class ErrorCheckingTrackingListener : public TrackingListener { 1068 public: 1069 using transform::TrackingListener::TrackingListener; 1070 1071 ~ErrorCheckingTrackingListener() override; 1072 1073 /// Check and return the current error state of this listener. Afterwards, 1074 /// resets the error state to "success". 1075 DiagnosedSilenceableFailure checkAndResetError(); 1076 1077 /// Return "true" if this tracking listener had a failure. 1078 bool failed() const; 1079 1080 protected: 1081 void 1082 notifyPayloadReplacementNotFound(Operation *op, ValueRange values, 1083 DiagnosedSilenceableFailure &&diag) override; 1084 1085 private: 1086 /// The error state of this listener. "Success" indicates that no error 1087 /// happened so far. 1088 DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); 1089 1090 /// The number of errors that have been encountered. 1091 int64_t errorCounter = 0; 1092 }; 1093 1094 /// This is a special rewriter to be used in transform op implementations, 1095 /// providing additional helper functions to update the transform state, etc. 1096 // TODO: Helper functions will be added in a subsequent change. 1097 class TransformRewriter : public RewriterBase { 1098 protected: 1099 friend class TransformState; 1100 1101 /// Create a new TransformRewriter. 1102 explicit TransformRewriter(MLIRContext *ctx, 1103 ErrorCheckingTrackingListener *listener); 1104 1105 public: 1106 /// Return "true" if the tracking listener had failures. 1107 bool hasTrackingFailures() const; 1108 1109 /// Silence all tracking failures that have been encountered so far. 1110 void silenceTrackingFailure(); 1111 1112 /// Notify the transform dialect interpreter that the given op has been 1113 /// replaced with another op and that the mapping between handles and payload 1114 /// ops/values should be updated. This function should be called before the 1115 /// original op is erased. It fails if the operation could not be replaced, 1116 /// e.g., because the original operation is not tracked. 1117 /// 1118 /// Note: As long as IR modifications are performed through this rewriter, 1119 /// the transform state is usually updated automatically. This function should 1120 /// be used when unsupported rewriter API is used; e.g., updating all uses of 1121 /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`. 1122 LogicalResult notifyPayloadOperationReplaced(Operation *op, 1123 Operation *replacement); 1124 1125 private: 1126 ErrorCheckingTrackingListener *const listener; 1127 }; 1128 1129 /// This trait is supposed to be attached to Transform dialect operations that 1130 /// can be standalone top-level transforms. Such operations typically contain 1131 /// other Transform dialect operations that can be executed following some 1132 /// control flow logic specific to the current operation. The operations with 1133 /// this trait are expected to have at least one single-block region with at 1134 /// least one argument of type implementing TransformHandleTypeInterface. The 1135 /// operations are also expected to be valid without operands, in which case 1136 /// they are considered top-level, and with one or more arguments, in which case 1137 /// they are considered nested. Top-level operations have the block argument of 1138 /// the entry block in the Transform IR correspond to the root operation of 1139 /// Payload IR. Nested operations have the block argument of the entry block in 1140 /// the Transform IR correspond to a list of Payload IR operations mapped to the 1141 /// first operand of the Transform IR operation. The operation must implement 1142 /// TransformOpInterface. 1143 template <typename OpTy> 1144 class PossibleTopLevelTransformOpTrait 1145 : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> { 1146 public: 1147 /// Verifies that `op` satisfies the invariants of this trait. Not expected to 1148 /// be called directly. 1149 static LogicalResult verifyTrait(Operation *op) { 1150 return detail::verifyPossibleTopLevelTransformOpTrait(op); 1151 } 1152 1153 /// Returns the single block of the given region. 1154 Block *getBodyBlock(unsigned region = 0) { 1155 return &this->getOperation()->getRegion(region).front(); 1156 } 1157 1158 /// Populates `effects` with side effects implied by this trait. 1159 void getPotentialTopLevelEffects( 1160 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1161 detail::getPotentialTopLevelEffects( 1162 this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(), 1163 *getBodyBlock(), effects); 1164 } 1165 1166 /// Sets up the mapping between the entry block of the given region of this op 1167 /// and the relevant list of Payload IR operations in the given state. The 1168 /// state is expected to be already scoped at the region of this operation. 1169 LogicalResult mapBlockArguments(TransformState &state, Region ®ion) { 1170 assert(region.getParentOp() == this->getOperation() && 1171 "op comes from the wrong region"); 1172 return detail::mapPossibleTopLevelTransformOpBlockArguments( 1173 state, this->getOperation(), region); 1174 } 1175 LogicalResult mapBlockArguments(TransformState &state) { 1176 assert( 1177 this->getOperation()->getNumRegions() == 1 && 1178 "must indicate the region to map if the operation has more than one"); 1179 return mapBlockArguments(state, this->getOperation()->getRegion(0)); 1180 } 1181 }; 1182 1183 class ApplyToEachResultList; 1184 1185 /// Trait implementing the TransformOpInterface for operations applying a 1186 /// transformation to a single operation handle and producing an arbitrary 1187 /// number of handles and parameter values. 1188 /// The op must implement a method with the following signature: 1189 /// - DiagnosedSilenceableFailure applyToOne(OpTy, 1190 /// ApplyToEachResultList &results, TransformState &state) 1191 /// to perform a transformation that is applied in turn to all payload IR 1192 /// operations that correspond to the handle of the transform IR operation. 1193 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class 1194 /// that the transformation is applied to (and NOT the class of the transform IR 1195 /// op). 1196 /// The `applyToOne` method takes an empty `results` vector that it fills with 1197 /// zero, one or multiple operations depending on the number of results expected 1198 /// by the transform op. 1199 /// The number of results must match the number of results of the transform op. 1200 /// `applyToOne` is allowed to fill the `results` with all null elements to 1201 /// signify that the transformation did not apply to the payload IR operations. 1202 /// Such null elements are filtered out from results before return. 1203 /// 1204 /// The transform op having this trait is expected to have a single operand. 1205 template <typename OpTy> 1206 class TransformEachOpTrait 1207 : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> { 1208 public: 1209 /// Calls `applyToOne` for every payload operation associated with the operand 1210 /// of this transform IR op, the following case disjunction happens: 1211 /// 1. If not target payload ops are associated to the operand then fill the 1212 /// results vector with the expected number of null elements and return 1213 /// success. This is the corner case handling that allows propagating 1214 /// the "no-op" case gracefully to improve usability. 1215 /// 2. If any `applyToOne` returns definiteFailure, the transformation is 1216 /// immediately considered definitely failed and we return. 1217 /// 3. All applications of `applyToOne` are checked to return a number of 1218 /// results expected by the transform IR op. If not, this is a definite 1219 /// failure and we return early. 1220 /// 4. If `applyToOne` produces ops, associate them with the result of this 1221 /// transform op. 1222 /// 5. If any `applyToOne` return silenceableFailure, the transformation is 1223 /// considered silenceable. 1224 /// 6. Otherwise the transformation is considered successful. 1225 DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, 1226 TransformResults &transformResults, 1227 TransformState &state); 1228 1229 /// Checks that the op matches the expectations of this trait. 1230 static LogicalResult verifyTrait(Operation *op); 1231 }; 1232 1233 /// Side effect resource corresponding to the mapping between Transform IR 1234 /// values and Payload IR operations. An Allocate effect from this resource 1235 /// means creating a new mapping entry, it is always accompanied by a Write 1236 /// effect. A Read effect from this resource means accessing the mapping. A Free 1237 /// effect on this resource indicates the removal of the mapping entry, 1238 /// typically after a transformation that modifies the Payload IR operations 1239 /// associated with one of the Transform IR operation's operands. It is always 1240 /// accompanied by a Read effect. Read-after-Free and double-Free are not 1241 /// allowed (they would be problematic with "regular" memory effects too) as 1242 /// they indicate an attempt to access Payload IR operations that have been 1243 /// modified, potentially erased, by the previous transformations. 1244 // TODO: consider custom effects if these are not enabling generic passes such 1245 // as CSE/DCE to work. 1246 struct TransformMappingResource 1247 : public SideEffects::Resource::Base<TransformMappingResource> { 1248 StringRef getName() override { return "transform.mapping"; } 1249 }; 1250 1251 /// Side effect resource corresponding to the Payload IR itself. Only Read and 1252 /// Write effects are expected on this resource, with Write always accompanied 1253 /// by a Read (short of fully replacing the top-level Payload IR operation, one 1254 /// cannot modify the Payload IR without reading it first). This is intended 1255 /// to disallow reordering of Transform IR operations that mutate the Payload IR 1256 /// while still allowing the reordering of those that only access it. 1257 struct PayloadIRResource 1258 : public SideEffects::Resource::Base<PayloadIRResource> { 1259 StringRef getName() override { return "transform.payload_ir"; } 1260 }; 1261 1262 /// Populates `effects` with the memory effects indicating the operation on the 1263 /// given handle value: 1264 /// - consumes = Read + Free, 1265 /// - produces = Allocate + Write, 1266 /// - onlyReads = Read. 1267 void consumesHandle(MutableArrayRef<OpOperand> handles, 1268 SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 1269 void producesHandle(ResultRange handles, 1270 SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 1271 void producesHandle(MutableArrayRef<BlockArgument> handles, 1272 SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 1273 void onlyReadsHandle(MutableArrayRef<OpOperand> handles, 1274 SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 1275 1276 /// Checks whether the transform op consumes the given handle. 1277 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform); 1278 1279 /// Populates `effects` with the memory effects indicating the access to payload 1280 /// IR resource. 1281 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 1282 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 1283 1284 /// Checks whether the transform op modifies the payload. 1285 bool doesModifyPayload(transform::TransformOpInterface transform); 1286 /// Checks whether the transform op reads the payload. 1287 bool doesReadPayload(transform::TransformOpInterface transform); 1288 1289 /// Populates `consumedArguments` with positions of `block` arguments that are 1290 /// consumed by the operations in the `block`. 1291 void getConsumedBlockArguments( 1292 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments); 1293 1294 /// Trait implementing the MemoryEffectOpInterface for operations that "consume" 1295 /// their operands and produce new results. 1296 template <typename OpTy> 1297 class FunctionalStyleTransformOpTrait 1298 : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> { 1299 public: 1300 /// This op "consumes" the operands by reading and freeing then, "produces" 1301 /// the results by allocating and writing it and reads/writes the payload IR 1302 /// in the process. 1303 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1304 consumesHandle(this->getOperation()->getOpOperands(), effects); 1305 producesHandle(this->getOperation()->getOpResults(), effects); 1306 modifiesPayload(effects); 1307 } 1308 1309 /// Checks that the op matches the expectations of this trait. 1310 static LogicalResult verifyTrait(Operation *op) { 1311 if (!op->getName().getInterface<MemoryEffectOpInterface>()) { 1312 op->emitError() 1313 << "FunctionalStyleTransformOpTrait should only be attached to ops " 1314 "that implement MemoryEffectOpInterface"; 1315 } 1316 return success(); 1317 } 1318 }; 1319 1320 /// Trait implementing the MemoryEffectOpInterface for operations that use their 1321 /// operands without consuming and without modifying the Payload IR to 1322 /// potentially produce new handles. 1323 template <typename OpTy> 1324 class NavigationTransformOpTrait 1325 : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> { 1326 public: 1327 /// This op produces handles to the Payload IR without consuming the original 1328 /// handles and without modifying the IR itself. 1329 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1330 onlyReadsHandle(this->getOperation()->getOpOperands(), effects); 1331 producesHandle(this->getOperation()->getOpResults(), effects); 1332 if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) { 1333 return isa<TransformHandleTypeInterface, 1334 TransformValueHandleTypeInterface>(t); 1335 })) { 1336 onlyReadsPayload(effects); 1337 } 1338 } 1339 1340 /// Checks that the op matches the expectation of this trait. 1341 static LogicalResult verifyTrait(Operation *op) { 1342 if (!op->getName().getInterface<MemoryEffectOpInterface>()) { 1343 op->emitError() << "NavigationTransformOpTrait should only be attached " 1344 "to ops that implement MemoryEffectOpInterface"; 1345 } 1346 return success(); 1347 } 1348 }; 1349 1350 namespace detail { 1351 /// Non-template implementation of ParamProducerTransformOpTrait::getEffects(). 1352 void getParamProducerTransformOpTraitEffects( 1353 Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects); 1354 /// Non-template implementation of ParamProducerTransformOpTrait::verify(). 1355 LogicalResult verifyParamProducerTransformOpTrait(Operation *op); 1356 } // namespace detail 1357 1358 /// Trait implementing the MemoryEffectsOpInterface for operations that produce 1359 /// transform dialect parameters. It marks all op results of 1360 /// TransformHandleTypeInterface as produced by the op, all operands as only 1361 /// read by the op and, if at least one of the operand is a handle to payload 1362 /// ops, the entire payload as potentially read. The op must only produce 1363 /// parameter-typed results. 1364 template <typename OpTy> 1365 class ParamProducerTransformOpTrait 1366 : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> { 1367 public: 1368 /// Populates `effects` with effect instances described in the trait 1369 /// documentation. 1370 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1371 detail::getParamProducerTransformOpTraitEffects(this->getOperation(), 1372 effects); 1373 } 1374 1375 /// Checks that the op matches the expectation of this trait, i.e., that it 1376 /// implements the MemoryEffectsOpInterface and only produces parameter-typed 1377 /// results. 1378 static LogicalResult verifyTrait(Operation *op) { 1379 return detail::verifyParamProducerTransformOpTrait(op); 1380 } 1381 }; 1382 1383 /// `TrackingListener` failures are reported only for ops that have this trait. 1384 /// The purpose of this trait is to give users more time to update their custom 1385 /// transform ops to use the provided `TransformRewriter` for all IR 1386 /// modifications. This trait will eventually be removed, and failures will be 1387 /// reported for all transform ops. 1388 template <typename OpTy> 1389 class ReportTrackingListenerFailuresOpTrait 1390 : public OpTrait::TraitBase<OpTy, ReportTrackingListenerFailuresOpTrait> {}; 1391 1392 /// A single result of applying a transform op with `ApplyEachOpTrait` to a 1393 /// single payload operation. 1394 using ApplyToEachResult = MappedValue; 1395 1396 /// A list of results of applying a transform op with `ApplyEachOpTrait` to a 1397 /// single payload operation, co-indexed with the results of the transform op. 1398 class ApplyToEachResultList { 1399 public: 1400 ApplyToEachResultList() = default; 1401 explicit ApplyToEachResultList(unsigned size) : results(size) {} 1402 1403 /// Sets the list of results to `size` null pointers. 1404 void assign(unsigned size, std::nullptr_t) { results.assign(size, nullptr); } 1405 1406 /// Sets the list of results to the given range of values. 1407 template <typename Range> 1408 void assign(Range &&range) { 1409 // This is roughly the implementation of SmallVectorImpl::assign. 1410 // Dispatching to it with map_range and template type inference would result 1411 // in more complex code here. 1412 results.clear(); 1413 results.reserve(llvm::size(range)); 1414 for (auto element : range) { 1415 if constexpr (std::is_convertible_v<decltype(*std::begin(range)), 1416 Operation *>) { 1417 results.push_back(static_cast<Operation *>(element)); 1418 } else if constexpr (std::is_convertible_v<decltype(*std::begin(range)), 1419 Value>) { 1420 results.push_back(element.template get<Value>()); 1421 } else { 1422 results.push_back(static_cast<Attribute>(element)); 1423 } 1424 } 1425 } 1426 1427 /// Appends an element to the list. 1428 // Using ApplyToEachResult that can be implicitly constructed from a Value but 1429 // not from a concrete Op that is implicitly convertible to a Value to avoid 1430 // ambiguity. 1431 void push_back(Operation *op) { results.push_back(op); } 1432 void push_back(Attribute attr) { results.push_back(attr); } 1433 void push_back(ApplyToEachResult r) { results.push_back(r); } 1434 1435 /// Reserves space for `size` elements in the list. 1436 void reserve(unsigned size) { results.reserve(size); } 1437 1438 /// Iterators over the list. 1439 auto begin() { return results.begin(); } 1440 auto end() { return results.end(); } 1441 auto begin() const { return results.begin(); } 1442 auto end() const { return results.end(); } 1443 1444 /// Returns the number of elements in the list. 1445 size_t size() const { return results.size(); } 1446 1447 /// Element access. Expects the index to be in bounds. 1448 ApplyToEachResult &operator[](size_t index) { return results[index]; } 1449 const ApplyToEachResult &operator[](size_t index) const { 1450 return results[index]; 1451 } 1452 1453 private: 1454 /// Underlying storage. 1455 SmallVector<ApplyToEachResult> results; 1456 }; 1457 1458 namespace detail { 1459 1460 /// Check that the contents of `partialResult` matches the number, kind (payload 1461 /// op or parameter) and nullity (either all or none) requirements of 1462 /// `transformOp`. Report errors and return failure otherwise. 1463 LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, 1464 const ApplyToEachResultList &partialResult); 1465 1466 /// "Transpose" the results produced by individual applications, arranging them 1467 /// per result value of the transform op, and populate `transformResults` with 1468 /// that. The number, kind and nullity of per-application results are assumed to 1469 /// have been verified. 1470 void setApplyToOneResults(Operation *transformOp, 1471 TransformResults &transformResults, 1472 ArrayRef<ApplyToEachResultList> results); 1473 1474 /// Applies a one-to-one or a one-to-many transform to each of the given 1475 /// targets. Puts the results of transforms, if any, in `results` in the same 1476 /// order. Fails if any of the application fails. Individual transforms must be 1477 /// callable with the following signature: 1478 /// - DiagnosedSilenceableFailure(OpTy, 1479 /// SmallVector<Operation*> &results, state) 1480 /// where OpTy is either 1481 /// - Operation *, in which case the transform is always applied; 1482 /// - a concrete Op class, in which case a check is performed whether 1483 /// `targets` contains operations of the same class and a silenceable failure 1484 /// is reported if it does not. 1485 template <typename TransformOpTy, typename Range> 1486 DiagnosedSilenceableFailure applyTransformToEach( 1487 TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, 1488 SmallVectorImpl<ApplyToEachResultList> &results, TransformState &state) { 1489 using OpTy = typename llvm::function_traits< 1490 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>; 1491 static_assert(std::is_convertible<OpTy, Operation *>::value, 1492 "expected transform function to take an operation"); 1493 OpBuilder::InsertionGuard g(rewriter); 1494 1495 SmallVector<Diagnostic> silenceableStack; 1496 unsigned expectedNumResults = transformOp->getNumResults(); 1497 for (Operation *target : targets) { 1498 auto specificOp = dyn_cast<OpTy>(target); 1499 if (!specificOp) { 1500 Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error); 1501 diag << "transform applied to the wrong op kind"; 1502 diag.attachNote(target->getLoc()) << "when applied to this op"; 1503 silenceableStack.push_back(std::move(diag)); 1504 continue; 1505 } 1506 1507 ApplyToEachResultList partialResults; 1508 partialResults.reserve(expectedNumResults); 1509 Location specificOpLoc = specificOp->getLoc(); 1510 rewriter.setInsertionPoint(specificOp); 1511 DiagnosedSilenceableFailure res = 1512 transformOp.applyToOne(rewriter, specificOp, partialResults, state); 1513 if (res.isDefiniteFailure()) 1514 return DiagnosedSilenceableFailure::definiteFailure(); 1515 1516 if (res.isSilenceableFailure()) { 1517 res.takeDiagnostics(silenceableStack); 1518 continue; 1519 } 1520 1521 if (failed(detail::checkApplyToOne(transformOp, specificOpLoc, 1522 partialResults))) { 1523 return DiagnosedSilenceableFailure::definiteFailure(); 1524 } 1525 results.push_back(std::move(partialResults)); 1526 } 1527 if (!silenceableStack.empty()) { 1528 return DiagnosedSilenceableFailure::silenceableFailure( 1529 std::move(silenceableStack)); 1530 } 1531 return DiagnosedSilenceableFailure::success(); 1532 } 1533 1534 /// Reports an error and returns failure if `targets` contains an ancestor 1535 /// operation before its descendant (or a copy of itself). Implementation detail 1536 /// for expensive checks during `TransformEachOpTrait::apply`. 1537 LogicalResult checkNestedConsumption(Location loc, 1538 ArrayRef<Operation *> targets); 1539 1540 } // namespace detail 1541 } // namespace transform 1542 } // namespace mlir 1543 1544 template <typename OpTy> 1545 mlir::DiagnosedSilenceableFailure 1546 mlir::transform::TransformEachOpTrait<OpTy>::apply( 1547 TransformRewriter &rewriter, TransformResults &transformResults, 1548 TransformState &state) { 1549 Value handle = this->getOperation()->getOperand(0); 1550 auto targets = state.getPayloadOps(handle); 1551 1552 // If the operand is consumed, check if it is associated with operations that 1553 // may be erased before their nested operations are. 1554 if (state.getOptions().getExpensiveChecksEnabled() && 1555 isHandleConsumed(handle, cast<transform::TransformOpInterface>( 1556 this->getOperation())) && 1557 failed(detail::checkNestedConsumption(this->getOperation()->getLoc(), 1558 llvm::to_vector(targets)))) { 1559 return DiagnosedSilenceableFailure::definiteFailure(); 1560 } 1561 1562 // Step 1. Handle the corner case where no target is specified. 1563 // This is typically the case when the matcher fails to apply and we need to 1564 // propagate gracefully. 1565 // In this case, we fill all results with an empty vector. 1566 if (std::empty(targets)) { 1567 SmallVector<Operation *> emptyPayload; 1568 SmallVector<Attribute> emptyParams; 1569 for (OpResult r : this->getOperation()->getResults()) { 1570 if (isa<TransformParamTypeInterface>(r.getType())) 1571 transformResults.setParams(r, emptyParams); 1572 else if (isa<TransformValueHandleTypeInterface>(r.getType())) 1573 transformResults.setValues(r, ValueRange()); 1574 else 1575 transformResults.set(r, emptyPayload); 1576 } 1577 return DiagnosedSilenceableFailure::success(); 1578 } 1579 1580 // Step 2. Call applyToOne on each target and record newly produced ops in its 1581 // corresponding results entry. 1582 SmallVector<ApplyToEachResultList, 1> results; 1583 DiagnosedSilenceableFailure result = detail::applyTransformToEach( 1584 cast<OpTy>(this->getOperation()), rewriter, targets, results, state); 1585 1586 // Step 3. Propagate the definite failure if any and bail out. 1587 if (result.isDefiniteFailure()) 1588 return result; 1589 1590 // Step 4. "Transpose" the results produced by individual applications, 1591 // arranging them per result value of the transform op. The number, kind and 1592 // nullity of per-application results have been verified by the callback 1593 // above. 1594 detail::setApplyToOneResults(this->getOperation(), transformResults, results); 1595 1596 // Step 5. ApplyToOne may have returned silenceableFailure, propagate it. 1597 return result; 1598 } 1599 1600 template <typename OpTy> 1601 llvm::LogicalResult 1602 mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) { 1603 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(), 1604 "expected single-operand op"); 1605 if (!op->getName().getInterface<TransformOpInterface>()) { 1606 return op->emitError() << "TransformEachOpTrait should only be attached to " 1607 "ops that implement TransformOpInterface"; 1608 } 1609 1610 return success(); 1611 } 1612 1613 #endif // DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H 1614