xref: /llvm-project/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (revision 9df63b2651b2435c02a7d825953ca2ddc65c778e)
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
12 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 namespace mlir {
18 class CallOpInterface;
19 
20 namespace LLVM {
21 namespace detail {
22 /// Handle generically setting flags as native properties on LLVM operations.
23 void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24 
25 /// Replaces the given operation "op" with a new operation of type "targetOp"
26 /// and given operands.
27 LogicalResult oneToOneRewrite(
28     Operation *op, StringRef targetOp, ValueRange operands,
29     ArrayRef<NamedAttribute> targetAttrs,
30     const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31     IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
32 
33 } // namespace detail
34 } // namespace LLVM
35 
36 /// Base class for operation conversions targeting the LLVM IR dialect. It
37 /// provides the conversion patterns with access to the LLVMTypeConverter and
38 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
39 /// LowerToLLVMOptions by reference meaning the references have to remain alive
40 /// during the entire pattern lifetime.
41 class ConvertToLLVMPattern : public ConversionPattern {
42 public:
43   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
44                        const LLVMTypeConverter &typeConverter,
45                        PatternBenefit benefit = 1);
46 
47 protected:
48   /// Returns the LLVM dialect.
49   LLVM::LLVMDialect &getDialect() const;
50 
51   const LLVMTypeConverter *getTypeConverter() const;
52 
53   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
54   /// defined by the used type converter.
55   Type getIndexType() const;
56 
57   /// Gets the MLIR type wrapping the LLVM integer type whose bit width
58   /// corresponds to that of a LLVM pointer type.
59   Type getIntPtrType(unsigned addressSpace = 0) const;
60 
61   /// Gets the MLIR type wrapping the LLVM void type.
62   Type getVoidType() const;
63 
64   /// Get the MLIR type wrapping the LLVM i8* type.
65   Type getVoidPtrType() const;
66 
67   /// Create a constant Op producing a value of `resultType` from an index-typed
68   /// integer attribute.
69   static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
70                                        Type resultType, int64_t value);
71 
72   // This is a strided getElementPtr variant that linearizes subscripts as:
73   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
74   Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
75                              ValueRange indices,
76                              ConversionPatternRewriter &rewriter) const;
77 
78   /// Returns if the given memref has identity maps and the element type is
79   /// convertible to LLVM.
80   bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
81 
82   /// Returns the type of a pointer to an element of the memref.
83   Type getElementPtrType(MemRefType type) const;
84 
85   /// Computes sizes, strides and buffer size of `memRefType` with identity
86   /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
87   /// `dynamicSizes` for the others. Emits instructions to compute strides and
88   /// buffer size from these sizes.
89   ///
90   /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
91   /// `sizes[0]`   = llvm.mlir.constant(4 : index) : i64
92   /// `sizes[1]`   = `dynamicSizes[0]`
93   /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
94   /// `strides[0]` = `sizes[0]`
95   /// %size        = llvm.mul `sizes[0]`, `sizes[1]` : i64
96   /// %nullptr     = llvm.mlir.zero : !llvm.ptr
97   /// %gep         = llvm.getelementptr %nullptr[%size]
98   ///                  : (!llvm.ptr, i64) -> !llvm.ptr, f32
99   /// `sizeBytes`  = llvm.ptrtoint %gep : !llvm.ptr to i64
100   ///
101   /// If `sizeInBytes = false`, memref<4x?xf32> emits:
102   /// `sizes[0]`   = llvm.mlir.constant(4 : index) : i64
103   /// `sizes[1]`   = `dynamicSizes[0]`
104   /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
105   /// `strides[0]` = `sizes[0]`
106   /// %size        = llvm.mul `sizes[0]`, `sizes[1]` : i64
107   void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
108                                 ValueRange dynamicSizes,
109                                 ConversionPatternRewriter &rewriter,
110                                 SmallVectorImpl<Value> &sizes,
111                                 SmallVectorImpl<Value> &strides, Value &size,
112                                 bool sizeInBytes = true) const;
113 
114   /// Computes the size of type in bytes.
115   Value getSizeInBytes(Location loc, Type type,
116                        ConversionPatternRewriter &rewriter) const;
117 
118   /// Computes total number of elements for the given MemRef and dynamicSizes.
119   Value getNumElements(Location loc, MemRefType memRefType,
120                        ValueRange dynamicSizes,
121                        ConversionPatternRewriter &rewriter) const;
122 
123   /// Creates and populates a canonical memref descriptor struct.
124   MemRefDescriptor
125   createMemRefDescriptor(Location loc, MemRefType memRefType,
126                          Value allocatedPtr, Value alignedPtr,
127                          ArrayRef<Value> sizes, ArrayRef<Value> strides,
128                          ConversionPatternRewriter &rewriter) const;
129 
130   /// Copies the memory descriptor for any operands that were unranked
131   /// descriptors originally to heap-allocated memory (if toDynamic is true) or
132   /// to stack-allocated memory (otherwise). Also frees the previously used
133   /// memory (that is assumed to be heap-allocated) if toDynamic is false.
134   LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
135                                         TypeRange origTypes,
136                                         SmallVectorImpl<Value> &operands,
137                                         bool toDynamic) const;
138 };
139 
140 /// Utility class for operation conversions targeting the LLVM dialect that
141 /// match exactly one source operation.
142 template <typename SourceOp>
143 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144 public:
145   using OpAdaptor = typename SourceOp::Adaptor;
146   using OneToNOpAdaptor =
147       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
148 
149   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
150                                   PatternBenefit benefit = 1)
151       : ConvertToLLVMPattern(SourceOp::getOperationName(),
152                              &typeConverter.getContext(), typeConverter,
153                              benefit) {}
154 
155   /// Wrappers around the RewritePattern methods that pass the derived op type.
156   void rewrite(Operation *op, ArrayRef<Value> operands,
157                ConversionPatternRewriter &rewriter) const final {
158     auto sourceOp = cast<SourceOp>(op);
159     rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160   }
161   void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162                ConversionPatternRewriter &rewriter) const final {
163     auto sourceOp = cast<SourceOp>(op);
164     rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
165   }
166   LogicalResult match(Operation *op) const final {
167     return match(cast<SourceOp>(op));
168   }
169   LogicalResult
170   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
171                   ConversionPatternRewriter &rewriter) const final {
172     auto sourceOp = cast<SourceOp>(op);
173     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174   }
175   LogicalResult
176   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
177                   ConversionPatternRewriter &rewriter) const final {
178     auto sourceOp = cast<SourceOp>(op);
179     return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180                            rewriter);
181   }
182 
183   /// Rewrite and Match methods that operate on the SourceOp type. These must be
184   /// overridden by the derived pattern class.
185   virtual LogicalResult match(SourceOp op) const {
186     llvm_unreachable("must override match or matchAndRewrite");
187   }
188   virtual void rewrite(SourceOp op, OpAdaptor adaptor,
189                        ConversionPatternRewriter &rewriter) const {
190     llvm_unreachable("must override rewrite or matchAndRewrite");
191   }
192   virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193                        ConversionPatternRewriter &rewriter) const {
194     SmallVector<Value> oneToOneOperands =
195         getOneToOneAdaptorOperands(adaptor.getOperands());
196     rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197   }
198   virtual LogicalResult
199   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
200                   ConversionPatternRewriter &rewriter) const {
201     if (failed(match(op)))
202       return failure();
203     rewrite(op, adaptor, rewriter);
204     return success();
205   }
206   virtual LogicalResult
207   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208                   ConversionPatternRewriter &rewriter) const {
209     SmallVector<Value> oneToOneOperands =
210         getOneToOneAdaptorOperands(adaptor.getOperands());
211     return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212   }
213 
214 private:
215   using ConvertToLLVMPattern::match;
216   using ConvertToLLVMPattern::matchAndRewrite;
217 };
218 
219 /// Generic implementation of one-to-one conversion from "SourceOp" to
220 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
221 /// Upholds a convention that multi-result operations get converted into an
222 /// operation returning the LLVM IR structure type, in which case individual
223 /// values must be extracted from using LLVM::ExtractValueOp before being used.
224 template <typename SourceOp, typename TargetOp>
225 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
226 public:
227   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
228   using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
229 
230   /// Converts the type of the result to an LLVM type, pass operands as is,
231   /// preserve attributes.
232   LogicalResult
233   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
234                   ConversionPatternRewriter &rewriter) const override {
235     return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
236                                          adaptor.getOperands(), op->getAttrs(),
237                                          *this->getTypeConverter(), rewriter);
238   }
239 };
240 
241 } // namespace mlir
242 
243 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
244