xref: /llvm-project/mlir/include/mlir/Transforms/DialectConversion.h (revision 3ace685105d3b50bca68328bf0c945af22d70f23)
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
41 class TypeConverter {
42 public:
43   virtual ~TypeConverter() = default;
44   TypeConverter() = default;
45   // Copy the registered conversions, but not the caches
46   TypeConverter(const TypeConverter &other)
47       : conversions(other.conversions),
48         argumentMaterializations(other.argumentMaterializations),
49         sourceMaterializations(other.sourceMaterializations),
50         targetMaterializations(other.targetMaterializations),
51         typeAttributeConversions(other.typeAttributeConversions) {}
52   TypeConverter &operator=(const TypeConverter &other) {
53     conversions = other.conversions;
54     argumentMaterializations = other.argumentMaterializations;
55     sourceMaterializations = other.sourceMaterializations;
56     targetMaterializations = other.targetMaterializations;
57     typeAttributeConversions = other.typeAttributeConversions;
58     return *this;
59   }
60 
61   /// This class provides all of the information necessary to convert a type
62   /// signature.
63   class SignatureConversion {
64   public:
65     SignatureConversion(unsigned numOrigInputs)
66         : remappedInputs(numOrigInputs) {}
67 
68     /// This struct represents a range of new types or a single value that
69     /// remaps an existing signature input.
70     struct InputMapping {
71       size_t inputNo, size;
72       Value replacementValue;
73     };
74 
75     /// Return the argument types for the new signature.
76     ArrayRef<Type> getConvertedTypes() const { return argTypes; }
77 
78     /// Get the input mapping for the given argument.
79     std::optional<InputMapping> getInputMapping(unsigned input) const {
80       return remappedInputs[input];
81     }
82 
83     //===------------------------------------------------------------------===//
84     // Conversion Hooks
85     //===------------------------------------------------------------------===//
86 
87     /// Remap an input of the original signature with a new set of types. The
88     /// new types are appended to the new signature conversion.
89     void addInputs(unsigned origInputNo, ArrayRef<Type> types);
90 
91     /// Append new input types to the signature conversion, this should only be
92     /// used if the new types are not intended to remap an existing input.
93     void addInputs(ArrayRef<Type> types);
94 
95     /// Remap an input of the original signature to another `replacement`
96     /// value. This drops the original argument.
97     void remapInput(unsigned origInputNo, Value replacement);
98 
99   private:
100     /// Remap an input of the original signature with a range of types in the
101     /// new signature.
102     void remapInput(unsigned origInputNo, unsigned newInputNo,
103                     unsigned newInputCount = 1);
104 
105     /// The remapping information for each of the original arguments.
106     SmallVector<std::optional<InputMapping>, 4> remappedInputs;
107 
108     /// The set of new argument types.
109     SmallVector<Type, 4> argTypes;
110   };
111 
112   /// The general result of a type attribute conversion callback, allowing
113   /// for early termination. The default constructor creates the na case.
114   class AttributeConversionResult {
115   public:
116     constexpr AttributeConversionResult() : impl() {}
117     AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
118 
119     static AttributeConversionResult result(Attribute attr);
120     static AttributeConversionResult na();
121     static AttributeConversionResult abort();
122 
123     bool hasResult() const;
124     bool isNa() const;
125     bool isAbort() const;
126 
127     Attribute getResult() const;
128 
129   private:
130     AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
131 
132     llvm::PointerIntPair<Attribute, 2> impl;
133     // Note that na is 0 so that we can use PointerIntPair's default
134     // constructor.
135     static constexpr unsigned naTag = 0;
136     static constexpr unsigned resultTag = 1;
137     static constexpr unsigned abortTag = 2;
138   };
139 
140   /// Register a conversion function. A conversion function must be convertible
141   /// to any of the following forms (where `T` is a class derived from `Type`):
142   ///
143   ///   * std::optional<Type>(T)
144   ///     - This form represents a 1-1 type conversion. It should return nullptr
145   ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned,
146   ///       the converter is allowed to try another conversion function to
147   ///       perform the conversion.
148   ///   * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
149   ///     - This form represents a 1-N type conversion. It should return
150   ///       `failure` or `std::nullopt` to signify a failed conversion. If the
151   ///       new set of types is empty, the type is removed and any usages of the
152   ///       existing value are expected to be removed during conversion. If
153   ///       `std::nullopt` is returned, the converter is allowed to try another
154   ///       conversion function to perform the conversion.
155   ///
156   /// Note: When attempting to convert a type, e.g. via 'convertType', the
157   ///       mostly recently added conversions will be invoked first.
158   template <typename FnT, typename T = typename llvm::function_traits<
159                               std::decay_t<FnT>>::template arg_t<0>>
160   void addConversion(FnT &&callback) {
161     registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
162   }
163 
164   /// All of the following materializations require function objects that are
165   /// convertible to the following form:
166   ///   `Value(OpBuilder &, T, ValueRange, Location)`,
167   /// where `T` is any subclass of `Type`. This function is responsible for
168   /// creating an operation, using the OpBuilder and Location provided, that
169   /// "casts" a range of values into a single value of the given type `T`. It
170   /// must return a Value of the type `T` on success and `nullptr` if
171   /// it failed but other materialization should be attempted. Materialization
172   /// functions must be provided when a type conversion may persist after the
173   /// conversion has finished.
174   ///
175   /// Note: Target materializations may optionally accept an additional Type
176   /// parameter, which is the original type of the SSA value. Furthermore, `T`
177   /// can be a TypeRange; in that case, the function must return a
178   /// SmallVector<Value>.
179 
180   /// This method registers a materialization that will be called when
181   /// converting (potentially multiple) block arguments that were the result of
182   /// a signature conversion of a single block argument, to a single SSA value
183   /// with the old block argument type.
184   ///
185   /// Note: Argument materializations are used only with the 1:N dialect
186   /// conversion driver. The 1:N dialect conversion driver will be removed soon
187   /// and so will be argument materializations.
188   template <typename FnT, typename T = typename llvm::function_traits<
189                               std::decay_t<FnT>>::template arg_t<1>>
190   void addArgumentMaterialization(FnT &&callback) {
191     argumentMaterializations.emplace_back(
192         wrapMaterialization<T>(std::forward<FnT>(callback)));
193   }
194 
195   /// This method registers a materialization that will be called when
196   /// converting a replacement value back to its original source type.
197   /// This is used when some uses of the original value persist beyond the main
198   /// conversion.
199   template <typename FnT, typename T = typename llvm::function_traits<
200                               std::decay_t<FnT>>::template arg_t<1>>
201   void addSourceMaterialization(FnT &&callback) {
202     sourceMaterializations.emplace_back(
203         wrapMaterialization<T>(std::forward<FnT>(callback)));
204   }
205 
206   /// This method registers a materialization that will be called when
207   /// converting a value to a target type according to a pattern's type
208   /// converter.
209   ///
210   /// Note: Target materializations can optionally inspect the "original"
211   /// type. This type may be different from the type of the input value.
212   /// For example, let's assume that a conversion pattern "P1" replaced an SSA
213   /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion
214   /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore
215   /// assume that "P2" determines that the converted target type of "t1" is
216   /// "t3", which may be different from "t2". In this example, the target
217   /// materialization will be invoked with: outputType = "t3", inputs = "v2",
218   /// originalType = "t1". Note that the original type "t1" cannot be recovered
219   /// from just "t3" and "v2"; that's why the originalType parameter exists.
220   ///
221   /// Note: During a 1:N conversion, the result types can be a TypeRange. In
222   /// that case the materialization produces a SmallVector<Value>.
223   template <typename FnT, typename T = typename llvm::function_traits<
224                               std::decay_t<FnT>>::template arg_t<1>>
225   void addTargetMaterialization(FnT &&callback) {
226     targetMaterializations.emplace_back(
227         wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
228   }
229 
230   /// Register a conversion function for attributes within types. Type
231   /// converters may call this function in order to allow hoking into the
232   /// translation of attributes that exist within types. For example, a type
233   /// converter for the `memref` type could use these conversions to convert
234   /// memory spaces or layouts in an extensible way.
235   ///
236   /// The conversion functions take a non-null Type or subclass of Type and a
237   /// non-null Attribute (or subclass of Attribute), and returns a
238   /// `AttributeConversionResult`. This result can either contan an `Attribute`,
239   /// which may be `nullptr`, representing the conversion's success,
240   /// `AttributeConversionResult::na()` (the default empty value), indicating
241   /// that the conversion function did not apply and that further conversion
242   /// functions should be checked, or `AttributeConversionResult::abort()`
243   /// indicating that the conversion process should be aborted.
244   ///
245   /// Registered conversion functions are callled in the reverse of the order in
246   /// which they were registered.
247   template <
248       typename FnT,
249       typename T =
250           typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
251       typename A =
252           typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
253   void addTypeAttributeConversion(FnT &&callback) {
254     registerTypeAttributeConversion(
255         wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
256   }
257 
258   /// Convert the given type. This function should return failure if no valid
259   /// conversion exists, success otherwise. If the new set of types is empty,
260   /// the type is removed and any usages of the existing value are expected to
261   /// be removed during conversion.
262   LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
263 
264   /// This hook simplifies defining 1-1 type conversions. This function returns
265   /// the type to convert to on success, and a null type on failure.
266   Type convertType(Type t) const;
267 
268   /// Attempts a 1-1 type conversion, expecting the result type to be
269   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
270   /// and a null type on conversion or cast failure.
271   template <typename TargetType>
272   TargetType convertType(Type t) const {
273     return dyn_cast_or_null<TargetType>(convertType(t));
274   }
275 
276   /// Convert the given set of types, filling 'results' as necessary. This
277   /// returns failure if the conversion of any of the types fails, success
278   /// otherwise.
279   LogicalResult convertTypes(TypeRange types,
280                              SmallVectorImpl<Type> &results) const;
281 
282   /// Return true if the given type is legal for this type converter, i.e. the
283   /// type converts to itself.
284   bool isLegal(Type type) const;
285 
286   /// Return true if all of the given types are legal for this type converter.
287   template <typename RangeT>
288   std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
289                        !std::is_convertible<RangeT, Operation *>::value,
290                    bool>
291   isLegal(RangeT &&range) const {
292     return llvm::all_of(range, [this](Type type) { return isLegal(type); });
293   }
294   /// Return true if the given operation has legal operand and result types.
295   bool isLegal(Operation *op) const;
296 
297   /// Return true if the types of block arguments within the region are legal.
298   bool isLegal(Region *region) const;
299 
300   /// Return true if the inputs and outputs of the given function type are
301   /// legal.
302   bool isSignatureLegal(FunctionType ty) const;
303 
304   /// This method allows for converting a specific argument of a signature. It
305   /// takes as inputs the original argument input number, type.
306   /// On success, it populates 'result' with any new mappings.
307   LogicalResult convertSignatureArg(unsigned inputNo, Type type,
308                                     SignatureConversion &result) const;
309   LogicalResult convertSignatureArgs(TypeRange types,
310                                      SignatureConversion &result,
311                                      unsigned origInputOffset = 0) const;
312 
313   /// This function converts the type signature of the given block, by invoking
314   /// 'convertSignatureArg' for each argument. This function should return a
315   /// valid conversion for the signature on success, std::nullopt otherwise.
316   std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
317 
318   /// Materialize a conversion from a set of types into one result type by
319   /// generating a cast sequence of some kind. See the respective
320   /// `add*Materialization` for more information on the context for these
321   /// methods.
322   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
323                                       Type resultType, ValueRange inputs) const;
324   Value materializeSourceConversion(OpBuilder &builder, Location loc,
325                                     Type resultType, ValueRange inputs) const;
326   Value materializeTargetConversion(OpBuilder &builder, Location loc,
327                                     Type resultType, ValueRange inputs,
328                                     Type originalType = {}) const;
329   SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
330                                                  Location loc,
331                                                  TypeRange resultType,
332                                                  ValueRange inputs,
333                                                  Type originalType = {}) const;
334 
335   /// Convert an attribute present `attr` from within the type `type` using
336   /// the registered conversion functions. If no applicable conversion has been
337   /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
338   /// is a valid return value for this function.
339   std::optional<Attribute> convertTypeAttribute(Type type,
340                                                 Attribute attr) const;
341 
342 private:
343   /// The signature of the callback used to convert a type. If the new set of
344   /// types is empty, the type is removed and any usages of the existing value
345   /// are expected to be removed during conversion.
346   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
347       Type, SmallVectorImpl<Type> &)>;
348 
349   /// The signature of the callback used to materialize a source/argument
350   /// conversion.
351   ///
352   /// Arguments: builder, result type, inputs, location
353   using MaterializationCallbackFn =
354       std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
355 
356   /// The signature of the callback used to materialize a target conversion.
357   ///
358   /// Arguments: builder, result types, inputs, location, original type
359   using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
360       OpBuilder &, TypeRange, ValueRange, Location, Type)>;
361 
362   /// The signature of the callback used to convert a type attribute.
363   using TypeAttributeConversionCallbackFn =
364       std::function<AttributeConversionResult(Type, Attribute)>;
365 
366   /// Generate a wrapper for the given callback. This allows for accepting
367   /// different callback forms, that all compose into a single version.
368   /// With callback of form: `std::optional<Type>(T)`
369   template <typename T, typename FnT>
370   std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
371   wrapCallback(FnT &&callback) const {
372     return wrapCallback<T>([callback = std::forward<FnT>(callback)](
373                                T type, SmallVectorImpl<Type> &results) {
374       if (std::optional<Type> resultOpt = callback(type)) {
375         bool wasSuccess = static_cast<bool>(*resultOpt);
376         if (wasSuccess)
377           results.push_back(*resultOpt);
378         return std::optional<LogicalResult>(success(wasSuccess));
379       }
380       return std::optional<LogicalResult>();
381     });
382   }
383   /// With callback of form: `std::optional<LogicalResult>(
384   ///     T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
385   template <typename T, typename FnT>
386   std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
387                    ConversionCallbackFn>
388   wrapCallback(FnT &&callback) const {
389     return [callback = std::forward<FnT>(callback)](
390                Type type,
391                SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
392       T derivedType = dyn_cast<T>(type);
393       if (!derivedType)
394         return std::nullopt;
395       return callback(derivedType, results);
396     };
397   }
398 
399   /// Register a type conversion.
400   void registerConversion(ConversionCallbackFn callback) {
401     conversions.emplace_back(std::move(callback));
402     cachedDirectConversions.clear();
403     cachedMultiConversions.clear();
404   }
405 
406   /// Generate a wrapper for the given argument/source materialization
407   /// callback. The callback may take any subclass of `Type` and the
408   /// wrapper will check for the target type to be of the expected class
409   /// before calling the callback.
410   template <typename T, typename FnT>
411   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
412     return [callback = std::forward<FnT>(callback)](
413                OpBuilder &builder, Type resultType, ValueRange inputs,
414                Location loc) -> Value {
415       if (T derivedType = dyn_cast<T>(resultType))
416         return callback(builder, derivedType, inputs, loc);
417       return Value();
418     };
419   }
420 
421   /// Generate a wrapper for the given target materialization callback.
422   /// The callback may take any subclass of `Type` and the wrapper will check
423   /// for the target type to be of the expected class before calling the
424   /// callback.
425   ///
426   /// With callback of form:
427   /// - Value(OpBuilder &, T, ValueRange, Location, Type)
428   /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
429   template <typename T, typename FnT>
430   std::enable_if_t<
431       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
432       TargetMaterializationCallbackFn>
433   wrapTargetMaterialization(FnT &&callback) const {
434     return [callback = std::forward<FnT>(callback)](
435                OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
436                Location loc, Type originalType) -> SmallVector<Value> {
437       SmallVector<Value> result;
438       if constexpr (std::is_same<T, TypeRange>::value) {
439         // This is a 1:N target materialization. Return the produces values
440         // directly.
441         result = callback(builder, resultTypes, inputs, loc, originalType);
442       } else if constexpr (std::is_assignable<Type, T>::value) {
443         // This is a 1:1 target materialization. Invoke the callback only if a
444         // single SSA value is requested.
445         if (resultTypes.size() == 1) {
446           // Invoke the callback only if the type class of the callback matches
447           // the requested result type.
448           if (T derivedType = dyn_cast<T>(resultTypes.front())) {
449             // 1:1 materializations produce single values, but we store 1:N
450             // target materialization functions in the type converter. Wrap the
451             // result value in a SmallVector<Value>.
452             Value val =
453                 callback(builder, derivedType, inputs, loc, originalType);
454             if (val)
455               result.push_back(val);
456           }
457         }
458       } else {
459         static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
460       }
461       return result;
462     };
463   }
464   /// With callback of form:
465   /// - Value(OpBuilder &, T, ValueRange, Location)
466   /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
467   template <typename T, typename FnT>
468   std::enable_if_t<
469       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
470       TargetMaterializationCallbackFn>
471   wrapTargetMaterialization(FnT &&callback) const {
472     return wrapTargetMaterialization<T>(
473         [callback = std::forward<FnT>(callback)](
474             OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
475             Type originalType) {
476           return callback(builder, resultTypes, inputs, loc);
477         });
478   }
479 
480   /// Generate a wrapper for the given memory space conversion callback. The
481   /// callback may take any subclass of `Attribute` and the wrapper will check
482   /// for the target attribute to be of the expected class before calling the
483   /// callback.
484   template <typename T, typename A, typename FnT>
485   TypeAttributeConversionCallbackFn
486   wrapTypeAttributeConversion(FnT &&callback) const {
487     return [callback = std::forward<FnT>(callback)](
488                Type type, Attribute attr) -> AttributeConversionResult {
489       if (T derivedType = dyn_cast<T>(type)) {
490         if (A derivedAttr = dyn_cast_or_null<A>(attr))
491           return callback(derivedType, derivedAttr);
492       }
493       return AttributeConversionResult::na();
494     };
495   }
496 
497   /// Register a memory space conversion, clearing caches.
498   void
499   registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
500     typeAttributeConversions.emplace_back(std::move(callback));
501     // Clear type conversions in case a memory space is lingering inside.
502     cachedDirectConversions.clear();
503     cachedMultiConversions.clear();
504   }
505 
506   /// The set of registered conversion functions.
507   SmallVector<ConversionCallbackFn, 4> conversions;
508 
509   /// The list of registered materialization functions.
510   SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
511   SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
512   SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
513 
514   /// The list of registered type attribute conversion functions.
515   SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
516 
517   /// A set of cached conversions to avoid recomputing in the common case.
518   /// Direct 1-1 conversions are the most common, so this cache stores the
519   /// successful 1-1 conversions as well as all failed conversions.
520   mutable DenseMap<Type, Type> cachedDirectConversions;
521   /// This cache stores the successful 1->N conversions, where N != 1.
522   mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
523   /// A mutex used for cache access
524   mutable llvm::sys::SmartRWMutex<true> cacheMutex;
525 };
526 
527 //===----------------------------------------------------------------------===//
528 // Conversion Patterns
529 //===----------------------------------------------------------------------===//
530 
531 /// Base class for the conversion patterns. This pattern class enables type
532 /// conversions, and other uses specific to the conversion framework. As such,
533 /// patterns of this type can only be used with the 'apply*' methods below.
534 class ConversionPattern : public RewritePattern {
535 public:
536   /// Hook for derived classes to implement rewriting. `op` is the (first)
537   /// operation matched by the pattern, `operands` is a list of the rewritten
538   /// operand values that are passed to `op`, `rewriter` can be used to emit the
539   /// new operations. This function should not fail. If some specific cases of
540   /// the operation are not supported, these cases should not be matched.
541   virtual void rewrite(Operation *op, ArrayRef<Value> operands,
542                        ConversionPatternRewriter &rewriter) const {
543     llvm_unreachable("unimplemented rewrite");
544   }
545   virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
546                        ConversionPatternRewriter &rewriter) const {
547     rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
548   }
549 
550   /// Hook for derived classes to implement combined matching and rewriting.
551   /// This overload supports only 1:1 replacements. The 1:N overload is called
552   /// by the driver. By default, it calls this 1:1 overload or reports a fatal
553   /// error if 1:N replacements were found.
554   virtual LogicalResult
555   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
556                   ConversionPatternRewriter &rewriter) const {
557     if (failed(match(op)))
558       return failure();
559     rewrite(op, operands, rewriter);
560     return success();
561   }
562 
563   /// Hook for derived classes to implement combined matching and rewriting.
564   /// This overload supports 1:N replacements.
565   virtual LogicalResult
566   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
567                   ConversionPatternRewriter &rewriter) const {
568     return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
569   }
570 
571   /// Attempt to match and rewrite the IR root at the specified operation.
572   LogicalResult matchAndRewrite(Operation *op,
573                                 PatternRewriter &rewriter) const final;
574 
575   /// Return the type converter held by this pattern, or nullptr if the pattern
576   /// does not require type conversion.
577   const TypeConverter *getTypeConverter() const { return typeConverter; }
578 
579   template <typename ConverterTy>
580   std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
581                    const ConverterTy *>
582   getTypeConverter() const {
583     return static_cast<const ConverterTy *>(typeConverter);
584   }
585 
586 protected:
587   /// See `RewritePattern::RewritePattern` for information on the other
588   /// available constructors.
589   using RewritePattern::RewritePattern;
590   /// Construct a conversion pattern with the given converter, and forward the
591   /// remaining arguments to RewritePattern.
592   template <typename... Args>
593   ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
594       : RewritePattern(std::forward<Args>(args)...),
595         typeConverter(&typeConverter) {}
596 
597   /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
598   /// try to extract the single value of each range to construct a the inputs
599   /// for a 1:1 adaptor.
600   ///
601   /// This function produces a fatal error if at least one range has 0 or
602   /// more than 1 value: "pattern 'name' does not support 1:N conversion"
603   SmallVector<Value>
604   getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
605 
606 protected:
607   /// An optional type converter for use by this pattern.
608   const TypeConverter *typeConverter = nullptr;
609 
610 private:
611   using RewritePattern::rewrite;
612 };
613 
614 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
615 /// matching and rewriting against an instance of a derived operation class as
616 /// opposed to a raw Operation.
617 template <typename SourceOp>
618 class OpConversionPattern : public ConversionPattern {
619 public:
620   using OpAdaptor = typename SourceOp::Adaptor;
621   using OneToNOpAdaptor =
622       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
623 
624   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
625       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
626   OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context,
627                       PatternBenefit benefit = 1)
628       : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
629                           context) {}
630 
631   /// Wrappers around the ConversionPattern methods that pass the derived op
632   /// type.
633   LogicalResult match(Operation *op) const final {
634     return match(cast<SourceOp>(op));
635   }
636   void rewrite(Operation *op, ArrayRef<Value> operands,
637                ConversionPatternRewriter &rewriter) const final {
638     auto sourceOp = cast<SourceOp>(op);
639     rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
640   }
641   void rewrite(Operation *op, ArrayRef<ValueRange> operands,
642                ConversionPatternRewriter &rewriter) const final {
643     auto sourceOp = cast<SourceOp>(op);
644     rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
645   }
646   LogicalResult
647   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
648                   ConversionPatternRewriter &rewriter) const final {
649     auto sourceOp = cast<SourceOp>(op);
650     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
651   }
652   LogicalResult
653   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
654                   ConversionPatternRewriter &rewriter) const final {
655     auto sourceOp = cast<SourceOp>(op);
656     return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
657                            rewriter);
658   }
659 
660   /// Rewrite and Match methods that operate on the SourceOp type. These must be
661   /// overridden by the derived pattern class.
662   virtual LogicalResult match(SourceOp op) const {
663     llvm_unreachable("must override match or matchAndRewrite");
664   }
665   virtual void rewrite(SourceOp op, OpAdaptor adaptor,
666                        ConversionPatternRewriter &rewriter) const {
667     llvm_unreachable("must override matchAndRewrite or a rewrite method");
668   }
669   virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
670                        ConversionPatternRewriter &rewriter) const {
671     SmallVector<Value> oneToOneOperands =
672         getOneToOneAdaptorOperands(adaptor.getOperands());
673     rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
674   }
675   virtual LogicalResult
676   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
677                   ConversionPatternRewriter &rewriter) const {
678     if (failed(match(op)))
679       return failure();
680     rewrite(op, adaptor, rewriter);
681     return success();
682   }
683   virtual LogicalResult
684   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
685                   ConversionPatternRewriter &rewriter) const {
686     SmallVector<Value> oneToOneOperands =
687         getOneToOneAdaptorOperands(adaptor.getOperands());
688     return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
689   }
690 
691 private:
692   using ConversionPattern::matchAndRewrite;
693 };
694 
695 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
696 /// allows for matching and rewriting against an instance of an OpInterface
697 /// class as opposed to a raw Operation.
698 template <typename SourceOp>
699 class OpInterfaceConversionPattern : public ConversionPattern {
700 public:
701   OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
702       : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
703                           SourceOp::getInterfaceID(), benefit, context) {}
704   OpInterfaceConversionPattern(const TypeConverter &typeConverter,
705                                MLIRContext *context, PatternBenefit benefit = 1)
706       : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
707                           SourceOp::getInterfaceID(), benefit, context) {}
708 
709   /// Wrappers around the ConversionPattern methods that pass the derived op
710   /// type.
711   void rewrite(Operation *op, ArrayRef<Value> operands,
712                ConversionPatternRewriter &rewriter) const final {
713     rewrite(cast<SourceOp>(op), operands, rewriter);
714   }
715   void rewrite(Operation *op, ArrayRef<ValueRange> operands,
716                ConversionPatternRewriter &rewriter) const final {
717     rewrite(cast<SourceOp>(op), operands, rewriter);
718   }
719   LogicalResult
720   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
721                   ConversionPatternRewriter &rewriter) const final {
722     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
723   }
724   LogicalResult
725   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
726                   ConversionPatternRewriter &rewriter) const final {
727     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
728   }
729 
730   /// Rewrite and Match methods that operate on the SourceOp type. These must be
731   /// overridden by the derived pattern class.
732   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
733                        ConversionPatternRewriter &rewriter) const {
734     llvm_unreachable("must override matchAndRewrite or a rewrite method");
735   }
736   virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
737                        ConversionPatternRewriter &rewriter) const {
738     rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
739   }
740   virtual LogicalResult
741   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
742                   ConversionPatternRewriter &rewriter) const {
743     if (failed(match(op)))
744       return failure();
745     rewrite(op, operands, rewriter);
746     return success();
747   }
748   virtual LogicalResult
749   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
750                   ConversionPatternRewriter &rewriter) const {
751     return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
752   }
753 
754 private:
755   using ConversionPattern::matchAndRewrite;
756 };
757 
758 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
759 /// for matching and rewriting against instances of an operation that possess a
760 /// given trait.
761 template <template <typename> class TraitType>
762 class OpTraitConversionPattern : public ConversionPattern {
763 public:
764   OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
765       : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
766                           TypeID::get<TraitType>(), benefit, context) {}
767   OpTraitConversionPattern(const TypeConverter &typeConverter,
768                            MLIRContext *context, PatternBenefit benefit = 1)
769       : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
770                           TypeID::get<TraitType>(), benefit, context) {}
771 };
772 
773 /// Generic utility to convert op result types according to type converter
774 /// without knowing exact op type.
775 /// Clones existing op with new result types and returns it.
776 FailureOr<Operation *>
777 convertOpResultTypes(Operation *op, ValueRange operands,
778                      const TypeConverter &converter,
779                      ConversionPatternRewriter &rewriter);
780 
781 /// Add a pattern to the given pattern list to convert the signature of a
782 /// FunctionOpInterface op with the given type converter. This only supports
783 /// ops which use FunctionType to represent their type.
784 void populateFunctionOpInterfaceTypeConversionPattern(
785     StringRef functionLikeOpName, RewritePatternSet &patterns,
786     const TypeConverter &converter);
787 
788 template <typename FuncOpT>
789 void populateFunctionOpInterfaceTypeConversionPattern(
790     RewritePatternSet &patterns, const TypeConverter &converter) {
791   populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
792                                                    patterns, converter);
793 }
794 
795 void populateAnyFunctionOpInterfaceTypeConversionPattern(
796     RewritePatternSet &patterns, const TypeConverter &converter);
797 
798 //===----------------------------------------------------------------------===//
799 // Conversion PatternRewriter
800 //===----------------------------------------------------------------------===//
801 
802 namespace detail {
803 struct ConversionPatternRewriterImpl;
804 } // namespace detail
805 
806 /// This class implements a pattern rewriter for use with ConversionPatterns. It
807 /// extends the base PatternRewriter and provides special conversion specific
808 /// hooks.
809 class ConversionPatternRewriter final : public PatternRewriter {
810 public:
811   ~ConversionPatternRewriter() override;
812 
813   /// Apply a signature conversion to given block. This replaces the block with
814   /// a new block containing the updated signature. The operations of the given
815   /// block are inlined into the newly-created block, which is returned.
816   ///
817   /// If no block argument types are changing, the original block will be
818   /// left in place and returned.
819   ///
820   /// A signature converison must be provided. (Type converters can construct
821   /// a signature conversion with `convertBlockSignature`.)
822   ///
823   /// Optionally, a type converter can be provided to build materializations.
824   /// Note: If no type converter was provided or the type converter does not
825   /// specify any suitable argument/target materialization rules, the dialect
826   /// conversion may fail to legalize unresolved materializations.
827   Block *
828   applySignatureConversion(Block *block,
829                            TypeConverter::SignatureConversion &conversion,
830                            const TypeConverter *converter = nullptr);
831 
832   /// Apply a signature conversion to each block in the given region. This
833   /// replaces each block with a new block containing the updated signature. If
834   /// an updated signature would match the current signature, the respective
835   /// block is left in place as is. (See `applySignatureConversion` for
836   /// details.) The new entry block of the region is returned.
837   ///
838   /// SignatureConversions are computed with the specified type converter.
839   /// This function returns "failure" if the type converter failed to compute
840   /// a SignatureConversion for at least one block.
841   ///
842   /// Optionally, a special SignatureConversion can be specified for the entry
843   /// block. This is because the types of the entry block arguments are often
844   /// tied semantically to the operation.
845   FailureOr<Block *> convertRegionTypes(
846       Region *region, const TypeConverter &converter,
847       TypeConverter::SignatureConversion *entryConversion = nullptr);
848 
849   /// Replace all the uses of the block argument `from` with value `to`.
850   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
851 
852   /// Return the converted value of 'key' with a type defined by the type
853   /// converter of the currently executing pattern. Return nullptr in the case
854   /// of failure, the remapped value otherwise.
855   Value getRemappedValue(Value key);
856 
857   /// Return the converted values that replace 'keys' with types defined by the
858   /// type converter of the currently executing pattern. Returns failure if the
859   /// remap failed, success otherwise.
860   LogicalResult getRemappedValues(ValueRange keys,
861                                   SmallVectorImpl<Value> &results);
862 
863   //===--------------------------------------------------------------------===//
864   // PatternRewriter Hooks
865   //===--------------------------------------------------------------------===//
866 
867   /// Indicate that the conversion rewriter can recover from rewrite failure.
868   /// Recovery is supported via rollback, allowing for continued processing of
869   /// patterns even if a failure is encountered during the rewrite step.
870   bool canRecoverFromRewriteFailure() const override { return true; }
871 
872   /// Replace the given operation with the new values. The number of op results
873   /// and replacement values must match. The types may differ: the dialect
874   /// conversion driver will reconcile any surviving type mismatches at the end
875   /// of the conversion process with source materializations. The given
876   /// operation is erased.
877   void replaceOp(Operation *op, ValueRange newValues) override;
878 
879   /// Replace the given operation with the results of the new op. The number of
880   /// op results must match. The types may differ: the dialect conversion
881   /// driver will reconcile any surviving type mismatches at the end of the
882   /// conversion process with source materializations. The original operation
883   /// is erased.
884   void replaceOp(Operation *op, Operation *newOp) override;
885 
886   /// Replace the given operation with the new value ranges. The number of op
887   /// results and value ranges must match. The given  operation is erased.
888   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
889 
890   /// PatternRewriter hook for erasing a dead operation. The uses of this
891   /// operation *must* be made dead by the end of the conversion process,
892   /// otherwise an assert will be issued.
893   void eraseOp(Operation *op) override;
894 
895   /// PatternRewriter hook for erase all operations in a block. This is not yet
896   /// implemented for dialect conversion.
897   void eraseBlock(Block *block) override;
898 
899   /// PatternRewriter hook for inlining the ops of a block into another block.
900   void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
901                          ValueRange argValues = std::nullopt) override;
902   using PatternRewriter::inlineBlockBefore;
903 
904   /// PatternRewriter hook for updating the given operation in-place.
905   /// Note: These methods only track updates to the given operation itself,
906   /// and not nested regions. Updates to regions will still require notification
907   /// through other more specific hooks above.
908   void startOpModification(Operation *op) override;
909 
910   /// PatternRewriter hook for updating the given operation in-place.
911   void finalizeOpModification(Operation *op) override;
912 
913   /// PatternRewriter hook for updating the given operation in-place.
914   void cancelOpModification(Operation *op) override;
915 
916   /// Return a reference to the internal implementation.
917   detail::ConversionPatternRewriterImpl &getImpl();
918 
919 private:
920   // Allow OperationConverter to construct new rewriters.
921   friend struct OperationConverter;
922 
923   /// Conversion pattern rewriters must not be used outside of dialect
924   /// conversions. They apply some IR rewrites in a delayed fashion and could
925   /// bring the IR into an inconsistent state when used standalone.
926   explicit ConversionPatternRewriter(MLIRContext *ctx,
927                                      const ConversionConfig &config);
928 
929   // Hide unsupported pattern rewriter API.
930   using OpBuilder::setListener;
931 
932   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
933 };
934 
935 //===----------------------------------------------------------------------===//
936 // ConversionTarget
937 //===----------------------------------------------------------------------===//
938 
939 /// This class describes a specific conversion target.
940 class ConversionTarget {
941 public:
942   /// This enumeration corresponds to the specific action to take when
943   /// considering an operation legal for this conversion target.
944   enum class LegalizationAction {
945     /// The target supports this operation.
946     Legal,
947 
948     /// This operation has dynamic legalization constraints that must be checked
949     /// by the target.
950     Dynamic,
951 
952     /// The target explicitly does not support this operation.
953     Illegal,
954   };
955 
956   /// A structure containing additional information describing a specific legal
957   /// operation instance.
958   struct LegalOpDetails {
959     /// A flag that indicates if this operation is 'recursively' legal. This
960     /// means that if an operation is legal, either statically or dynamically,
961     /// all of the operations nested within are also considered legal.
962     bool isRecursivelyLegal = false;
963   };
964 
965   /// The signature of the callback used to determine if an operation is
966   /// dynamically legal on the target.
967   using DynamicLegalityCallbackFn =
968       std::function<std::optional<bool>(Operation *)>;
969 
970   ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
971   virtual ~ConversionTarget() = default;
972 
973   //===--------------------------------------------------------------------===//
974   // Legality Registration
975   //===--------------------------------------------------------------------===//
976 
977   /// Register a legality action for the given operation.
978   void setOpAction(OperationName op, LegalizationAction action);
979   template <typename OpT>
980   void setOpAction(LegalizationAction action) {
981     setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
982   }
983 
984   /// Register the given operations as legal.
985   void addLegalOp(OperationName op) {
986     setOpAction(op, LegalizationAction::Legal);
987   }
988   template <typename OpT>
989   void addLegalOp() {
990     addLegalOp(OperationName(OpT::getOperationName(), &ctx));
991   }
992   template <typename OpT, typename OpT2, typename... OpTs>
993   void addLegalOp() {
994     addLegalOp<OpT>();
995     addLegalOp<OpT2, OpTs...>();
996   }
997 
998   /// Register the given operation as dynamically legal and set the dynamic
999   /// legalization callback to the one provided.
1000   void addDynamicallyLegalOp(OperationName op,
1001                              const DynamicLegalityCallbackFn &callback) {
1002     setOpAction(op, LegalizationAction::Dynamic);
1003     setLegalityCallback(op, callback);
1004   }
1005   template <typename OpT>
1006   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
1007     addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1008                           callback);
1009   }
1010   template <typename OpT, typename OpT2, typename... OpTs>
1011   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
1012     addDynamicallyLegalOp<OpT>(callback);
1013     addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1014   }
1015   template <typename OpT, class Callable>
1016   std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1017   addDynamicallyLegalOp(Callable &&callback) {
1018     addDynamicallyLegalOp<OpT>(
1019         [=](Operation *op) { return callback(cast<OpT>(op)); });
1020   }
1021 
1022   /// Register the given operation as illegal, i.e. this operation is known to
1023   /// not be supported by this target.
1024   void addIllegalOp(OperationName op) {
1025     setOpAction(op, LegalizationAction::Illegal);
1026   }
1027   template <typename OpT>
1028   void addIllegalOp() {
1029     addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1030   }
1031   template <typename OpT, typename OpT2, typename... OpTs>
1032   void addIllegalOp() {
1033     addIllegalOp<OpT>();
1034     addIllegalOp<OpT2, OpTs...>();
1035   }
1036 
1037   /// Mark an operation, that *must* have either been set as `Legal` or
1038   /// `DynamicallyLegal`, as being recursively legal. This means that in
1039   /// addition to the operation itself, all of the operations nested within are
1040   /// also considered legal. An optional dynamic legality callback may be
1041   /// provided to mark subsets of legal instances as recursively legal.
1042   void markOpRecursivelyLegal(OperationName name,
1043                               const DynamicLegalityCallbackFn &callback);
1044   template <typename OpT>
1045   void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
1046     OperationName opName(OpT::getOperationName(), &ctx);
1047     markOpRecursivelyLegal(opName, callback);
1048   }
1049   template <typename OpT, typename OpT2, typename... OpTs>
1050   void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
1051     markOpRecursivelyLegal<OpT>(callback);
1052     markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1053   }
1054   template <typename OpT, class Callable>
1055   std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1056   markOpRecursivelyLegal(Callable &&callback) {
1057     markOpRecursivelyLegal<OpT>(
1058         [=](Operation *op) { return callback(cast<OpT>(op)); });
1059   }
1060 
1061   /// Register a legality action for the given dialects.
1062   void setDialectAction(ArrayRef<StringRef> dialectNames,
1063                         LegalizationAction action);
1064 
1065   /// Register the operations of the given dialects as legal.
1066   template <typename... Names>
1067   void addLegalDialect(StringRef name, Names... names) {
1068     SmallVector<StringRef, 2> dialectNames({name, names...});
1069     setDialectAction(dialectNames, LegalizationAction::Legal);
1070   }
1071   template <typename... Args>
1072   void addLegalDialect() {
1073     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1074     setDialectAction(dialectNames, LegalizationAction::Legal);
1075   }
1076 
1077   /// Register the operations of the given dialects as dynamically legal, i.e.
1078   /// requiring custom handling by the callback.
1079   template <typename... Names>
1080   void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback,
1081                                   StringRef name, Names... names) {
1082     SmallVector<StringRef, 2> dialectNames({name, names...});
1083     setDialectAction(dialectNames, LegalizationAction::Dynamic);
1084     setLegalityCallback(dialectNames, callback);
1085   }
1086   template <typename... Args>
1087   void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
1088     addDynamicallyLegalDialect(std::move(callback),
1089                                Args::getDialectNamespace()...);
1090   }
1091 
1092   /// Register unknown operations as dynamically legal. For operations(and
1093   /// dialects) that do not have a set legalization action, treat them as
1094   /// dynamically legal and invoke the given callback.
1095   void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
1096     setLegalityCallback(fn);
1097   }
1098 
1099   /// Register the operations of the given dialects as illegal, i.e.
1100   /// operations of this dialect are not supported by the target.
1101   template <typename... Names>
1102   void addIllegalDialect(StringRef name, Names... names) {
1103     SmallVector<StringRef, 2> dialectNames({name, names...});
1104     setDialectAction(dialectNames, LegalizationAction::Illegal);
1105   }
1106   template <typename... Args>
1107   void addIllegalDialect() {
1108     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1109     setDialectAction(dialectNames, LegalizationAction::Illegal);
1110   }
1111 
1112   //===--------------------------------------------------------------------===//
1113   // Legality Querying
1114   //===--------------------------------------------------------------------===//
1115 
1116   /// Get the legality action for the given operation.
1117   std::optional<LegalizationAction> getOpAction(OperationName op) const;
1118 
1119   /// If the given operation instance is legal on this target, a structure
1120   /// containing legality information is returned. If the operation is not
1121   /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1122   /// legality wasn't registered by user or dynamic legality callbacks returned
1123   /// None.
1124   ///
1125   /// Note: Legality is actually a 4-state: Legal(recursive=true),
1126   /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1127   /// either as Legal or Illegal depending on context.
1128   std::optional<LegalOpDetails> isLegal(Operation *op) const;
1129 
1130   /// Returns true is operation instance is illegal on this target. Returns
1131   /// false if operation is legal, operation legality wasn't registered by user
1132   /// or dynamic legality callbacks returned None.
1133   bool isIllegal(Operation *op) const;
1134 
1135 private:
1136   /// Set the dynamic legality callback for the given operation.
1137   void setLegalityCallback(OperationName name,
1138                            const DynamicLegalityCallbackFn &callback);
1139 
1140   /// Set the dynamic legality callback for the given dialects.
1141   void setLegalityCallback(ArrayRef<StringRef> dialects,
1142                            const DynamicLegalityCallbackFn &callback);
1143 
1144   /// Set the dynamic legality callback for the unknown ops.
1145   void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1146 
1147   /// The set of information that configures the legalization of an operation.
1148   struct LegalizationInfo {
1149     /// The legality action this operation was given.
1150     LegalizationAction action = LegalizationAction::Illegal;
1151 
1152     /// If some legal instances of this operation may also be recursively legal.
1153     bool isRecursivelyLegal = false;
1154 
1155     /// The legality callback if this operation is dynamically legal.
1156     DynamicLegalityCallbackFn legalityFn;
1157   };
1158 
1159   /// Get the legalization information for the given operation.
1160   std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1161 
1162   /// A deterministic mapping of operation name and its respective legality
1163   /// information.
1164   llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1165 
1166   /// A set of legality callbacks for given operation names that are used to
1167   /// check if an operation instance is recursively legal.
1168   DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1169 
1170   /// A deterministic mapping of dialect name to the specific legality action to
1171   /// take.
1172   llvm::StringMap<LegalizationAction> legalDialects;
1173 
1174   /// A set of dynamic legality callbacks for given dialect names.
1175   llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1176 
1177   /// An optional legality callback for unknown operations.
1178   DynamicLegalityCallbackFn unknownLegalityFn;
1179 
1180   /// The current context this target applies to.
1181   MLIRContext &ctx;
1182 };
1183 
1184 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1185 //===----------------------------------------------------------------------===//
1186 // PDL Configuration
1187 //===----------------------------------------------------------------------===//
1188 
1189 /// A PDL configuration that is used to supported dialect conversion
1190 /// functionality.
1191 class PDLConversionConfig final
1192     : public PDLPatternConfigBase<PDLConversionConfig> {
1193 public:
1194   PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1195   ~PDLConversionConfig() final = default;
1196 
1197   /// Return the type converter used by this configuration, which may be nullptr
1198   /// if no type conversions are expected.
1199   const TypeConverter *getTypeConverter() const { return converter; }
1200 
1201   /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1202   /// pattern.
1203   void notifyRewriteBegin(PatternRewriter &rewriter) final;
1204   void notifyRewriteEnd(PatternRewriter &rewriter) final;
1205 
1206 private:
1207   /// An optional type converter to use for the pattern.
1208   const TypeConverter *converter;
1209 };
1210 
1211 /// Register the dialect conversion PDL functions with the given pattern set.
1212 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1213 
1214 #else
1215 
1216 // Stubs for when PDL in rewriting is not enabled.
1217 
1218 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1219 
1220 class PDLConversionConfig final {
1221 public:
1222   PDLConversionConfig(const TypeConverter * /*converter*/) {}
1223 };
1224 
1225 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1226 
1227 //===----------------------------------------------------------------------===//
1228 // ConversionConfig
1229 //===----------------------------------------------------------------------===//
1230 
1231 /// Dialect conversion configuration.
1232 struct ConversionConfig {
1233   /// An optional callback used to notify about match failure diagnostics during
1234   /// the conversion. Diagnostics reported to this callback may only be
1235   /// available in debug mode.
1236   function_ref<void(Diagnostic &)> notifyCallback = nullptr;
1237 
1238   /// Partial conversion only. All operations that are found not to be
1239   /// legalizable are placed in this set. (Note that if there is an op
1240   /// explicitly marked as illegal, the conversion terminates and the set will
1241   /// not necessarily be complete.)
1242   DenseSet<Operation *> *unlegalizedOps = nullptr;
1243 
1244   /// Analysis conversion only. All operations that are found to be legalizable
1245   /// are placed in this set. Note that no actual rewrites are applied to the
1246   /// IR during an analysis conversion and only pre-existing operations are
1247   /// added to the set.
1248   DenseSet<Operation *> *legalizableOps = nullptr;
1249 
1250   /// An optional listener that is notified about all IR modifications in case
1251   /// dialect conversion succeeds. If the dialect conversion fails and no IR
1252   /// modifications are visible (i.e., they were all rolled back), or if the
1253   /// dialect conversion is an "analysis conversion", no notifications are
1254   /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1255   ///
1256   /// Note: Notifications are sent in a delayed fashion, when the dialect
1257   /// conversion is guaranteed to succeed. At that point, some IR modifications
1258   /// may already have been materialized. Consequently, operations/blocks that
1259   /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1260   /// guaranteed to be valid pointers and accessing op names is allowed. But
1261   /// there are no guarantees about the state of ops/blocks at the time that a
1262   /// callback is triggered.)
1263   ///
1264   /// Example: Consider a dialect conversion a new op ("test.foo") is created
1265   /// and inserted, and later moved to another block. (Moving ops also triggers
1266   /// "notifyOperationInserted".)
1267   ///
1268   /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1269   /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1270   ///
1271   /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1272   /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1273   /// immediately performed by the dialect conversion (and rolled back upon
1274   /// failure).
1275   //
1276   // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1277   // callback, the previous region/block is provided to the callback, but not
1278   // the iterator pointing to the exact location within the region/block. That
1279   // is because these notifications are sent with a delay (after the IR has
1280   // already been modified) and iterators into past IR state cannot be
1281   // represented at the moment.
1282   RewriterBase::Listener *listener = nullptr;
1283 
1284   /// If set to "true", the dialect conversion attempts to build source/target
1285   /// materializations through the type converter API in lieu of
1286   /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1287   /// at least one materialization could not be built.
1288   ///
1289   /// If set to "false", the dialect conversion does not build any custom
1290   /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1291   /// ops to ensure that the resulting IR is valid.
1292   bool buildMaterializations = true;
1293 };
1294 
1295 //===----------------------------------------------------------------------===//
1296 // Reconcile Unrealized Casts
1297 //===----------------------------------------------------------------------===//
1298 
1299 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1300 /// left-over ops in `remainingCastOps` (if provided).
1301 ///
1302 /// This function processes cast ops in a worklist-driven fashion. For each
1303 /// cast op, if the chain of input casts eventually reaches a cast op where the
1304 /// input types match the output types of the matched op, replace the matched
1305 /// op with the inputs.
1306 ///
1307 /// Example:
1308 /// %1 = unrealized_conversion_cast %0 : !A to !B
1309 /// %2 = unrealized_conversion_cast %1 : !B to !C
1310 /// %3 = unrealized_conversion_cast %2 : !C to !A
1311 ///
1312 /// In the above example, %0 can be used instead of %3 and all cast ops are
1313 /// folded away.
1314 void reconcileUnrealizedCasts(
1315     ArrayRef<UnrealizedConversionCastOp> castOps,
1316     SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1317 
1318 //===----------------------------------------------------------------------===//
1319 // Op Conversion Entry Points
1320 //===----------------------------------------------------------------------===//
1321 
1322 /// Below we define several entry points for operation conversion. It is
1323 /// important to note that the patterns provided to the conversion framework may
1324 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1325 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1326 /// the use of the PatternRewriter.
1327 
1328 /// Apply a partial conversion on the given operations and all nested
1329 /// operations. This method converts as many operations to the target as
1330 /// possible, ignoring operations that failed to legalize. This method only
1331 /// returns failure if there ops explicitly marked as illegal.
1332 LogicalResult
1333 applyPartialConversion(ArrayRef<Operation *> ops,
1334                        const ConversionTarget &target,
1335                        const FrozenRewritePatternSet &patterns,
1336                        ConversionConfig config = ConversionConfig());
1337 LogicalResult
1338 applyPartialConversion(Operation *op, const ConversionTarget &target,
1339                        const FrozenRewritePatternSet &patterns,
1340                        ConversionConfig config = ConversionConfig());
1341 
1342 /// Apply a complete conversion on the given operations, and all nested
1343 /// operations. This method returns failure if the conversion of any operation
1344 /// fails, or if there are unreachable blocks in any of the regions nested
1345 /// within 'ops'.
1346 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1347                                   const ConversionTarget &target,
1348                                   const FrozenRewritePatternSet &patterns,
1349                                   ConversionConfig config = ConversionConfig());
1350 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1351                                   const FrozenRewritePatternSet &patterns,
1352                                   ConversionConfig config = ConversionConfig());
1353 
1354 /// Apply an analysis conversion on the given operations, and all nested
1355 /// operations. This method analyzes which operations would be successfully
1356 /// converted to the target if a conversion was applied. All operations that
1357 /// were found to be legalizable to the given 'target' are placed within the
1358 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1359 /// applied to the operations on success. This method only returns failure if
1360 /// there are unreachable blocks in any of the regions nested within 'ops'.
1361 LogicalResult
1362 applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
1363                         const FrozenRewritePatternSet &patterns,
1364                         ConversionConfig config = ConversionConfig());
1365 LogicalResult
1366 applyAnalysisConversion(Operation *op, ConversionTarget &target,
1367                         const FrozenRewritePatternSet &patterns,
1368                         ConversionConfig config = ConversionConfig());
1369 } // namespace mlir
1370 
1371 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
1372