xref: /llvm-project/mlir/include/mlir/Transforms/OneToNTypeConversion.h (revision 4751f47c7af63315565891a1d112376b52e6b826)
1 //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Note: The 1:N dialect conversion is deprecated and will be removed soon.
10 // 1:N support has been added to the regular dialect conversion driver.
11 //
12 // This file provides utils for implementing (poor-man's) dialect conversion
13 // passes with 1:N type conversions.
14 //
15 // The main function, `applyPartialOneToNConversion`, first applies a set of
16 // `RewritePattern`s, which produce unrealized casts to convert the operands and
17 // results from and to the source types, and then replaces all newly added
18 // unrealized casts by user-provided materializations. For this to work, the
19 // main function requires a special `TypeConverter`, a special
20 // `PatternRewriter`, and special RewritePattern`s, which extend their
21 // respective base classes for 1:N type converions.
22 //
23 // Note that this is much more simple-minded than the "real" dialect conversion,
24 // which checks for legality before applying patterns and does probably many
25 // other additional things. Ideally, some of the extensions here could be
26 // integrated there.
27 //
28 //===----------------------------------------------------------------------===//
29 
30 #ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
31 #define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
32 
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Transforms/DialectConversion.h"
35 #include "llvm/ADT/SmallVector.h"
36 
37 namespace mlir {
38 
39 /// Stores a 1:N mapping of types and provides several useful accessors. This
40 /// class extends `SignatureConversion`, which already supports 1:N type
41 /// mappings but lacks some accessors into the mapping as well as access to the
42 /// original types.
43 class OneToNTypeMapping : public TypeConverter::SignatureConversion {
44 public:
45   OneToNTypeMapping(TypeRange originalTypes)
46       : TypeConverter::SignatureConversion(originalTypes.size()),
47         originalTypes(originalTypes) {}
48 
49   using TypeConverter::SignatureConversion::getConvertedTypes;
50 
51   /// Returns the list of types that corresponds to the original type at the
52   /// given index.
53   TypeRange getConvertedTypes(unsigned originalTypeNo) const;
54 
55   /// Returns the list of original types.
56   TypeRange getOriginalTypes() const { return originalTypes; }
57 
58   /// Returns the slice of converted values that corresponds the original value
59   /// at the given index.
60   ValueRange getConvertedValues(ValueRange convertedValues,
61                                 unsigned originalValueNo) const;
62 
63   /// Fills the given result vector with as many copies of the location of the
64   /// original value as the number of values it is converted to.
65   void convertLocation(Value originalValue, unsigned originalValueNo,
66                        llvm::SmallVectorImpl<Location> &result) const;
67 
68   /// Fills the given result vector with as many copies of the lociation of each
69   /// original value as the number of values they are respectively converted to.
70   void convertLocations(ValueRange originalValues,
71                         llvm::SmallVectorImpl<Location> &result) const;
72 
73   /// Returns true iff at least one type conversion maps an input type to a type
74   /// that is different from itself.
75   bool hasNonIdentityConversion() const;
76 
77 private:
78   llvm::SmallVector<Type> originalTypes;
79 };
80 
81 /// Extends the basic `RewritePattern` class with a type converter member and
82 /// some accessors to it. This is useful for patterns that are not
83 /// `ConversionPattern`s but still require access to a type converter.
84 class RewritePatternWithConverter : public mlir::RewritePattern {
85 public:
86   /// Construct a conversion pattern with the given converter, and forward the
87   /// remaining arguments to RewritePattern.
88   template <typename... Args>
89   RewritePatternWithConverter(const TypeConverter &typeConverter,
90                               Args &&...args)
91       : RewritePattern(std::forward<Args>(args)...),
92         typeConverter(&typeConverter) {}
93 
94   /// Return the type converter held by this pattern, or nullptr if the pattern
95   /// does not require type conversion.
96   const TypeConverter *getTypeConverter() const { return typeConverter; }
97 
98   template <typename ConverterTy>
99   std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
100                    const ConverterTy *>
101   getTypeConverter() const {
102     return static_cast<const ConverterTy *>(typeConverter);
103   }
104 
105 protected:
106   /// A type converter for use by this pattern.
107   const TypeConverter *const typeConverter;
108 };
109 
110 /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
111 /// class provides additional rewrite methods that are specific to 1:N type
112 /// conversions.
113 class OneToNPatternRewriter : public PatternRewriter {
114 public:
115   OneToNPatternRewriter(MLIRContext *context,
116                         OpBuilder::Listener *listener = nullptr)
117       : PatternRewriter(context, listener) {}
118 
119   /// Replaces the results of the operation with the specified list of values
120   /// mapped back to the original types as specified in the provided type
121   /// mapping. That type mapping must match the replaced op (i.e., the original
122   /// types must be the same as the result types of the op) and the new values
123   /// (i.e., the converted types must be the same as the types of the new
124   /// values).
125   /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
126   /// Use replaceOpWithMultiple() instead.
127   void replaceOp(Operation *op, ValueRange newValues,
128                  const OneToNTypeMapping &resultMapping);
129   using PatternRewriter::replaceOp;
130 
131   /// Applies the given argument conversion to the given block. This consists of
132   /// replacing each original argument with N arguments as specified in the
133   /// argument conversion and inserting unrealized casts from the converted
134   /// values to the original types, which are then used in lieu of the original
135   /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
136   /// with a user-provided argument materialization if necessary.) This is
137   /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
138   /// type conversion properly and probably (2) doesn't handle many other edge
139   /// cases.
140   Block *applySignatureConversion(Block *block,
141                                   OneToNTypeMapping &argumentConversion);
142 };
143 
144 /// Base class for patterns with 1:N type conversions. Derived classes have to
145 /// overwrite the `matchAndRewrite` overlaod that provides additional
146 /// information for 1:N type conversions.
147 class OneToNConversionPattern : public RewritePatternWithConverter {
148 public:
149   using RewritePatternWithConverter::RewritePatternWithConverter;
150 
151   /// This function has to be implemented by derived classes and is called from
152   /// the usual overloads. Like in "normal" `DialectConversion`, the function is
153   /// provided with the converted operands (which thus have target types). Since
154   /// 1:N conversions are supported, there is usually no 1:1 relationship
155   /// between the original and the converted operands. Instead, the provided
156   /// `operandMapping` can be used to access the converted operands that
157   /// correspond to a particular original operand. Similarly, `resultMapping`
158   /// is provided to help with assembling the result values, which may have 1:N
159   /// correspondences as well. In that case, the original op should be replaced
160   /// with the overload of `replaceOp` that takes the provided `resultMapping`
161   /// in order to deal with the mapping of converted result values to their
162   /// usages in the original types correctly.
163   virtual LogicalResult matchAndRewrite(Operation *op,
164                                         OneToNPatternRewriter &rewriter,
165                                         const OneToNTypeMapping &operandMapping,
166                                         const OneToNTypeMapping &resultMapping,
167                                         ValueRange convertedOperands) const = 0;
168 
169   LogicalResult matchAndRewrite(Operation *op,
170                                 PatternRewriter &rewriter) const final;
171 };
172 
173 /// This class is a wrapper around `OneToNConversionPattern` for matching
174 /// against instances of a particular op class.
175 template <typename SourceOp>
176 class OneToNOpConversionPattern : public OneToNConversionPattern {
177 public:
178   OneToNOpConversionPattern(const TypeConverter &typeConverter,
179                             MLIRContext *context, PatternBenefit benefit = 1,
180                             ArrayRef<StringRef> generatedNames = {})
181       : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
182                                 benefit, context, generatedNames) {}
183   /// Generic adaptor around the root op of this pattern using the converted
184   /// operands. Importantly, each operand is represented as a *range* of values,
185   /// namely the N values each original operand gets converted to. Concretely,
186   /// this makes the result type of the accessor functions of the adaptor class
187   /// be a `ValueRange`.
188   class OpAdaptor
189       : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
190   public:
191     using RangeT = ArrayRef<ValueRange>;
192     using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
193     using Properties = typename SourceOp::template InferredProperties<SourceOp>;
194 
195     OpAdaptor(const OneToNTypeMapping *operandMapping,
196               const OneToNTypeMapping *resultMapping,
197               const ValueRange *convertedOperands, RangeT values, SourceOp op)
198         : BaseT(values, op), operandMapping(operandMapping),
199           resultMapping(resultMapping), convertedOperands(convertedOperands) {}
200 
201     /// Get the type mapping of the original operands to the converted operands.
202     const OneToNTypeMapping &getOperandMapping() const {
203       return *operandMapping;
204     }
205 
206     /// Get the type mapping of the original results to the converted results.
207     const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
208 
209     /// Get a flat range of all converted operands. Unlike `getOperands`, which
210     /// returns an `ArrayRef` with one `ValueRange` for each original operand,
211     /// this function returns a `ValueRange` that contains all converted
212     /// operands irrespectively of which operand they originated from.
213     ValueRange getFlatOperands() const { return *convertedOperands; }
214 
215   private:
216     const OneToNTypeMapping *operandMapping;
217     const OneToNTypeMapping *resultMapping;
218     const ValueRange *convertedOperands;
219   };
220 
221   using OneToNConversionPattern::matchAndRewrite;
222 
223   /// Overload that derived classes have to override for their op type.
224   virtual LogicalResult
225   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
226                   OneToNPatternRewriter &rewriter) const = 0;
227 
228   LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
229                                 const OneToNTypeMapping &operandMapping,
230                                 const OneToNTypeMapping &resultMapping,
231                                 ValueRange convertedOperands) const final {
232     // Wrap converted operands and type mappings into an adaptor.
233     SmallVector<ValueRange> valueRanges;
234     for (int64_t i = 0; i < op->getNumOperands(); i++) {
235       auto values = operandMapping.getConvertedValues(convertedOperands, i);
236       valueRanges.push_back(values);
237     }
238     OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
239                       valueRanges, cast<SourceOp>(op));
240 
241     // Call overload implemented by the derived class.
242     return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
243   }
244 };
245 
246 /// Applies the given set of patterns recursively on the given op and adds user
247 /// materializations where necessary. The patterns are expected to be
248 /// `OneToNConversionPattern`, which help converting the types of the operands
249 /// and results of the matched ops. The provided type converter is used to
250 /// convert the operands of matched ops from their original types to operands
251 /// with different types. Unlike in `DialectConversion`, this supports 1:N type
252 /// conversions. Those conversions at the "boundary" of the pattern application,
253 /// where converted results are not consumed by replaced ops that expect the
254 /// converted operands or vice versa, the function inserts user materializations
255 /// from the type converter. Also unlike `DialectConversion`, there are no legal
256 /// or illegal types; the function simply applies the given patterns and does
257 /// not fail if some ops or types remain unconverted (i.e., the conversion is
258 /// only "partial").
259 /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
260 /// 1:N support has been added to the regular dialect conversion driver.
261 /// Use applyPartialConversion() instead.
262 LogicalResult
263 applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
264                              const FrozenRewritePatternSet &patterns);
265 
266 /// Add a pattern to the given pattern list to convert the signature of a
267 /// FunctionOpInterface op with the given type converter. This only supports
268 /// ops which use FunctionType to represent their type. This is intended to be
269 /// used with the 1:N dialect conversion.
270 /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
271 /// 1:N support has been added to the regular dialect conversion driver.
272 /// Use populateFunctionOpInterfaceTypeConversionPattern() instead.
273 void populateOneToNFunctionOpInterfaceTypeConversionPattern(
274     StringRef functionLikeOpName, const TypeConverter &converter,
275     RewritePatternSet &patterns);
276 template <typename FuncOpT>
277 void populateOneToNFunctionOpInterfaceTypeConversionPattern(
278     const TypeConverter &converter, RewritePatternSet &patterns) {
279   populateOneToNFunctionOpInterfaceTypeConversionPattern(
280       FuncOpT::getOperationName(), converter, patterns);
281 }
282 
283 } // namespace mlir
284 
285 #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
286