xref: /llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h (revision 6c7a3f80e75de36f2642110a077664e948d9e7e3)
1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
11 
12 #include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"
13 #include "mlir/Dialect/Transform/Utils/RaggedArray.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 
19 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
20 
21 namespace mlir {
22 namespace transform {
23 
24 class TransformOpInterface;
25 class TransformResults;
26 class TransformRewriter;
27 class TransformState;
28 
29 using Param = Attribute;
30 using MappedValue = llvm::PointerUnion<Operation *, Param, Value>;
31 
32 namespace detail {
33 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
34 /// to either the list of operations associated with its operand or the root of
35 /// the payload IR, depending on what is available in the context.
36 LogicalResult
37 mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
38                                              Operation *op, Region &region);
39 
40 /// Verification hook for PossibleTopLevelTransformOpTrait.
41 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
42 
43 /// Populates `effects` with side effects implied by
44 /// PossibleTopLevelTransformOpTrait for the given operation. The operation may
45 /// have an optional `root` operand, indicating it is not in fact top-level. It
46 /// is also expected to have a single-block body.
47 void getPotentialTopLevelEffects(
48     Operation *operation, Value root, Block &body,
49     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
50 
51 /// Verification hook for TransformOpInterface.
52 LogicalResult verifyTransformOpInterface(Operation *op);
53 
54 /// Appends the entities associated with the given transform values in `state`
55 /// to the pre-existing list of mappings. The array of mappings must have as
56 /// many elements as values. If `flatten` is set, multiple values may be
57 /// associated with each transform value, and this always succeeds. Otherwise,
58 /// checks that each value has exactly one mapping associated and return failure
59 /// otherwise.
60 LogicalResult appendValueMappings(
61     MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
62     ValueRange values, const transform::TransformState &state,
63     bool flatten = true);
64 
65 /// Populates `mappings` with mapped values associated with the given transform
66 /// IR values in the given `state`.
67 void prepareValueMappings(
68     SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
69     ValueRange values, const transform::TransformState &state);
70 
71 /// Populates `results` with payload associations that match exactly those of
72 /// the operands to `block`'s terminator.
73 void forwardTerminatorOperands(Block *block, transform::TransformState &state,
74                                transform::TransformResults &results);
75 
76 /// Make a dummy transform state for testing purposes. This MUST NOT be used
77 /// outside of test cases.
78 TransformState makeTransformStateForTesting(Region *region,
79                                             Operation *payloadRoot);
80 
81 /// Returns all operands that are handles and being consumed by the given op.
82 SmallVector<OpOperand *>
83 getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
84 } // namespace detail
85 } // namespace transform
86 } // namespace mlir
87 
88 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
89 
90 namespace mlir {
91 namespace transform {
92 
93 /// Options controlling the application of transform operations by the
94 /// TransformState.
95 class TransformOptions {
96 public:
97   TransformOptions() = default;
98   TransformOptions(const TransformOptions &) = default;
99   TransformOptions &operator=(const TransformOptions &) = default;
100 
101   /// Requests computationally expensive checks of the transform and payload IR
102   /// well-formedness to be performed before each transformation. In particular,
103   /// these ensure that the handles still point to valid operations when used.
104   TransformOptions &enableExpensiveChecks(bool enable = true) {
105     expensiveChecksEnabled = enable;
106     return *this;
107   }
108 
109   // Ensures that only a single top-level transform op is present in the IR.
110   TransformOptions &enableEnforceSingleToplevelTransformOp(bool enable = true) {
111     enforceSingleToplevelTransformOp = enable;
112     return *this;
113   }
114 
115   /// Returns true if the expensive checks are requested.
116   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
117 
118   // Returns true if enforcing a single top-level transform op is requested.
119   bool getEnforceSingleToplevelTransformOp() const {
120     return enforceSingleToplevelTransformOp;
121   }
122 
123 private:
124   bool expensiveChecksEnabled = true;
125   bool enforceSingleToplevelTransformOp = true;
126 };
127 
128 /// Entry point to the Transform dialect infrastructure. Applies the
129 /// transformation specified by `transform` to payload IR contained in
130 /// `payloadRoot`. The `transform` operation may contain other operations that
131 /// will be executed following the internal logic of the operation. It must
132 /// have the `PossibleTopLevelTransformOp` trait and not have any operands.
133 /// This function internally keeps track of the transformation state.
134 LogicalResult applyTransforms(
135     Operation *payloadRoot, TransformOpInterface transform,
136     const RaggedArray<MappedValue> &extraMapping = {},
137     const TransformOptions &options = TransformOptions(),
138     bool enforceToplevelTransformOp = true,
139     function_ref<void(TransformState &)> stateInitializer = nullptr,
140     function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
141 
142 /// The state maintained across applications of various ops implementing the
143 /// TransformOpInterface. The operations implementing this interface and the
144 /// surrounding structure are referred to as transform IR. The operations to
145 /// which transformations apply are referred to as payload IR. Transform IR
146 /// operates on values that can be associated either with a list of payload IR
147 /// operations (such values are referred to as handles) or with a list of
148 /// parameters represented as attributes. The state thus contains the mapping
149 /// between values defined in the transform IR ops and either payload IR ops or
150 /// parameters. For payload ops, the mapping is many-to-many and the reverse
151 /// mapping is also stored. The "expensive-checks" option can be passed to the
152 /// constructor at transformation execution time that transform IR values used
153 /// as operands by a transform IR operation are not associated with dangling
154 /// pointers to payload IR operations that are known to have been erased by
155 /// previous transformation through the same or a different transform IR value.
156 ///
157 /// A reference to this class is passed as an argument to "apply" methods of the
158 /// transform op interface. Thus the "apply" method can call either
159 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
160 /// or `state.getParams( getSomeOperand() )` to obtain the list of parameters
161 /// associated with its operand. The method is expected to populate the
162 /// `TransformResults` class instance in order to update the mapping. The
163 /// `applyTransform` method takes care of propagating the state of
164 /// `TransformResults` into the instance of this class.
165 ///
166 /// When applying transform IR operations with regions, the client is expected
167 /// to create a `RegionScope` RAII object to create a new "stack frame" for
168 /// values defined inside the region. The mappings from and to these values will
169 /// be automatically dropped when the object goes out of scope, typically at the
170 /// end of the `apply` function of the parent operation. If a region contains
171 /// blocks with arguments, the client can map those arguments to payload IR ops
172 /// using `mapBlockArguments`.
173 class TransformState {
174 public:
175   using Param = transform::Param;
176 
177 private:
178   /// Mapping between a Value in the transform IR and the corresponding set of
179   /// operations in the payload IR.
180   using TransformOpMapping = DenseMap<Value, SmallVector<Operation *, 2>>;
181 
182   /// Mapping between a payload IR operation and the transform IR values it is
183   /// associated with.
184   using TransformOpReverseMapping =
185       DenseMap<Operation *, SmallVector<Value, 2>>;
186 
187   /// Mapping between a Value in the transform IR and the corresponding list of
188   /// parameters.
189   using ParamMapping = DenseMap<Value, SmallVector<Param>>;
190 
191   /// Mapping between a Value in the transform IR and the corrsponding list of
192   /// values in the payload IR. Also works for reverse mappings.
193   using ValueMapping = DenseMap<Value, SmallVector<Value>>;
194 
195   /// Mapping between a Value in the transform IR and an error message that
196   /// should be emitted when the value is used.
197   using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
198 
199 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
200   /// Debug only: A timestamp is associated with each transform IR value, so
201   /// that invalid iterator usage can be detected more reliably.
202   using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
203 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
204 
205   /// The bidirectional mappings between transform IR values and payload IR
206   /// operations, and the mapping between transform IR values and parameters.
207   struct Mappings {
208     TransformOpMapping direct;
209     TransformOpReverseMapping reverse;
210     ParamMapping params;
211     ValueMapping values;
212     ValueMapping reverseValues;
213 
214 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
215     TransformIRTimestampMapping timestamps;
216     void incrementTimestamp(Value value) { ++timestamps[value]; }
217 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
218   };
219 
220   friend LogicalResult
221   applyTransforms(Operation *, TransformOpInterface,
222                   const RaggedArray<MappedValue> &, const TransformOptions &,
223                   bool, function_ref<void(TransformState &)>,
224                   function_ref<LogicalResult(TransformState &)>);
225 
226   friend TransformState
227   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
228 
229 public:
230   const TransformOptions &getOptions() const { return options; }
231 
232   /// Returns the op at which the transformation state is rooted. This is
233   /// typically helpful for transformations that apply globally.
234   Operation *getTopLevel() const;
235 
236   /// Returns the number of extra mappings for the top-level operation.
237   size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
238 
239   /// Returns the position-th extra mapping for the top-level operation.
240   ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
241     return topLevelMappedValues[position];
242   }
243 
244   /// Returns an iterator that enumerates all ops that the given transform IR
245   /// value corresponds to. Ops may be erased while iterating; erased ops are
246   /// not enumerated. This function is helpful for transformations that apply to
247   /// a particular handle.
248   auto getPayloadOps(Value value) const {
249     ArrayRef<Operation *> view = getPayloadOpsView(value);
250 
251 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
252     // Memorize the current timestamp and make sure that it has not changed
253     // when incrementing or dereferencing the iterator returned by this
254     // function. The timestamp is incremented when the "direct" mapping is
255     // resized; this would invalidate the iterator returned by this function.
256     int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
257 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
258 
259     // When ops are replaced/erased, they are replaced with nullptr (until
260     // the data structure is compacted). Do not enumerate these ops.
261     return llvm::make_filter_range(view, [=](Operation *op) {
262 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
263       [[maybe_unused]] bool sameTimestamp =
264           currentTimestamp == this->getMapping(value).timestamps.lookup(value);
265       assert(sameTimestamp && "iterator was invalidated during iteration");
266 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
267       return op != nullptr;
268     });
269   }
270 
271   /// Returns the list of parameters that the given transform IR value
272   /// corresponds to.
273   ArrayRef<Attribute> getParams(Value value) const;
274 
275   /// Returns an iterator that enumerates all payload IR values that the given
276   /// transform IR value corresponds to.
277   auto getPayloadValues(Value handleValue) const {
278     ArrayRef<Value> view = getPayloadValuesView(handleValue);
279 
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
281     // Memorize the current timestamp and make sure that it has not changed
282     // when incrementing or dereferencing the iterator returned by this
283     // function. The timestamp is incremented when the "values" mapping is
284     // resized; this would invalidate the iterator returned by this function.
285     int64_t currentTimestamp =
286         getMapping(handleValue).timestamps.lookup(handleValue);
287     return llvm::make_filter_range(view, [=](Value v) {
288       [[maybe_unused]] bool sameTimestamp =
289           currentTimestamp ==
290           this->getMapping(handleValue).timestamps.lookup(handleValue);
291       assert(sameTimestamp && "iterator was invalidated during iteration");
292       return true;
293     });
294 #else
295     return llvm::make_range(view.begin(), view.end());
296 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
297   }
298 
299   /// Populates `handles` with all handles pointing to the given Payload IR op.
300   /// Returns success if such handles exist, failure otherwise.
301   /// If `includeOutOfScope` is set to "true", handles that are defined in
302   /// regions beyond the most recent isolated from above region are included.
303   LogicalResult getHandlesForPayloadOp(Operation *op,
304                                        SmallVectorImpl<Value> &handles,
305                                        bool includeOutOfScope = false) const;
306 
307   /// Populates `handles` with all handles pointing to the given payload IR
308   /// value. Returns success if such handles exist, failure otherwise.
309   /// If `includeOutOfScope` is set to "true", handles that are defined in
310   /// regions beyond the most recent isolated from above region are included.
311   LogicalResult getHandlesForPayloadValue(Value payloadValue,
312                                           SmallVectorImpl<Value> &handles,
313                                           bool includeOutOfScope = false) const;
314 
315   /// Applies the transformation specified by the given transform op and updates
316   /// the state accordingly.
317   DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
318 
319   /// Records the mapping between a block argument in the transform IR and a
320   /// list of operations in the payload IR. The arguments must be defined in
321   /// blocks of the currently processed transform IR region, typically after a
322   /// region scope is defined.
323   ///
324   /// Returns failure if the payload does not satisfy the conditions associated
325   /// with the type of the handle value.
326   LogicalResult mapBlockArguments(BlockArgument argument,
327                                   ArrayRef<Operation *> operations) {
328     assert(argument.getParentRegion() == regionStack.back()->region &&
329            "mapping block arguments from a region other than the active one");
330     return setPayloadOps(argument, operations);
331   }
332   LogicalResult mapBlockArgument(BlockArgument argument,
333                                  ArrayRef<MappedValue> values);
334   LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
335                                   ArrayRef<SmallVector<MappedValue>> mapping);
336 
337   // Forward declarations to support limited visibility.
338   class RegionScope;
339 
340   /// Creates a new region scope for the given region. The region is expected to
341   /// be nested in the currently processed region.
342   // Implementation note: this method is inline but implemented outside of the
343   // class body to comply with visibility and full-declaration requirements.
344   inline RegionScope make_region_scope(Region &region);
345 
346   /// A RAII object maintaining a "stack frame" for a transform IR region. When
347   /// applying a transform IR operation that contains a region, the caller is
348   /// expected to create a RegionScope before applying the ops contained in the
349   /// region. This ensures that the mappings between values defined in the
350   /// transform IR region and payload IR operations are cleared when the region
351   /// processing ends; such values cannot be accessed outside the region.
352   class RegionScope {
353   public:
354     /// Forgets the mapping from or to values defined in the associated
355     /// transform IR region, and restores the mapping that existed before
356     /// entering this scope.
357     ~RegionScope();
358 
359   private:
360     /// Creates a new scope for mappings between values defined in the given
361     /// transform IR region and payload IR objects.
362     RegionScope(TransformState &state, Region &region)
363         : state(state), region(&region) {
364       auto res = state.mappings.insert(
365           std::make_pair(&region, std::make_unique<Mappings>()));
366       assert(res.second && "the region scope is already present");
367       (void)res;
368       state.regionStack.push_back(this);
369     }
370 
371     /// Back-reference to the transform state.
372     TransformState &state;
373 
374     /// The region this scope is associated with.
375     Region *region;
376 
377     /// The transform op within this region that is currently being applied.
378     TransformOpInterface currentTransform;
379 
380     friend class transform::TransformState;
381   };
382   friend class RegionScope;
383 
384   /// Base class for TransformState extensions that allow TransformState to
385   /// contain user-specified information in the state object. Clients are
386   /// expected to derive this class, add the desired fields, and make the
387   /// derived class compatible with the MLIR TypeID mechanism:
388   ///
389   /// ```mlir
390   /// class MyExtension final : public TransformState::Extension {
391   /// public:
392   ///   MyExtension(TranfsormState &state, int myData)
393   ///     : Extension(state) {...}
394   /// private:
395   ///   int mySupplementaryData;
396   /// };
397   /// ```
398   ///
399   /// Instances of this and derived classes are not expected to be created by
400   /// the user, instead they are directly constructed within a TransformState. A
401   /// TransformState can only contain one extension with the given TypeID.
402   /// Extensions can be obtained from a TransformState instance, and can be
403   /// removed when they are no longer required.
404   ///
405   /// ```mlir
406   /// transformState.addExtension<MyExtension>(/*myData=*/42);
407   /// MyExtension *ext = transformState.getExtension<MyExtension>();
408   /// ext->doSomething();
409   /// ```
410   class Extension {
411     // Allow TransformState to allocate Extensions.
412     friend class TransformState;
413 
414   public:
415     /// Base virtual destructor.
416     // Out-of-line definition ensures symbols are emitted in a single object
417     // file.
418     virtual ~Extension();
419 
420   protected:
421     /// Constructs an extension of the given TransformState object.
422     Extension(TransformState &state) : state(state) {}
423 
424     /// Provides read-only access to the parent TransformState object.
425     const TransformState &getTransformState() const { return state; }
426 
427     /// Replaces the given payload op with another op. If the replacement op is
428     /// null, removes the association of the payload op with its handle. Returns
429     /// failure if the op is not associated with any handle.
430     ///
431     /// Note: This function does not update value handles. None of the original
432     /// op's results are allowed to be mapped to any value handle.
433     LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
434 
435     /// Replaces the given payload value with another value. If the replacement
436     /// value is null, removes the association of the payload value with its
437     /// handle. Returns failure if the value is not associated with any handle.
438     LogicalResult replacePayloadValue(Value value, Value replacement);
439 
440   private:
441     /// Back-reference to the state that is being extended.
442     TransformState &state;
443   };
444 
445   /// Adds a new Extension of the type specified as template parameter,
446   /// constructing it with the arguments provided. The extension is owned by the
447   /// TransformState. It is expected that the state does not already have an
448   /// extension of the same type. Extension constructors are expected to take
449   /// a reference to TransformState as first argument, automatically supplied
450   /// by this call.
451   template <typename Ty, typename... Args>
452   Ty &addExtension(Args &&...args) {
453     static_assert(
454         std::is_base_of<Extension, Ty>::value,
455         "only an class derived from TransformState::Extension is allowed here");
456     auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
457     auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
458     assert(result.second && "extension already added");
459     return *static_cast<Ty *>(result.first->second.get());
460   }
461 
462   /// Returns the extension of the specified type.
463   template <typename Ty>
464   Ty *getExtension() {
465     static_assert(
466         std::is_base_of<Extension, Ty>::value,
467         "only an class derived from TransformState::Extension is allowed here");
468     auto iter = extensions.find(TypeID::get<Ty>());
469     if (iter == extensions.end())
470       return nullptr;
471     return static_cast<Ty *>(iter->second.get());
472   }
473 
474   /// Removes the extension of the specified type.
475   template <typename Ty>
476   void removeExtension() {
477     static_assert(
478         std::is_base_of<Extension, Ty>::value,
479         "only an class derived from TransformState::Extension is allowed here");
480     extensions.erase(TypeID::get<Ty>());
481   }
482 
483 private:
484   /// Identifier for storing top-level value in the `operations` mapping.
485   static constexpr Value kTopLevelValue = Value();
486 
487   /// Creates a state for transform ops living in the given region. The second
488   /// argument points to the root operation in the payload IR being transformed,
489   /// which may or may not contain the region with transform ops. Additional
490   /// options can be provided through the trailing configuration object.
491   TransformState(Region *region, Operation *payloadRoot,
492                  const RaggedArray<MappedValue> &extraMappings = {},
493                  const TransformOptions &options = TransformOptions());
494 
495   /// Returns the mappings frame for the region in which the value is defined.
496   /// If `allowOutOfScope` is set to "false", asserts that the value is in
497   /// scope, based on the current stack of frames.
498   const Mappings &getMapping(Value value, bool allowOutOfScope = false) const {
499     return const_cast<TransformState *>(this)->getMapping(value,
500                                                           allowOutOfScope);
501   }
502   Mappings &getMapping(Value value, bool allowOutOfScope = false) {
503     Region *region = value.getParentRegion();
504     auto it = mappings.find(region);
505     assert(it != mappings.end() &&
506            "trying to find a mapping for a value from an unmapped region");
507 #ifndef NDEBUG
508     if (!allowOutOfScope) {
509       for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
510         if (r == region)
511           break;
512         if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
513           llvm_unreachable("trying to get mapping beyond region that is "
514                            "isolated from above");
515       }
516     }
517 #endif // NDEBUG
518     return *it->second;
519   }
520 
521   /// Returns the mappings frame for the region in which the operation resides.
522   /// If `allowOutOfScope` is set to "false", asserts that the operation is in
523   /// scope, based on the current stack of frames.
524   const Mappings &getMapping(Operation *operation,
525                              bool allowOutOfScope = false) const {
526     return const_cast<TransformState *>(this)->getMapping(operation,
527                                                           allowOutOfScope);
528   }
529   Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) {
530     Region *region = operation->getParentRegion();
531     auto it = mappings.find(region);
532     assert(it != mappings.end() &&
533            "trying to find a mapping for an operation from an unmapped region");
534 #ifndef NDEBUG
535     if (!allowOutOfScope) {
536       for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
537         if (r == region)
538           break;
539         if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
540           llvm_unreachable("trying to get mapping beyond region that is "
541                            "isolated from above");
542       }
543     }
544 #endif // NDEBUG
545     return *it->second;
546   }
547 
548   /// Updates the state to include the associations between op results and the
549   /// provided result of applying a transform op.
550   LogicalResult updateStateFromResults(const TransformResults &results,
551                                        ResultRange opResults);
552 
553   /// Returns a list of all ops that the given transform IR value corresponds
554   /// to. In case an op was erased, the returned list contains nullptr. This
555   /// function is helpful for transformations that apply to a particular handle.
556   ArrayRef<Operation *> getPayloadOpsView(Value value) const;
557 
558   /// Returns a list of payload IR values that the given transform IR value
559   /// corresponds to.
560   ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
561 
562   /// Sets the payload IR ops associated with the given transform IR value
563   /// (handle). A payload op may be associated multiple handles as long as
564   /// at most one of them gets consumed by further transformations.
565   /// For example, a hypothetical "find function by name" may be called twice in
566   /// a row to produce two handles pointing to the same function:
567   ///
568   ///   %0 = transform.find_func_by_name { name = "myfunc" }
569   ///   %1 = transform.find_func_by_name { name = "myfunc" }
570   ///
571   /// which is valid by itself. However, calling a hypothetical "rewrite and
572   /// rename function" transform on both handles:
573   ///
574   ///   transform.rewrite_and_rename %0 { new_name = "func" }
575   ///   transform.rewrite_and_rename %1 { new_name = "func" }
576   ///
577   /// is invalid given the transformation "consumes" the handle as expressed
578   /// by side effects. Practically, a transformation consuming a handle means
579   /// that the associated payload operation may no longer exist.
580   ///
581   /// Similarly, operation handles may be invalidate and should not be used
582   /// after a transform that consumed a value handle pointing to a payload value
583   /// defined by the operation as either block argument or op result. For
584   /// example, in the following sequence, the last transform operation rewrites
585   /// the callee to not return a specified result:
586   ///
587   ///   %0 = transform.find_call "myfunc"
588   ///   %1 = transform.find_results_of_calling "myfunc"
589   ///   transform.drop_call_result_from_signature %1[0]
590   ///
591   /// which requires the call operations to be recreated. Therefore, the handle
592   /// %0 becomes associated with a dangling pointer and should not be used.
593   ///
594   /// Returns failure if the payload does not satisfy the conditions associated
595   /// with the type of the handle value. The value is expected to have a type
596   /// implementing TransformHandleTypeInterface.
597   LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
598 
599   /// Sets the payload IR values association with the given transform IR value
600   /// (handle). A payload value may be associated with multiple handles as long
601   /// as at most one of them is consumed by further transformations. For
602   /// example, a hypothetical "get results of calls to function with the given
603   /// name" transform may be performed twice in a row producing handles pointing
604   /// to the same values:
605   ///
606   ///   %0 = transform.find_results_of_calling "myfunc"
607   ///   %1 = transform.find_results_of_calling "myfunc"
608   ///
609   /// which is valid by itself. However, calling a hypothetical "erase value
610   /// producer" transform on both handles:
611   ///
612   ///   transform.erase_value_produce %0
613   ///   transform.erase_value_produce %1
614   ///
615   /// is invalid provided the transformation "consumes" the handle as expressed
616   /// by side effects (which themselves reflect the semantics of the transform
617   /// erasing the producer and making the handle dangling). Practically, a
618   /// transformation consuming a handle means the associated payload value may
619   /// no longer exist.
620   ///
621   /// Similarly, value handles are invalidated and should not be used after a
622   /// transform that consumed an operation handle pointing to the payload IR
623   /// operation defining the values associated the value handle, as either block
624   /// arguments or op results, or any ancestor operation. For example,
625   ///
626   ///   %0 = transform.find_call "myfunc"
627   ///   %1 = transform.find_results_of_calling "myfunc"
628   ///   transform.rewrite_and_rename %0 { new_name = "func" }
629   ///
630   /// makes %1 unusable after the last transformation if it consumes %0. When an
631   /// operation handle is consumed, it usually indicates that the operation was
632   /// destroyed or heavily modified, meaning that the values it defines may no
633   /// longer exist.
634   ///
635   /// Returns failure if the payload values do not satisfy the conditions
636   /// associated with the type of the handle value. The value is expected to
637   /// have a type implementing TransformValueHandleTypeInterface.
638   LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
639 
640   /// Sets the parameters associated with the given transform IR value. Returns
641   /// failure if the parameters do not satisfy the conditions associated with
642   /// the type of the value. The value is expected to have a type implementing
643   /// TransformParamTypeInterface.
644   LogicalResult setParams(Value value, ArrayRef<Param> params);
645 
646   /// Forgets the payload IR ops associated with the given transform IR value,
647   /// as well as any association between value handles and the results of said
648   /// payload IR op.
649   ///
650   /// If `allowOutOfScope` is set to "false", asserts that the handle is in
651   /// scope, based on the current stack of frames.
652   void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
653                      bool allowOutOfScope = false);
654 
655   void forgetValueMapping(Value valueHandle,
656                           ArrayRef<Operation *> payloadOperations);
657 
658   /// Replaces the given payload op with another op. If the replacement op is
659   /// null, removes the association of the payload op with its handle. Returns
660   /// failure if the op is not associated with any handle.
661   ///
662   /// Note: This function does not update value handles. None of the original
663   /// op's results are allowed to be mapped to any value handle.
664   LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
665 
666   /// Replaces the given payload value with another value. If the replacement
667   /// value is null, removes the association of the payload value with its
668   /// handle. Returns failure if the value is not associated with any handle.
669   LogicalResult replacePayloadValue(Value value, Value replacement);
670 
671   /// Records handle invalidation reporters into `newlyInvalidated`.
672   /// Specifically,
673   ///  - `handle` is the op operand that consumes the handle,
674   ///  - `potentialAncestors` is a list of ancestors of the payload operation
675   ///     that the consumed handle is associated with, including itself,
676   ///  - `throughValue` is the payload value the handle to which is consumed,
677   ///     when it is the case, null when the operation handle is consumed
678   ///     directly.
679   /// Iterates over all known operation and value handles and records reporters
680   /// for any potential future use of `handle` or any other handle that is
681   /// invalidated by its consumption, i.e., any handle pointing to any payload
682   /// IR entity (operation or value) associated with the same payload IR entity
683   /// as the consumed handle, or any nested payload IR entity. If
684   /// `potentialAncestors` is empty, records the reporter anyway. Does not
685   /// override existing reporters. This must remain a const method so it doesn't
686   /// inadvertently mutate `invalidatedHandles` too early.
687   void recordOpHandleInvalidation(OpOperand &consumingHandle,
688                                   ArrayRef<Operation *> potentialAncestors,
689                                   Value throughValue,
690                                   InvalidatedHandleMap &newlyInvalidated) const;
691 
692   /// Records handle invalidation reporters into `newlyInvalidated`.
693   /// Specifically,
694   ///  - `consumingHandle` is the op operand that consumes the handle,
695   ///  - `potentialAncestors` is a list of ancestors of the payload operation
696   ///     that the consumed handle is associated with, including itself,
697   ///  - `payloadOp` is the operation itself,
698   ///  - `otherHandle` is another that may be associated with the affected
699   ///     payload operations
700   ///  - `throughValue` is the payload value the handle to which is consumed,
701   ///     when it is the case, null when the operation handle is consumed
702   ///     directly.
703   /// Looks at the payload opreations associated with `otherHandle` and if any
704   /// of these operations has an ancestor (or is itself) listed in
705   /// `potentialAncestors`, records the error message describing the use of the
706   /// invalidated handle. Does nothing if `otherHandle` already has a reporter
707   /// associated with it. This must remain a const method so it doesn't
708   /// inadvertently mutate `invalidatedHandles` too early.
709   void recordOpHandleInvalidationOne(
710       OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
711       Operation *payloadOp, Value otherHandle, Value throughValue,
712       InvalidatedHandleMap &newlyInvalidated) const;
713 
714   /// Records handle invalidation reporters into `newlyInvalidated`.
715   /// Specifically,
716   ///  - `opHandle` is the op operand that consumes the handle;
717   ///  - `potentialAncestors` is a list of ancestors of the payload operation
718   ///     that the consumed handle is associated with, including itself;
719   ///  - `payloadValue` is the value defined by the operation associated with
720   ///     the consuming handle as either op result or block argument;
721   ///  - `valueHandle` is another that may be associated with the payload value.
722   /// Looks at the payload values associated with `valueHandle` and if any of
723   /// these values is defined, as op result or block argument, by an operation
724   /// whose ancestor (or the operation itself) is listed in
725   /// `potentialAncestors`, records the error message describing the use of the
726   /// invalidated handle. Does nothing if `valueHandle` already has a reporter
727   /// associated with it. This must remain a const method so it doesn't
728   /// inadvertently mutate `invalidatedHandles` too early.
729   void recordValueHandleInvalidationByOpHandleOne(
730       OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
731       Value payloadValue, Value valueHandle,
732       InvalidatedHandleMap &newlyInvalidated) const;
733 
734   /// Records handle invalidation reporters into `newlyInvalidated`.
735   /// Specifically,
736   ///  - `valueHandle` is the op operand that consumes the handle,
737   ///  - `throughValue` is the payload value the handle to which is consumed,
738   ///     when it is the case, null when the operation handle is consumed
739   ///     directly.
740   /// Iterates over all known operation and value handles and records reporters
741   /// for any potential future use of `handle` or any other handle that is
742   /// invalidated by its consumption, i.e., any handle pointing to any payload
743   /// IR entity (operation or value) associated with the same payload IR entity
744   /// as the consumed handle, or any nested payload IR entity. Does not override
745   /// existing reporters. This must remain a const method so it doesn't
746   /// inadvertently mutate `invalidatedHandles` too early.
747   void
748   recordValueHandleInvalidation(OpOperand &valueHandle,
749                                 InvalidatedHandleMap &newlyInvalidated) const;
750 
751   /// Checks that the operation does not use invalidated handles as operands.
752   /// Reports errors and returns failure if it does. Otherwise, invalidates the
753   /// handles consumed by the operation as well as any handles pointing to
754   /// payload IR operations nested in the operations associated with the
755   /// consumed handles.
756   LogicalResult
757   checkAndRecordHandleInvalidation(TransformOpInterface transform);
758 
759   /// Implementation of the checkAndRecordHandleInvalidation. This must remain a
760   /// const method so it doesn't inadvertently mutate `invalidatedHandles` too
761   /// early.
762   LogicalResult checkAndRecordHandleInvalidationImpl(
763       transform::TransformOpInterface transform,
764       transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const;
765 
766   /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
767   void compactOpHandles();
768 
769   /// A stack of mappings between transform IR values and payload IR ops,
770   /// aggregated by the region in which the transform IR values are defined.
771   /// We use a pointer to the Mappings struct so that reallocations inside
772   /// MapVector don't invalidate iterators when we apply nested transform ops
773   /// while also iterating over the mappings.
774   llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
775 
776   /// Op handles may be temporarily mapped to nullptr to avoid invalidating
777   /// payload op iterators. This set contains all op handles with nullptrs.
778   /// These handles are "compacted" (i.e., nullptrs removed) at the end of each
779   /// transform.
780   DenseSet<Value> opHandlesToCompact;
781 
782   /// Extensions attached to the TransformState, identified by the TypeID of
783   /// their type. Only one extension of any given type is allowed.
784   DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
785 
786   /// The top-level operation that contains all payload IR, typically a module.
787   Operation *topLevel;
788 
789   /// Extra mapped values (payload operations, values or parameters) to be
790   /// associated with additional entry block arguments of the top-level
791   /// transform operation.
792   RaggedArray<MappedValue> topLevelMappedValues;
793 
794   /// Additional options controlling the transformation state behavior.
795   TransformOptions options;
796 
797   /// The mapping from invalidated handles to the error-reporting functions that
798   /// describe when the handles were invalidated. Calling such a function emits
799   /// a user-visible diagnostic with an additional note pointing to the given
800   /// location.
801   InvalidatedHandleMap invalidatedHandles;
802 
803   /// A stack of nested regions that are being processed in the transform IR.
804   /// Each region must be an ancestor of the following regions in this list.
805   /// These are also the keys for "mappings".
806   SmallVector<RegionScope *> regionStack;
807 
808   /// The top-level region scope. The first (bottom) element of `regionStack`
809   /// is the top-level region scope object.
810   std::unique_ptr<RegionScope> topLevelRegionScope;
811 };
812 
813 /// Local mapping between values defined by a specific op implementing the
814 /// TransformOpInterface and the payload IR ops they correspond to.
815 class TransformResults {
816   friend class TransformState;
817 
818 public:
819   /// Indicates that the result of the transform IR op at the given position
820   /// corresponds to the given list of payload IR ops. Each result must be set
821   /// by the transformation exactly once in case of transformation succeeding.
822   /// The value must have a type implementing TransformHandleTypeInterface.
823   template <typename Range>
824   void set(OpResult value, Range &&ops) {
825     int64_t position = value.getResultNumber();
826     assert(position < static_cast<int64_t>(operations.size()) &&
827            "setting results for a non-existent handle");
828     assert(operations[position].data() == nullptr && "results already set");
829     assert(params[position].data() == nullptr &&
830            "another kind of results already set");
831     assert(values[position].data() == nullptr &&
832            "another kind of results already set");
833     operations.replace(position, std::forward<Range>(ops));
834   }
835 
836   /// Indicates that the result of the transform IR op at the given position
837   /// corresponds to the given list of payload IR ops. Each result must be set
838   /// by the transformation exactly once in case of transformation succeeding.
839   /// The value must have a type implementing TransformHandleTypeInterface.
840   void set(OpResult value, std::initializer_list<Operation *> ops) {
841     set(value, ArrayRef<Operation *>(ops));
842   }
843 
844   /// Indicates that the result of the transform IR op at the given position
845   /// corresponds to the given list of parameters. Each result must be set by
846   /// the transformation exactly once in case of transformation succeeding. The
847   /// value must have a type implementing TransformParamTypeInterface.
848   void setParams(OpResult value, ArrayRef<TransformState::Param> params);
849 
850   /// Indicates that the result of the transform IR op at the given position
851   /// corresponds to the given range of payload IR values. Each result must be
852   /// set by the transformation exactly once in case of transformation
853   /// succeeding. The value must have a type implementing
854   /// TransformValueHandleTypeInterface.
855   template <typename Range>
856   void setValues(OpResult handle, Range &&values) {
857     int64_t position = handle.getResultNumber();
858     assert(position < static_cast<int64_t>(this->values.size()) &&
859            "setting values for a non-existent handle");
860     assert(this->values[position].data() == nullptr && "values already set");
861     assert(operations[position].data() == nullptr &&
862            "another kind of results already set");
863     assert(params[position].data() == nullptr &&
864            "another kind of results already set");
865     this->values.replace(position, std::forward<Range>(values));
866   }
867 
868   /// Indicates that the result of the transform IR op at the given position
869   /// corresponds to the given range of payload IR values. Each result must be
870   /// set by the transformation exactly once in case of transformation
871   /// succeeding. The value must have a type implementing
872   /// TransformValueHandleTypeInterface.
873   void setValues(OpResult handle, std::initializer_list<Value> values) {
874     setValues(handle, ArrayRef<Value>(values));
875   }
876 
877   /// Indicates that the result of the transform IR op at the given position
878   /// corresponds to the given range of mapped values. All mapped values are
879   /// expected to be compatible with the type of the result, e.g., if the result
880   /// is an operation handle, all mapped values are expected to be payload
881   /// operations.
882   void setMappedValues(OpResult handle, ArrayRef<MappedValue> values);
883 
884   /// Sets the currently unset results to empty lists of the kind expected by
885   /// the corresponding results of the given `transform` op.
886   void setRemainingToEmpty(TransformOpInterface transform);
887 
888 private:
889   /// Creates an instance of TransformResults that expects mappings for
890   /// `numSegments` values, which may be associated with payload operations or
891   /// parameters.
892   explicit TransformResults(unsigned numSegments);
893 
894   /// Gets the list of operations associated with the result identified by its
895   /// number in the list of operation results. The result must have been set to
896   /// be associated with payload IR operations.
897   ArrayRef<Operation *> get(unsigned resultNumber) const;
898 
899   /// Gets the list of parameters associated with the result identified by its
900   /// number in the list of operation results. The result must have been set to
901   /// be associated with parameters.
902   ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const;
903 
904   /// Gets the list of payload IR values associated with the result identified
905   /// by its number in the list of operation results. The result must have been
906   /// set to be associated with payload IR values.
907   ArrayRef<Value> getValues(unsigned resultNumber) const;
908 
909   /// Returns `true` if the result identified by its number in the list of
910   /// operation results is associated with a list of parameters, `false`
911   /// otherwise.
912   bool isParam(unsigned resultNumber) const;
913 
914   /// Returns `true` if the result identified by its number in the list of
915   /// operation results is associated with a list of payload IR value, `false`
916   /// otherwise.
917   bool isValue(unsigned resultNumber) const;
918 
919   /// Returns `true` if the result identified by its number in the list of
920   /// operation results is associated with something.
921   bool isSet(unsigned resultNumber) const;
922 
923   /// Pointers to payload IR ops that are associated with results of a transform
924   /// IR op.
925   RaggedArray<Operation *> operations;
926 
927   /// Parameters that are associated with results of the transform IR op.
928   RaggedArray<Param> params;
929 
930   /// Payload IR values that are associated with results of a transform IR op.
931   RaggedArray<Value> values;
932 };
933 
934 /// Creates a RAII object the lifetime of which corresponds to the new mapping
935 /// for transform IR values defined in the given region. Values defined in
936 /// surrounding regions remain accessible.
937 TransformState::RegionScope TransformState::make_region_scope(Region &region) {
938   return RegionScope(*this, region);
939 }
940 
941 /// A configuration object for customizing a `TrackingListener`.
942 struct TrackingListenerConfig {
943   using SkipHandleFn = std::function<bool(Value)>;
944 
945   /// An optional function that returns "true" for handles that do not have to
946   /// be updated. These are typically dead or consumed handles.
947   SkipHandleFn skipHandleFn = nullptr;
948 
949   /// If set to "true", the name of a replacement op must match the name of the
950   /// original op. If set to "false", the names of the payload ops tracked in a
951   /// handle may change as the tracking listener updates the transform state.
952   bool requireMatchingReplacementOpName = true;
953 
954   /// If set to "true", cast ops (that implement the CastOpInterface) are
955   /// skipped and the replacement op search continues with the operands of the
956   /// cast op.
957   bool skipCastOps = true;
958 };
959 
960 /// A listener that updates a TransformState based on IR modifications. This
961 /// listener can be used during a greedy pattern rewrite to keep the transform
962 /// state up-to-date.
963 class TrackingListener : public RewriterBase::Listener,
964                          public TransformState::Extension {
965 public:
966   /// Create a new TrackingListener for usage in the specified transform op.
967   /// Optionally, a function can be specified to identify handles that should
968   /// do not have to be updated.
969   TrackingListener(TransformState &state, TransformOpInterface op,
970                    TrackingListenerConfig config = TrackingListenerConfig());
971 
972 protected:
973   /// Return a replacement payload op for the given op, which is going to be
974   /// replaced with the given values. By default, if all values are defined by
975   /// the same op, which also has the same type as the given op, that defining
976   /// op is used as a replacement.
977   ///
978   /// A "failure" return value indicates that no replacement operation could be
979   /// found. A "nullptr" return value indicates that no replacement op is needed
980   /// (e.g., handle is dead or was consumed) and that the payload op should
981   /// be dropped from the mapping.
982   ///
983   /// Example: A tracked "linalg.generic" with two results is replaced with two
984   /// values defined by (another) "linalg.generic". It is reasonable to assume
985   /// that the replacement "linalg.generic" represents the same "computation".
986   /// Therefore, the payload op mapping is updated to the defining op of the
987   /// replacement values.
988   ///
989   /// Counter Example: A "linalg.generic" is replaced with values defined by an
990   /// "scf.for". Without further investigation, the relationship between the
991   /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
992   /// same computation; e.g., there may be tiled "linalg.generic" inside the
993   /// loop body that represents the original computation. Therefore, the
994   /// TrackingListener is conservative by default: it drops the mapping and
995   /// triggers the "payload replacement not found" notification. This default
996   /// behavior can be customized in `TrackingListenerConfig`.
997   ///
998   /// If no replacement op could be found according to the rules mentioned
999   /// above, this function tries to skip over cast-like ops that implement
1000   /// `CastOpInterface`.
1001   ///
1002   /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
1003   /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
1004   /// reasonable to assume that the wrapped "linalg.generic" represents the same
1005   /// computation as the original "linalg.generic". The mapping is updated
1006   /// accordingly.
1007   ///
1008   /// Certain ops (typically also metadata-only ops) are not considered casts,
1009   /// but should be skipped nonetheless. Such ops should implement
1010   /// `FindPayloadReplacementOpInterface` to specify with which operands the
1011   /// lookup should continue.
1012   ///
1013   /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
1014   /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
1015   /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
1016   /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
1017   /// implementation, the replacement op lookup continues with the wrapped
1018   /// "linalg.generic" and the mapping is updated accordingly.
1019   ///
1020   /// Derived classes may override `findReplacementOp` to specify custom
1021   /// replacement rules.
1022   virtual DiagnosedSilenceableFailure
1023   findReplacementOp(Operation *&result, Operation *op,
1024                     ValueRange newValues) const;
1025 
1026   /// Notify the listener that the pattern failed to match the given operation,
1027   /// and provide a callback to populate a diagnostic with the reason why the
1028   /// failure occurred.
1029   void
1030   notifyMatchFailure(Location loc,
1031                      function_ref<void(Diagnostic &)> reasonCallback) override;
1032 
1033   /// This function is called when a tracked payload op is dropped because no
1034   /// replacement op was found. Derived classes can implement this function for
1035   /// custom error handling.
1036   virtual void
1037   notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
1038                                    DiagnosedSilenceableFailure &&diag) {}
1039 
1040   /// Return the single op that defines all given values (if any).
1041   static Operation *getCommonDefiningOp(ValueRange values);
1042 
1043   /// Return the transform op in which this TrackingListener is used.
1044   TransformOpInterface getTransformOp() const { return transformOp; }
1045 
1046 private:
1047   friend class TransformRewriter;
1048 
1049   void notifyOperationErased(Operation *op) override;
1050 
1051   void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
1052   using Listener::notifyOperationReplaced;
1053 
1054   /// The transform op in which this TrackingListener is used.
1055   TransformOpInterface transformOp;
1056 
1057   /// The handles that are consumed by the transform op.
1058   DenseSet<Value> consumedHandles;
1059 
1060   /// Tracking listener configuration.
1061   TrackingListenerConfig config;
1062 };
1063 
1064 /// A specialized listener that keeps track of cases in which no replacement
1065 /// payload could be found. The error state of this listener must be checked
1066 /// before the end of its lifetime.
1067 class ErrorCheckingTrackingListener : public TrackingListener {
1068 public:
1069   using transform::TrackingListener::TrackingListener;
1070 
1071   ~ErrorCheckingTrackingListener() override;
1072 
1073   /// Check and return the current error state of this listener. Afterwards,
1074   /// resets the error state to "success".
1075   DiagnosedSilenceableFailure checkAndResetError();
1076 
1077   /// Return "true" if this tracking listener had a failure.
1078   bool failed() const;
1079 
1080 protected:
1081   void
1082   notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
1083                                    DiagnosedSilenceableFailure &&diag) override;
1084 
1085 private:
1086   /// The error state of this listener. "Success" indicates that no error
1087   /// happened so far.
1088   DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success();
1089 
1090   /// The number of errors that have been encountered.
1091   int64_t errorCounter = 0;
1092 };
1093 
1094 /// This is a special rewriter to be used in transform op implementations,
1095 /// providing additional helper functions to update the transform state, etc.
1096 // TODO: Helper functions will be added in a subsequent change.
1097 class TransformRewriter : public RewriterBase {
1098 protected:
1099   friend class TransformState;
1100 
1101   /// Create a new TransformRewriter.
1102   explicit TransformRewriter(MLIRContext *ctx,
1103                              ErrorCheckingTrackingListener *listener);
1104 
1105 public:
1106   /// Return "true" if the tracking listener had failures.
1107   bool hasTrackingFailures() const;
1108 
1109   /// Silence all tracking failures that have been encountered so far.
1110   void silenceTrackingFailure();
1111 
1112   /// Notify the transform dialect interpreter that the given op has been
1113   /// replaced with another op and that the mapping between handles and payload
1114   /// ops/values should be updated. This function should be called before the
1115   /// original op is erased. It fails if the operation could not be replaced,
1116   /// e.g., because the original operation is not tracked.
1117   ///
1118   /// Note: As long as IR modifications are performed through this rewriter,
1119   /// the transform state is usually updated automatically. This function should
1120   /// be used when unsupported rewriter API is used; e.g., updating all uses of
1121   /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`.
1122   LogicalResult notifyPayloadOperationReplaced(Operation *op,
1123                                                Operation *replacement);
1124 
1125 private:
1126   ErrorCheckingTrackingListener *const listener;
1127 };
1128 
1129 /// This trait is supposed to be attached to Transform dialect operations that
1130 /// can be standalone top-level transforms. Such operations typically contain
1131 /// other Transform dialect operations that can be executed following some
1132 /// control flow logic specific to the current operation. The operations with
1133 /// this trait are expected to have at least one single-block region with at
1134 /// least one argument of type implementing TransformHandleTypeInterface. The
1135 /// operations are also expected to be valid without operands, in which case
1136 /// they are considered top-level, and with one or more arguments, in which case
1137 /// they are considered nested. Top-level operations have the block argument of
1138 /// the entry block in the Transform IR correspond to the root operation of
1139 /// Payload IR. Nested operations have the block argument of the entry block in
1140 /// the Transform IR correspond to a list of Payload IR operations mapped to the
1141 /// first operand of the Transform IR operation. The operation must implement
1142 /// TransformOpInterface.
1143 template <typename OpTy>
1144 class PossibleTopLevelTransformOpTrait
1145     : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
1146 public:
1147   /// Verifies that `op` satisfies the invariants of this trait. Not expected to
1148   /// be called directly.
1149   static LogicalResult verifyTrait(Operation *op) {
1150     return detail::verifyPossibleTopLevelTransformOpTrait(op);
1151   }
1152 
1153   /// Returns the single block of the given region.
1154   Block *getBodyBlock(unsigned region = 0) {
1155     return &this->getOperation()->getRegion(region).front();
1156   }
1157 
1158   /// Populates `effects` with side effects implied by this trait.
1159   void getPotentialTopLevelEffects(
1160       SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1161     detail::getPotentialTopLevelEffects(
1162         this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
1163         *getBodyBlock(), effects);
1164   }
1165 
1166   /// Sets up the mapping between the entry block of the given region of this op
1167   /// and the relevant list of Payload IR operations in the given state. The
1168   /// state is expected to be already scoped at the region of this operation.
1169   LogicalResult mapBlockArguments(TransformState &state, Region &region) {
1170     assert(region.getParentOp() == this->getOperation() &&
1171            "op comes from the wrong region");
1172     return detail::mapPossibleTopLevelTransformOpBlockArguments(
1173         state, this->getOperation(), region);
1174   }
1175   LogicalResult mapBlockArguments(TransformState &state) {
1176     assert(
1177         this->getOperation()->getNumRegions() == 1 &&
1178         "must indicate the region to map if the operation has more than one");
1179     return mapBlockArguments(state, this->getOperation()->getRegion(0));
1180   }
1181 };
1182 
1183 class ApplyToEachResultList;
1184 
1185 /// Trait implementing the TransformOpInterface for operations applying a
1186 /// transformation to a single operation handle and producing an arbitrary
1187 /// number of handles and parameter values.
1188 /// The op must implement a method with the following signature:
1189 ///   - DiagnosedSilenceableFailure applyToOne(OpTy,
1190 ///       ApplyToEachResultList &results, TransformState &state)
1191 /// to perform a transformation that is applied in turn to all payload IR
1192 /// operations that correspond to the handle of the transform IR operation.
1193 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
1194 /// that the transformation is applied to (and NOT the class of the transform IR
1195 /// op).
1196 /// The `applyToOne` method takes an empty `results` vector that it fills with
1197 /// zero, one or multiple operations depending on the number of results expected
1198 /// by the transform op.
1199 /// The number of results must match the number of results of the transform op.
1200 /// `applyToOne` is allowed to fill the `results` with all null elements to
1201 /// signify that the transformation did not apply to the payload IR operations.
1202 /// Such null elements are filtered out from results before return.
1203 ///
1204 /// The transform op having this trait is expected to have a single operand.
1205 template <typename OpTy>
1206 class TransformEachOpTrait
1207     : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
1208 public:
1209   /// Calls `applyToOne` for every payload operation associated with the operand
1210   /// of this transform IR op, the following case disjunction happens:
1211   ///   1. If not target payload ops are associated to the operand then fill the
1212   ///      results vector with the expected number of null elements and return
1213   ///      success. This is the corner case handling that allows propagating
1214   ///      the "no-op" case gracefully to improve usability.
1215   ///   2. If any `applyToOne` returns definiteFailure, the transformation is
1216   ///      immediately considered definitely failed and we return.
1217   ///   3. All applications of `applyToOne` are checked to return a number of
1218   ///      results expected by the transform IR op. If not, this is a definite
1219   ///      failure and we return early.
1220   ///   4. If `applyToOne` produces ops, associate them with the result of this
1221   ///      transform op.
1222   ///   5. If any `applyToOne` return silenceableFailure, the transformation is
1223   ///      considered silenceable.
1224   ///   6. Otherwise the transformation is considered successful.
1225   DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
1226                                     TransformResults &transformResults,
1227                                     TransformState &state);
1228 
1229   /// Checks that the op matches the expectations of this trait.
1230   static LogicalResult verifyTrait(Operation *op);
1231 };
1232 
1233 /// Side effect resource corresponding to the mapping between Transform IR
1234 /// values and Payload IR operations. An Allocate effect from this resource
1235 /// means creating a new mapping entry, it is always accompanied by a Write
1236 /// effect. A Read effect from this resource means accessing the mapping. A Free
1237 /// effect on this resource indicates the removal of the mapping entry,
1238 /// typically after a transformation that modifies the Payload IR operations
1239 /// associated with one of the Transform IR operation's operands. It is always
1240 /// accompanied by a Read effect. Read-after-Free and double-Free are not
1241 /// allowed (they would be problematic with "regular" memory effects too) as
1242 /// they indicate an attempt to access Payload IR operations that have been
1243 /// modified, potentially erased, by the previous transformations.
1244 // TODO: consider custom effects if these are not enabling generic passes such
1245 // as CSE/DCE to work.
1246 struct TransformMappingResource
1247     : public SideEffects::Resource::Base<TransformMappingResource> {
1248   StringRef getName() override { return "transform.mapping"; }
1249 };
1250 
1251 /// Side effect resource corresponding to the Payload IR itself. Only Read and
1252 /// Write effects are expected on this resource, with Write always accompanied
1253 /// by a Read (short of fully replacing the top-level Payload IR operation, one
1254 /// cannot modify the Payload IR without reading it first). This is intended
1255 /// to disallow reordering of Transform IR operations that mutate the Payload IR
1256 /// while still allowing the reordering of those that only access it.
1257 struct PayloadIRResource
1258     : public SideEffects::Resource::Base<PayloadIRResource> {
1259   StringRef getName() override { return "transform.payload_ir"; }
1260 };
1261 
1262 /// Populates `effects` with the memory effects indicating the operation on the
1263 /// given handle value:
1264 ///   - consumes = Read + Free,
1265 ///   - produces = Allocate + Write,
1266 ///   - onlyReads = Read.
1267 void consumesHandle(MutableArrayRef<OpOperand> handles,
1268                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1269 void producesHandle(ResultRange handles,
1270                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1271 void producesHandle(MutableArrayRef<BlockArgument> handles,
1272                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1273 void onlyReadsHandle(MutableArrayRef<OpOperand> handles,
1274                      SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1275 
1276 /// Checks whether the transform op consumes the given handle.
1277 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
1278 
1279 /// Populates `effects` with the memory effects indicating the access to payload
1280 /// IR resource.
1281 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1282 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1283 
1284 /// Checks whether the transform op modifies the payload.
1285 bool doesModifyPayload(transform::TransformOpInterface transform);
1286 /// Checks whether the transform op reads the payload.
1287 bool doesReadPayload(transform::TransformOpInterface transform);
1288 
1289 /// Populates `consumedArguments` with positions of `block` arguments that are
1290 /// consumed by the operations in the `block`.
1291 void getConsumedBlockArguments(
1292     Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1293 
1294 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
1295 /// their operands and produce new results.
1296 template <typename OpTy>
1297 class FunctionalStyleTransformOpTrait
1298     : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
1299 public:
1300   /// This op "consumes" the operands by reading and freeing then, "produces"
1301   /// the results by allocating and writing it and reads/writes the payload IR
1302   /// in the process.
1303   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1304     consumesHandle(this->getOperation()->getOpOperands(), effects);
1305     producesHandle(this->getOperation()->getOpResults(), effects);
1306     modifiesPayload(effects);
1307   }
1308 
1309   /// Checks that the op matches the expectations of this trait.
1310   static LogicalResult verifyTrait(Operation *op) {
1311     if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1312       op->emitError()
1313           << "FunctionalStyleTransformOpTrait should only be attached to ops "
1314              "that implement MemoryEffectOpInterface";
1315     }
1316     return success();
1317   }
1318 };
1319 
1320 /// Trait implementing the MemoryEffectOpInterface for operations that use their
1321 /// operands without consuming and without modifying the Payload IR to
1322 /// potentially produce new handles.
1323 template <typename OpTy>
1324 class NavigationTransformOpTrait
1325     : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
1326 public:
1327   /// This op produces handles to the Payload IR without consuming the original
1328   /// handles and without modifying the IR itself.
1329   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1330     onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
1331     producesHandle(this->getOperation()->getOpResults(), effects);
1332     if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) {
1333           return isa<TransformHandleTypeInterface,
1334                      TransformValueHandleTypeInterface>(t);
1335         })) {
1336       onlyReadsPayload(effects);
1337     }
1338   }
1339 
1340   /// Checks that the op matches the expectation of this trait.
1341   static LogicalResult verifyTrait(Operation *op) {
1342     if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1343       op->emitError() << "NavigationTransformOpTrait should only be attached "
1344                          "to ops that implement MemoryEffectOpInterface";
1345     }
1346     return success();
1347   }
1348 };
1349 
1350 namespace detail {
1351 /// Non-template implementation of ParamProducerTransformOpTrait::getEffects().
1352 void getParamProducerTransformOpTraitEffects(
1353     Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1354 /// Non-template implementation of ParamProducerTransformOpTrait::verify().
1355 LogicalResult verifyParamProducerTransformOpTrait(Operation *op);
1356 } // namespace detail
1357 
1358 /// Trait implementing the MemoryEffectsOpInterface for operations that produce
1359 /// transform dialect parameters. It marks all op results of
1360 /// TransformHandleTypeInterface as produced by the op, all operands as only
1361 /// read by the op and, if at least one of the operand is a handle to payload
1362 /// ops, the entire payload as potentially read. The op must only produce
1363 /// parameter-typed results.
1364 template <typename OpTy>
1365 class ParamProducerTransformOpTrait
1366     : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> {
1367 public:
1368   /// Populates `effects` with effect instances described in the trait
1369   /// documentation.
1370   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1371     detail::getParamProducerTransformOpTraitEffects(this->getOperation(),
1372                                                     effects);
1373   }
1374 
1375   /// Checks that the op matches the expectation of this trait, i.e., that it
1376   /// implements the MemoryEffectsOpInterface and only produces parameter-typed
1377   /// results.
1378   static LogicalResult verifyTrait(Operation *op) {
1379     return detail::verifyParamProducerTransformOpTrait(op);
1380   }
1381 };
1382 
1383 /// `TrackingListener` failures are reported only for ops that have this trait.
1384 /// The purpose of this trait is to give users more time to update their custom
1385 /// transform ops to use the provided `TransformRewriter` for all IR
1386 /// modifications. This trait will eventually be removed, and failures will be
1387 /// reported for all transform ops.
1388 template <typename OpTy>
1389 class ReportTrackingListenerFailuresOpTrait
1390     : public OpTrait::TraitBase<OpTy, ReportTrackingListenerFailuresOpTrait> {};
1391 
1392 /// A single result of applying a transform op with `ApplyEachOpTrait` to a
1393 /// single payload operation.
1394 using ApplyToEachResult = MappedValue;
1395 
1396 /// A list of results of applying a transform op with `ApplyEachOpTrait` to a
1397 /// single payload operation, co-indexed with the results of the transform op.
1398 class ApplyToEachResultList {
1399 public:
1400   ApplyToEachResultList() = default;
1401   explicit ApplyToEachResultList(unsigned size) : results(size) {}
1402 
1403   /// Sets the list of results to `size` null pointers.
1404   void assign(unsigned size, std::nullptr_t) { results.assign(size, nullptr); }
1405 
1406   /// Sets the list of results to the given range of values.
1407   template <typename Range>
1408   void assign(Range &&range) {
1409     // This is roughly the implementation of SmallVectorImpl::assign.
1410     // Dispatching to it with map_range and template type inference would result
1411     // in more complex code here.
1412     results.clear();
1413     results.reserve(llvm::size(range));
1414     for (auto element : range) {
1415       if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1416                                           Operation *>) {
1417         results.push_back(static_cast<Operation *>(element));
1418       } else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1419                                                  Value>) {
1420         results.push_back(element.template get<Value>());
1421       } else {
1422         results.push_back(static_cast<Attribute>(element));
1423       }
1424     }
1425   }
1426 
1427   /// Appends an element to the list.
1428   // Using ApplyToEachResult that can be implicitly constructed from a Value but
1429   // not from a concrete Op that is implicitly convertible to a Value to avoid
1430   // ambiguity.
1431   void push_back(Operation *op) { results.push_back(op); }
1432   void push_back(Attribute attr) { results.push_back(attr); }
1433   void push_back(ApplyToEachResult r) { results.push_back(r); }
1434 
1435   /// Reserves space for `size` elements in the list.
1436   void reserve(unsigned size) { results.reserve(size); }
1437 
1438   /// Iterators over the list.
1439   auto begin() { return results.begin(); }
1440   auto end() { return results.end(); }
1441   auto begin() const { return results.begin(); }
1442   auto end() const { return results.end(); }
1443 
1444   /// Returns the number of elements in the list.
1445   size_t size() const { return results.size(); }
1446 
1447   /// Element access. Expects the index to be in bounds.
1448   ApplyToEachResult &operator[](size_t index) { return results[index]; }
1449   const ApplyToEachResult &operator[](size_t index) const {
1450     return results[index];
1451   }
1452 
1453 private:
1454   /// Underlying storage.
1455   SmallVector<ApplyToEachResult> results;
1456 };
1457 
1458 namespace detail {
1459 
1460 /// Check that the contents of `partialResult` matches the number, kind (payload
1461 /// op or parameter) and nullity (either all or none) requirements of
1462 /// `transformOp`. Report errors and return failure otherwise.
1463 LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc,
1464                               const ApplyToEachResultList &partialResult);
1465 
1466 /// "Transpose" the results produced by individual applications, arranging them
1467 /// per result value of the transform op, and populate `transformResults` with
1468 /// that. The number, kind and nullity of per-application results are assumed to
1469 /// have been verified.
1470 void setApplyToOneResults(Operation *transformOp,
1471                           TransformResults &transformResults,
1472                           ArrayRef<ApplyToEachResultList> results);
1473 
1474 /// Applies a one-to-one or a one-to-many transform to each of the given
1475 /// targets. Puts the results of transforms, if any, in `results` in the same
1476 /// order. Fails if any of the application fails. Individual transforms must be
1477 /// callable with the following signature:
1478 ///   - DiagnosedSilenceableFailure(OpTy,
1479 ///       SmallVector<Operation*> &results, state)
1480 /// where OpTy is either
1481 ///   - Operation *, in which case the transform is always applied;
1482 ///   - a concrete Op class, in which case a check is performed whether
1483 ///   `targets` contains operations of the same class and a silenceable failure
1484 ///   is reported if it does not.
1485 template <typename TransformOpTy, typename Range>
1486 DiagnosedSilenceableFailure applyTransformToEach(
1487     TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets,
1488     SmallVectorImpl<ApplyToEachResultList> &results, TransformState &state) {
1489   using OpTy = typename llvm::function_traits<
1490       decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1491   static_assert(std::is_convertible<OpTy, Operation *>::value,
1492                 "expected transform function to take an operation");
1493   OpBuilder::InsertionGuard g(rewriter);
1494 
1495   SmallVector<Diagnostic> silenceableStack;
1496   unsigned expectedNumResults = transformOp->getNumResults();
1497   for (Operation *target : targets) {
1498     auto specificOp = dyn_cast<OpTy>(target);
1499     if (!specificOp) {
1500       Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
1501       diag << "transform applied to the wrong op kind";
1502       diag.attachNote(target->getLoc()) << "when applied to this op";
1503       silenceableStack.push_back(std::move(diag));
1504       continue;
1505     }
1506 
1507     ApplyToEachResultList partialResults;
1508     partialResults.reserve(expectedNumResults);
1509     Location specificOpLoc = specificOp->getLoc();
1510     rewriter.setInsertionPoint(specificOp);
1511     DiagnosedSilenceableFailure res =
1512         transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1513     if (res.isDefiniteFailure())
1514       return DiagnosedSilenceableFailure::definiteFailure();
1515 
1516     if (res.isSilenceableFailure()) {
1517       res.takeDiagnostics(silenceableStack);
1518       continue;
1519     }
1520 
1521     if (failed(detail::checkApplyToOne(transformOp, specificOpLoc,
1522                                        partialResults))) {
1523       return DiagnosedSilenceableFailure::definiteFailure();
1524     }
1525     results.push_back(std::move(partialResults));
1526   }
1527   if (!silenceableStack.empty()) {
1528     return DiagnosedSilenceableFailure::silenceableFailure(
1529         std::move(silenceableStack));
1530   }
1531   return DiagnosedSilenceableFailure::success();
1532 }
1533 
1534 /// Reports an error and returns failure if `targets` contains an ancestor
1535 /// operation before its descendant (or a copy of itself). Implementation detail
1536 /// for expensive checks during `TransformEachOpTrait::apply`.
1537 LogicalResult checkNestedConsumption(Location loc,
1538                                      ArrayRef<Operation *> targets);
1539 
1540 } // namespace detail
1541 } // namespace transform
1542 } // namespace mlir
1543 
1544 template <typename OpTy>
1545 mlir::DiagnosedSilenceableFailure
1546 mlir::transform::TransformEachOpTrait<OpTy>::apply(
1547     TransformRewriter &rewriter, TransformResults &transformResults,
1548     TransformState &state) {
1549   Value handle = this->getOperation()->getOperand(0);
1550   auto targets = state.getPayloadOps(handle);
1551 
1552   // If the operand is consumed, check if it is associated with operations that
1553   // may be erased before their nested operations are.
1554   if (state.getOptions().getExpensiveChecksEnabled() &&
1555       isHandleConsumed(handle, cast<transform::TransformOpInterface>(
1556                                    this->getOperation())) &&
1557       failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
1558                                             llvm::to_vector(targets)))) {
1559     return DiagnosedSilenceableFailure::definiteFailure();
1560   }
1561 
1562   // Step 1. Handle the corner case where no target is specified.
1563   // This is typically the case when the matcher fails to apply and we need to
1564   // propagate gracefully.
1565   // In this case, we fill all results with an empty vector.
1566   if (std::empty(targets)) {
1567     SmallVector<Operation *> emptyPayload;
1568     SmallVector<Attribute> emptyParams;
1569     for (OpResult r : this->getOperation()->getResults()) {
1570       if (isa<TransformParamTypeInterface>(r.getType()))
1571         transformResults.setParams(r, emptyParams);
1572       else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1573         transformResults.setValues(r, ValueRange());
1574       else
1575         transformResults.set(r, emptyPayload);
1576     }
1577     return DiagnosedSilenceableFailure::success();
1578   }
1579 
1580   // Step 2. Call applyToOne on each target and record newly produced ops in its
1581   // corresponding results entry.
1582   SmallVector<ApplyToEachResultList, 1> results;
1583   DiagnosedSilenceableFailure result = detail::applyTransformToEach(
1584       cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1585 
1586   // Step 3. Propagate the definite failure if any and bail out.
1587   if (result.isDefiniteFailure())
1588     return result;
1589 
1590   // Step 4. "Transpose" the results produced by individual applications,
1591   // arranging them per result value of the transform op. The number, kind and
1592   // nullity of per-application results have been verified by the callback
1593   // above.
1594   detail::setApplyToOneResults(this->getOperation(), transformResults, results);
1595 
1596   // Step 5. ApplyToOne may have returned silenceableFailure, propagate it.
1597   return result;
1598 }
1599 
1600 template <typename OpTy>
1601 llvm::LogicalResult
1602 mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
1603   static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1604                 "expected single-operand op");
1605   if (!op->getName().getInterface<TransformOpInterface>()) {
1606     return op->emitError() << "TransformEachOpTrait should only be attached to "
1607                               "ops that implement TransformOpInterface";
1608   }
1609 
1610   return success();
1611 }
1612 
1613 #endif // DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
1614