xref: /llvm-project/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h (revision 6900768719ff6d38403f39ceb75e0ec953278f5a)
1 //===- MemRefBuilder.h - Helper for LLVM MemRef equivalents -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides a convenience API for emitting IR that inspects or constructs values
10 // of LLVM dialect structure type that correspond to ranked or unranked memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H
15 #define MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H
16 
17 #include "mlir/Conversion/LLVMCommon/StructBuilder.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 namespace mlir {
21 
22 class LLVMTypeConverter;
23 class MemRefType;
24 class UnrankedMemRefType;
25 
26 namespace LLVM {
27 class LLVMPointerType;
28 } // namespace LLVM
29 
30 /// Helper class to produce LLVM dialect operations extracting or inserting
31 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
32 /// The Value may be null, in which case none of the operations are valid.
33 class MemRefDescriptor : public StructBuilder {
34 public:
35   /// Construct a helper for the given descriptor value.
36   explicit MemRefDescriptor(Value descriptor);
37   /// Builds IR creating an `undef` value of the descriptor type.
38   static MemRefDescriptor undef(OpBuilder &builder, Location loc,
39                                 Type descriptorType);
40   /// Builds IR creating a MemRef descriptor that represents `type` and
41   /// populates it with static shape and stride information extracted from the
42   /// type.
43   static MemRefDescriptor
44   fromStaticShape(OpBuilder &builder, Location loc,
45                   const LLVMTypeConverter &typeConverter, MemRefType type,
46                   Value memory);
47   static MemRefDescriptor
48   fromStaticShape(OpBuilder &builder, Location loc,
49                   const LLVMTypeConverter &typeConverter, MemRefType type,
50                   Value memory, Value alignedMemory);
51 
52   /// Builds IR extracting the allocated pointer from the descriptor.
53   Value allocatedPtr(OpBuilder &builder, Location loc);
54   /// Builds IR inserting the allocated pointer into the descriptor.
55   void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
56 
57   /// Builds IR extracting the aligned pointer from the descriptor.
58   Value alignedPtr(OpBuilder &builder, Location loc);
59 
60   /// Builds IR inserting the aligned pointer into the descriptor.
61   void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
62 
63   /// Builds IR extracting the offset from the descriptor.
64   Value offset(OpBuilder &builder, Location loc);
65 
66   /// Builds IR inserting the offset into the descriptor.
67   void setOffset(OpBuilder &builder, Location loc, Value offset);
68   void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
69 
70   /// Builds IR extracting the pos-th size from the descriptor.
71   Value size(OpBuilder &builder, Location loc, unsigned pos);
72   Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
73 
74   /// Builds IR inserting the pos-th size into the descriptor
75   void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
76   void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
77                        uint64_t size);
78 
79   /// Builds IR extracting the pos-th size from the descriptor.
80   Value stride(OpBuilder &builder, Location loc, unsigned pos);
81 
82   /// Builds IR inserting the pos-th stride into the descriptor
83   void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
84   void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
85                          uint64_t stride);
86 
87   /// Returns the type of array element in this descriptor.
88   Type getIndexType() { return indexType; };
89 
90   /// Returns the (LLVM) pointer type this descriptor contains.
91   LLVM::LLVMPointerType getElementPtrType();
92 
93   /// Builds IR for getting the start address of the buffer represented
94   /// by this memref:
95   /// `memref.alignedPtr + memref.offset * sizeof(type.getElementType())`.
96   /// \note there is no setter for this one since it is derived from alignedPtr
97   /// and offset.
98   Value bufferPtr(OpBuilder &builder, Location loc,
99                   const LLVMTypeConverter &converter, MemRefType type);
100 
101   /// Builds IR populating a MemRef descriptor structure from a list of
102   /// individual values composing that descriptor, in the following order:
103   /// - allocated pointer;
104   /// - aligned pointer;
105   /// - offset;
106   /// - <rank> sizes;
107   /// - <rank> strides;
108   /// where <rank> is the MemRef rank as provided in `type`.
109   static Value pack(OpBuilder &builder, Location loc,
110                     const LLVMTypeConverter &converter, MemRefType type,
111                     ValueRange values);
112 
113   /// Builds IR extracting individual elements of a MemRef descriptor structure
114   /// and returning them as `results` list.
115   static void unpack(OpBuilder &builder, Location loc, Value packed,
116                      MemRefType type, SmallVectorImpl<Value> &results);
117 
118   /// Returns the number of non-aggregate values that would be produced by
119   /// `unpack`.
120   static unsigned getNumUnpackedValues(MemRefType type);
121 
122 private:
123   // Cached index type.
124   Type indexType;
125 };
126 
127 /// Helper class allowing the user to access a range of Values that correspond
128 /// to an unpacked memref descriptor using named accessors. This does not own
129 /// the values.
130 class MemRefDescriptorView {
131 public:
132   /// Constructs the view from a range of values. Infers the rank from the size
133   /// of the range.
134   explicit MemRefDescriptorView(ValueRange range);
135 
136   /// Returns the allocated pointer Value.
137   Value allocatedPtr();
138 
139   /// Returns the aligned pointer Value.
140   Value alignedPtr();
141 
142   /// Returns the offset Value.
143   Value offset();
144 
145   /// Returns the pos-th size Value.
146   Value size(unsigned pos);
147 
148   /// Returns the pos-th stride Value.
149   Value stride(unsigned pos);
150 
151 private:
152   /// Rank of the memref the descriptor is pointing to.
153   int rank;
154   /// Underlying range of Values.
155   ValueRange elements;
156 };
157 
158 class UnrankedMemRefDescriptor : public StructBuilder {
159 public:
160   /// Construct a helper for the given descriptor value.
161   explicit UnrankedMemRefDescriptor(Value descriptor);
162   /// Builds IR creating an `undef` value of the descriptor type.
163   static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
164                                         Type descriptorType);
165 
166   /// Builds IR extracting the rank from the descriptor
167   Value rank(OpBuilder &builder, Location loc) const;
168   /// Builds IR setting the rank in the descriptor
169   void setRank(OpBuilder &builder, Location loc, Value value);
170   /// Builds IR extracting ranked memref descriptor ptr
171   Value memRefDescPtr(OpBuilder &builder, Location loc) const;
172   /// Builds IR setting ranked memref descriptor ptr
173   void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
174 
175   /// Builds IR populating an unranked MemRef descriptor structure from a list
176   /// of individual constituent values in the following order:
177   /// - rank of the memref;
178   /// - pointer to the memref descriptor.
179   static Value pack(OpBuilder &builder, Location loc,
180                     const LLVMTypeConverter &converter, UnrankedMemRefType type,
181                     ValueRange values);
182 
183   /// Builds IR extracting individual elements that compose an unranked memref
184   /// descriptor and returns them as `results` list.
185   static void unpack(OpBuilder &builder, Location loc, Value packed,
186                      SmallVectorImpl<Value> &results);
187 
188   /// Returns the number of non-aggregate values that would be produced by
189   /// `unpack`.
190   static unsigned getNumUnpackedValues() { return 2; }
191 
192   /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
193   /// and appends the corresponding values into `sizes`. `addressSpaces`
194   /// which must have the same length as `values`, is needed to handle layouts
195   /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
196   static void computeSizes(OpBuilder &builder, Location loc,
197                            const LLVMTypeConverter &typeConverter,
198                            ArrayRef<UnrankedMemRefDescriptor> values,
199                            ArrayRef<unsigned> addressSpaces,
200                            SmallVectorImpl<Value> &sizes);
201 
202   /// TODO: The following accessors don't take alignment rules between elements
203   /// of the descriptor struct into account. For some architectures, it might be
204   /// necessary to extend them and to use `llvm::DataLayout` contained in
205   /// `LLVMTypeConverter`.
206 
207   /// Builds IR extracting the allocated pointer from the descriptor.
208   static Value allocatedPtr(OpBuilder &builder, Location loc,
209                             Value memRefDescPtr,
210                             LLVM::LLVMPointerType elemPtrType);
211   /// Builds IR inserting the allocated pointer into the descriptor.
212   static void setAllocatedPtr(OpBuilder &builder, Location loc,
213                               Value memRefDescPtr,
214                               LLVM::LLVMPointerType elemPtrType,
215                               Value allocatedPtr);
216 
217   /// Builds IR extracting the aligned pointer from the descriptor.
218   static Value alignedPtr(OpBuilder &builder, Location loc,
219                           const LLVMTypeConverter &typeConverter,
220                           Value memRefDescPtr,
221                           LLVM::LLVMPointerType elemPtrType);
222   /// Builds IR inserting the aligned pointer into the descriptor.
223   static void setAlignedPtr(OpBuilder &builder, Location loc,
224                             const LLVMTypeConverter &typeConverter,
225                             Value memRefDescPtr,
226                             LLVM::LLVMPointerType elemPtrType,
227                             Value alignedPtr);
228 
229   /// Builds IR for getting the pointer to the offset's location.
230   /// Returns a pointer to a convertType(index), which points to the beggining
231   /// of a struct {index, index[rank], index[rank]}.
232   static Value offsetBasePtr(OpBuilder &builder, Location loc,
233                              const LLVMTypeConverter &typeConverter,
234                              Value memRefDescPtr,
235                              LLVM::LLVMPointerType elemPtrType);
236   /// Builds IR extracting the offset from the descriptor.
237   static Value offset(OpBuilder &builder, Location loc,
238                       const LLVMTypeConverter &typeConverter,
239                       Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType);
240   /// Builds IR inserting the offset into the descriptor.
241   static void setOffset(OpBuilder &builder, Location loc,
242                         const LLVMTypeConverter &typeConverter,
243                         Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType,
244                         Value offset);
245 
246   /// Builds IR extracting the pointer to the first element of the size array.
247   static Value sizeBasePtr(OpBuilder &builder, Location loc,
248                            const LLVMTypeConverter &typeConverter,
249                            Value memRefDescPtr,
250                            LLVM::LLVMPointerType elemPtrType);
251   /// Builds IR extracting the size[index] from the descriptor.
252   static Value size(OpBuilder &builder, Location loc,
253                     const LLVMTypeConverter &typeConverter, Value sizeBasePtr,
254                     Value index);
255   /// Builds IR inserting the size[index] into the descriptor.
256   static void setSize(OpBuilder &builder, Location loc,
257                       const LLVMTypeConverter &typeConverter, Value sizeBasePtr,
258                       Value index, Value size);
259 
260   /// Builds IR extracting the pointer to the first element of the stride array.
261   static Value strideBasePtr(OpBuilder &builder, Location loc,
262                              const LLVMTypeConverter &typeConverter,
263                              Value sizeBasePtr, Value rank);
264   /// Builds IR extracting the stride[index] from the descriptor.
265   static Value stride(OpBuilder &builder, Location loc,
266                       const LLVMTypeConverter &typeConverter,
267                       Value strideBasePtr, Value index, Value stride);
268   /// Builds IR inserting the stride[index] into the descriptor.
269   static void setStride(OpBuilder &builder, Location loc,
270                         const LLVMTypeConverter &typeConverter,
271                         Value strideBasePtr, Value index, Value stride);
272 };
273 
274 } // namespace mlir
275 
276 #endif // MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H
277