1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H 10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H 11 12 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 namespace mlir { 18 class CallOpInterface; 19 20 namespace LLVM { 21 namespace detail { 22 /// Handle generically setting flags as native properties on LLVM operations. 23 void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags); 24 25 /// Replaces the given operation "op" with a new operation of type "targetOp" 26 /// and given operands. 27 LogicalResult oneToOneRewrite( 28 Operation *op, StringRef targetOp, ValueRange operands, 29 ArrayRef<NamedAttribute> targetAttrs, 30 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, 31 IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); 32 33 } // namespace detail 34 } // namespace LLVM 35 36 /// Base class for operation conversions targeting the LLVM IR dialect. It 37 /// provides the conversion patterns with access to the LLVMTypeConverter and 38 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the 39 /// LowerToLLVMOptions by reference meaning the references have to remain alive 40 /// during the entire pattern lifetime. 41 class ConvertToLLVMPattern : public ConversionPattern { 42 public: 43 ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, 44 const LLVMTypeConverter &typeConverter, 45 PatternBenefit benefit = 1); 46 47 protected: 48 /// Returns the LLVM dialect. 49 LLVM::LLVMDialect &getDialect() const; 50 51 const LLVMTypeConverter *getTypeConverter() const; 52 53 /// Gets the MLIR type wrapping the LLVM integer type whose bit width is 54 /// defined by the used type converter. 55 Type getIndexType() const; 56 57 /// Gets the MLIR type wrapping the LLVM integer type whose bit width 58 /// corresponds to that of a LLVM pointer type. 59 Type getIntPtrType(unsigned addressSpace = 0) const; 60 61 /// Gets the MLIR type wrapping the LLVM void type. 62 Type getVoidType() const; 63 64 /// Get the MLIR type wrapping the LLVM i8* type. 65 Type getVoidPtrType() const; 66 67 /// Create a constant Op producing a value of `resultType` from an index-typed 68 /// integer attribute. 69 static Value createIndexAttrConstant(OpBuilder &builder, Location loc, 70 Type resultType, int64_t value); 71 72 // This is a strided getElementPtr variant that linearizes subscripts as: 73 // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. 74 Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, 75 ValueRange indices, 76 ConversionPatternRewriter &rewriter) const; 77 78 /// Returns if the given memref has identity maps and the element type is 79 /// convertible to LLVM. 80 bool isConvertibleAndHasIdentityMaps(MemRefType type) const; 81 82 /// Returns the type of a pointer to an element of the memref. 83 Type getElementPtrType(MemRefType type) const; 84 85 /// Computes sizes, strides and buffer size of `memRefType` with identity 86 /// layout. Emits constant ops for the static sizes of `memRefType`, and uses 87 /// `dynamicSizes` for the others. Emits instructions to compute strides and 88 /// buffer size from these sizes. 89 /// 90 /// For example, memref<4x?xf32> with `sizeInBytes = true` emits: 91 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64 92 /// `sizes[1]` = `dynamicSizes[0]` 93 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64 94 /// `strides[0]` = `sizes[0]` 95 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64 96 /// %nullptr = llvm.mlir.zero : !llvm.ptr 97 /// %gep = llvm.getelementptr %nullptr[%size] 98 /// : (!llvm.ptr, i64) -> !llvm.ptr, f32 99 /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64 100 /// 101 /// If `sizeInBytes = false`, memref<4x?xf32> emits: 102 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64 103 /// `sizes[1]` = `dynamicSizes[0]` 104 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64 105 /// `strides[0]` = `sizes[0]` 106 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64 107 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, 108 ValueRange dynamicSizes, 109 ConversionPatternRewriter &rewriter, 110 SmallVectorImpl<Value> &sizes, 111 SmallVectorImpl<Value> &strides, Value &size, 112 bool sizeInBytes = true) const; 113 114 /// Computes the size of type in bytes. 115 Value getSizeInBytes(Location loc, Type type, 116 ConversionPatternRewriter &rewriter) const; 117 118 /// Computes total number of elements for the given MemRef and dynamicSizes. 119 Value getNumElements(Location loc, MemRefType memRefType, 120 ValueRange dynamicSizes, 121 ConversionPatternRewriter &rewriter) const; 122 123 /// Creates and populates a canonical memref descriptor struct. 124 MemRefDescriptor 125 createMemRefDescriptor(Location loc, MemRefType memRefType, 126 Value allocatedPtr, Value alignedPtr, 127 ArrayRef<Value> sizes, ArrayRef<Value> strides, 128 ConversionPatternRewriter &rewriter) const; 129 130 /// Copies the memory descriptor for any operands that were unranked 131 /// descriptors originally to heap-allocated memory (if toDynamic is true) or 132 /// to stack-allocated memory (otherwise). Also frees the previously used 133 /// memory (that is assumed to be heap-allocated) if toDynamic is false. 134 LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, 135 TypeRange origTypes, 136 SmallVectorImpl<Value> &operands, 137 bool toDynamic) const; 138 }; 139 140 /// Utility class for operation conversions targeting the LLVM dialect that 141 /// match exactly one source operation. 142 template <typename SourceOp> 143 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { 144 public: 145 using OpAdaptor = typename SourceOp::Adaptor; 146 using OneToNOpAdaptor = 147 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>; 148 149 explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, 150 PatternBenefit benefit = 1) 151 : ConvertToLLVMPattern(SourceOp::getOperationName(), 152 &typeConverter.getContext(), typeConverter, 153 benefit) {} 154 155 /// Wrappers around the RewritePattern methods that pass the derived op type. 156 void rewrite(Operation *op, ArrayRef<Value> operands, 157 ConversionPatternRewriter &rewriter) const final { 158 auto sourceOp = cast<SourceOp>(op); 159 rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); 160 } 161 void rewrite(Operation *op, ArrayRef<ValueRange> operands, 162 ConversionPatternRewriter &rewriter) const final { 163 auto sourceOp = cast<SourceOp>(op); 164 rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); 165 } 166 LogicalResult match(Operation *op) const final { 167 return match(cast<SourceOp>(op)); 168 } 169 LogicalResult 170 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 171 ConversionPatternRewriter &rewriter) const final { 172 auto sourceOp = cast<SourceOp>(op); 173 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); 174 } 175 LogicalResult 176 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 177 ConversionPatternRewriter &rewriter) const final { 178 auto sourceOp = cast<SourceOp>(op); 179 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), 180 rewriter); 181 } 182 183 /// Rewrite and Match methods that operate on the SourceOp type. These must be 184 /// overridden by the derived pattern class. 185 virtual LogicalResult match(SourceOp op) const { 186 llvm_unreachable("must override match or matchAndRewrite"); 187 } 188 virtual void rewrite(SourceOp op, OpAdaptor adaptor, 189 ConversionPatternRewriter &rewriter) const { 190 llvm_unreachable("must override rewrite or matchAndRewrite"); 191 } 192 virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, 193 ConversionPatternRewriter &rewriter) const { 194 SmallVector<Value> oneToOneOperands = 195 getOneToOneAdaptorOperands(adaptor.getOperands()); 196 rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); 197 } 198 virtual LogicalResult 199 matchAndRewrite(SourceOp op, OpAdaptor adaptor, 200 ConversionPatternRewriter &rewriter) const { 201 if (failed(match(op))) 202 return failure(); 203 rewrite(op, adaptor, rewriter); 204 return success(); 205 } 206 virtual LogicalResult 207 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, 208 ConversionPatternRewriter &rewriter) const { 209 SmallVector<Value> oneToOneOperands = 210 getOneToOneAdaptorOperands(adaptor.getOperands()); 211 return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); 212 } 213 214 private: 215 using ConvertToLLVMPattern::match; 216 using ConvertToLLVMPattern::matchAndRewrite; 217 }; 218 219 /// Generic implementation of one-to-one conversion from "SourceOp" to 220 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. 221 /// Upholds a convention that multi-result operations get converted into an 222 /// operation returning the LLVM IR structure type, in which case individual 223 /// values must be extracted from using LLVM::ExtractValueOp before being used. 224 template <typename SourceOp, typename TargetOp> 225 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 226 public: 227 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 228 using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>; 229 230 /// Converts the type of the result to an LLVM type, pass operands as is, 231 /// preserve attributes. 232 LogicalResult 233 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 234 ConversionPatternRewriter &rewriter) const override { 235 return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), 236 adaptor.getOperands(), op->getAttrs(), 237 *this->getTypeConverter(), rewriter); 238 } 239 }; 240 241 } // namespace mlir 242 243 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H 244