1 //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// 2 // 3 // Licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Note: The 1:N dialect conversion is deprecated and will be removed soon. 10 // 1:N support has been added to the regular dialect conversion driver. 11 // 12 // This file provides utils for implementing (poor-man's) dialect conversion 13 // passes with 1:N type conversions. 14 // 15 // The main function, `applyPartialOneToNConversion`, first applies a set of 16 // `RewritePattern`s, which produce unrealized casts to convert the operands and 17 // results from and to the source types, and then replaces all newly added 18 // unrealized casts by user-provided materializations. For this to work, the 19 // main function requires a special `TypeConverter`, a special 20 // `PatternRewriter`, and special RewritePattern`s, which extend their 21 // respective base classes for 1:N type converions. 22 // 23 // Note that this is much more simple-minded than the "real" dialect conversion, 24 // which checks for legality before applying patterns and does probably many 25 // other additional things. Ideally, some of the extensions here could be 26 // integrated there. 27 // 28 //===----------------------------------------------------------------------===// 29 30 #ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H 31 #define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H 32 33 #include "mlir/IR/PatternMatch.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "llvm/ADT/SmallVector.h" 36 37 namespace mlir { 38 39 /// Stores a 1:N mapping of types and provides several useful accessors. This 40 /// class extends `SignatureConversion`, which already supports 1:N type 41 /// mappings but lacks some accessors into the mapping as well as access to the 42 /// original types. 43 class OneToNTypeMapping : public TypeConverter::SignatureConversion { 44 public: 45 OneToNTypeMapping(TypeRange originalTypes) 46 : TypeConverter::SignatureConversion(originalTypes.size()), 47 originalTypes(originalTypes) {} 48 49 using TypeConverter::SignatureConversion::getConvertedTypes; 50 51 /// Returns the list of types that corresponds to the original type at the 52 /// given index. 53 TypeRange getConvertedTypes(unsigned originalTypeNo) const; 54 55 /// Returns the list of original types. 56 TypeRange getOriginalTypes() const { return originalTypes; } 57 58 /// Returns the slice of converted values that corresponds the original value 59 /// at the given index. 60 ValueRange getConvertedValues(ValueRange convertedValues, 61 unsigned originalValueNo) const; 62 63 /// Fills the given result vector with as many copies of the location of the 64 /// original value as the number of values it is converted to. 65 void convertLocation(Value originalValue, unsigned originalValueNo, 66 llvm::SmallVectorImpl<Location> &result) const; 67 68 /// Fills the given result vector with as many copies of the lociation of each 69 /// original value as the number of values they are respectively converted to. 70 void convertLocations(ValueRange originalValues, 71 llvm::SmallVectorImpl<Location> &result) const; 72 73 /// Returns true iff at least one type conversion maps an input type to a type 74 /// that is different from itself. 75 bool hasNonIdentityConversion() const; 76 77 private: 78 llvm::SmallVector<Type> originalTypes; 79 }; 80 81 /// Extends the basic `RewritePattern` class with a type converter member and 82 /// some accessors to it. This is useful for patterns that are not 83 /// `ConversionPattern`s but still require access to a type converter. 84 class RewritePatternWithConverter : public mlir::RewritePattern { 85 public: 86 /// Construct a conversion pattern with the given converter, and forward the 87 /// remaining arguments to RewritePattern. 88 template <typename... Args> 89 RewritePatternWithConverter(const TypeConverter &typeConverter, 90 Args &&...args) 91 : RewritePattern(std::forward<Args>(args)...), 92 typeConverter(&typeConverter) {} 93 94 /// Return the type converter held by this pattern, or nullptr if the pattern 95 /// does not require type conversion. 96 const TypeConverter *getTypeConverter() const { return typeConverter; } 97 98 template <typename ConverterTy> 99 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value, 100 const ConverterTy *> 101 getTypeConverter() const { 102 return static_cast<const ConverterTy *>(typeConverter); 103 } 104 105 protected: 106 /// A type converter for use by this pattern. 107 const TypeConverter *const typeConverter; 108 }; 109 110 /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The 111 /// class provides additional rewrite methods that are specific to 1:N type 112 /// conversions. 113 class OneToNPatternRewriter : public PatternRewriter { 114 public: 115 OneToNPatternRewriter(MLIRContext *context, 116 OpBuilder::Listener *listener = nullptr) 117 : PatternRewriter(context, listener) {} 118 119 /// Replaces the results of the operation with the specified list of values 120 /// mapped back to the original types as specified in the provided type 121 /// mapping. That type mapping must match the replaced op (i.e., the original 122 /// types must be the same as the result types of the op) and the new values 123 /// (i.e., the converted types must be the same as the types of the new 124 /// values). 125 /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. 126 /// Use replaceOpWithMultiple() instead. 127 void replaceOp(Operation *op, ValueRange newValues, 128 const OneToNTypeMapping &resultMapping); 129 using PatternRewriter::replaceOp; 130 131 /// Applies the given argument conversion to the given block. This consists of 132 /// replacing each original argument with N arguments as specified in the 133 /// argument conversion and inserting unrealized casts from the converted 134 /// values to the original types, which are then used in lieu of the original 135 /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts 136 /// with a user-provided argument materialization if necessary.) This is 137 /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N 138 /// type conversion properly and probably (2) doesn't handle many other edge 139 /// cases. 140 Block *applySignatureConversion(Block *block, 141 OneToNTypeMapping &argumentConversion); 142 }; 143 144 /// Base class for patterns with 1:N type conversions. Derived classes have to 145 /// overwrite the `matchAndRewrite` overlaod that provides additional 146 /// information for 1:N type conversions. 147 class OneToNConversionPattern : public RewritePatternWithConverter { 148 public: 149 using RewritePatternWithConverter::RewritePatternWithConverter; 150 151 /// This function has to be implemented by derived classes and is called from 152 /// the usual overloads. Like in "normal" `DialectConversion`, the function is 153 /// provided with the converted operands (which thus have target types). Since 154 /// 1:N conversions are supported, there is usually no 1:1 relationship 155 /// between the original and the converted operands. Instead, the provided 156 /// `operandMapping` can be used to access the converted operands that 157 /// correspond to a particular original operand. Similarly, `resultMapping` 158 /// is provided to help with assembling the result values, which may have 1:N 159 /// correspondences as well. In that case, the original op should be replaced 160 /// with the overload of `replaceOp` that takes the provided `resultMapping` 161 /// in order to deal with the mapping of converted result values to their 162 /// usages in the original types correctly. 163 virtual LogicalResult matchAndRewrite(Operation *op, 164 OneToNPatternRewriter &rewriter, 165 const OneToNTypeMapping &operandMapping, 166 const OneToNTypeMapping &resultMapping, 167 ValueRange convertedOperands) const = 0; 168 169 LogicalResult matchAndRewrite(Operation *op, 170 PatternRewriter &rewriter) const final; 171 }; 172 173 /// This class is a wrapper around `OneToNConversionPattern` for matching 174 /// against instances of a particular op class. 175 template <typename SourceOp> 176 class OneToNOpConversionPattern : public OneToNConversionPattern { 177 public: 178 OneToNOpConversionPattern(const TypeConverter &typeConverter, 179 MLIRContext *context, PatternBenefit benefit = 1, 180 ArrayRef<StringRef> generatedNames = {}) 181 : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), 182 benefit, context, generatedNames) {} 183 /// Generic adaptor around the root op of this pattern using the converted 184 /// operands. Importantly, each operand is represented as a *range* of values, 185 /// namely the N values each original operand gets converted to. Concretely, 186 /// this makes the result type of the accessor functions of the adaptor class 187 /// be a `ValueRange`. 188 class OpAdaptor 189 : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> { 190 public: 191 using RangeT = ArrayRef<ValueRange>; 192 using BaseT = typename SourceOp::template GenericAdaptor<RangeT>; 193 using Properties = typename SourceOp::template InferredProperties<SourceOp>; 194 195 OpAdaptor(const OneToNTypeMapping *operandMapping, 196 const OneToNTypeMapping *resultMapping, 197 const ValueRange *convertedOperands, RangeT values, SourceOp op) 198 : BaseT(values, op), operandMapping(operandMapping), 199 resultMapping(resultMapping), convertedOperands(convertedOperands) {} 200 201 /// Get the type mapping of the original operands to the converted operands. 202 const OneToNTypeMapping &getOperandMapping() const { 203 return *operandMapping; 204 } 205 206 /// Get the type mapping of the original results to the converted results. 207 const OneToNTypeMapping &getResultMapping() const { return *resultMapping; } 208 209 /// Get a flat range of all converted operands. Unlike `getOperands`, which 210 /// returns an `ArrayRef` with one `ValueRange` for each original operand, 211 /// this function returns a `ValueRange` that contains all converted 212 /// operands irrespectively of which operand they originated from. 213 ValueRange getFlatOperands() const { return *convertedOperands; } 214 215 private: 216 const OneToNTypeMapping *operandMapping; 217 const OneToNTypeMapping *resultMapping; 218 const ValueRange *convertedOperands; 219 }; 220 221 using OneToNConversionPattern::matchAndRewrite; 222 223 /// Overload that derived classes have to override for their op type. 224 virtual LogicalResult 225 matchAndRewrite(SourceOp op, OpAdaptor adaptor, 226 OneToNPatternRewriter &rewriter) const = 0; 227 228 LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, 229 const OneToNTypeMapping &operandMapping, 230 const OneToNTypeMapping &resultMapping, 231 ValueRange convertedOperands) const final { 232 // Wrap converted operands and type mappings into an adaptor. 233 SmallVector<ValueRange> valueRanges; 234 for (int64_t i = 0; i < op->getNumOperands(); i++) { 235 auto values = operandMapping.getConvertedValues(convertedOperands, i); 236 valueRanges.push_back(values); 237 } 238 OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands, 239 valueRanges, cast<SourceOp>(op)); 240 241 // Call overload implemented by the derived class. 242 return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter); 243 } 244 }; 245 246 /// Applies the given set of patterns recursively on the given op and adds user 247 /// materializations where necessary. The patterns are expected to be 248 /// `OneToNConversionPattern`, which help converting the types of the operands 249 /// and results of the matched ops. The provided type converter is used to 250 /// convert the operands of matched ops from their original types to operands 251 /// with different types. Unlike in `DialectConversion`, this supports 1:N type 252 /// conversions. Those conversions at the "boundary" of the pattern application, 253 /// where converted results are not consumed by replaced ops that expect the 254 /// converted operands or vice versa, the function inserts user materializations 255 /// from the type converter. Also unlike `DialectConversion`, there are no legal 256 /// or illegal types; the function simply applies the given patterns and does 257 /// not fail if some ops or types remain unconverted (i.e., the conversion is 258 /// only "partial"). 259 /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. 260 /// 1:N support has been added to the regular dialect conversion driver. 261 /// Use applyPartialConversion() instead. 262 LogicalResult 263 applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, 264 const FrozenRewritePatternSet &patterns); 265 266 /// Add a pattern to the given pattern list to convert the signature of a 267 /// FunctionOpInterface op with the given type converter. This only supports 268 /// ops which use FunctionType to represent their type. This is intended to be 269 /// used with the 1:N dialect conversion. 270 /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. 271 /// 1:N support has been added to the regular dialect conversion driver. 272 /// Use populateFunctionOpInterfaceTypeConversionPattern() instead. 273 void populateOneToNFunctionOpInterfaceTypeConversionPattern( 274 StringRef functionLikeOpName, const TypeConverter &converter, 275 RewritePatternSet &patterns); 276 template <typename FuncOpT> 277 void populateOneToNFunctionOpInterfaceTypeConversionPattern( 278 const TypeConverter &converter, RewritePatternSet &patterns) { 279 populateOneToNFunctionOpInterfaceTypeConversionPattern( 280 FuncOpT::getOperationName(), converter, patterns); 281 } 282 283 } // namespace mlir 284 285 #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H 286