1 //===- MemRefBuilder.h - Helper for LLVM MemRef equivalents -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides a convenience API for emitting IR that inspects or constructs values 10 // of LLVM dialect structure type that correspond to ranked or unranked memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H 15 #define MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H 16 17 #include "mlir/Conversion/LLVMCommon/StructBuilder.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 namespace mlir { 21 22 class LLVMTypeConverter; 23 class MemRefType; 24 class UnrankedMemRefType; 25 26 namespace LLVM { 27 class LLVMPointerType; 28 } // namespace LLVM 29 30 /// Helper class to produce LLVM dialect operations extracting or inserting 31 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. 32 /// The Value may be null, in which case none of the operations are valid. 33 class MemRefDescriptor : public StructBuilder { 34 public: 35 /// Construct a helper for the given descriptor value. 36 explicit MemRefDescriptor(Value descriptor); 37 /// Builds IR creating an `undef` value of the descriptor type. 38 static MemRefDescriptor undef(OpBuilder &builder, Location loc, 39 Type descriptorType); 40 /// Builds IR creating a MemRef descriptor that represents `type` and 41 /// populates it with static shape and stride information extracted from the 42 /// type. 43 static MemRefDescriptor 44 fromStaticShape(OpBuilder &builder, Location loc, 45 const LLVMTypeConverter &typeConverter, MemRefType type, 46 Value memory); 47 static MemRefDescriptor 48 fromStaticShape(OpBuilder &builder, Location loc, 49 const LLVMTypeConverter &typeConverter, MemRefType type, 50 Value memory, Value alignedMemory); 51 52 /// Builds IR extracting the allocated pointer from the descriptor. 53 Value allocatedPtr(OpBuilder &builder, Location loc); 54 /// Builds IR inserting the allocated pointer into the descriptor. 55 void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); 56 57 /// Builds IR extracting the aligned pointer from the descriptor. 58 Value alignedPtr(OpBuilder &builder, Location loc); 59 60 /// Builds IR inserting the aligned pointer into the descriptor. 61 void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); 62 63 /// Builds IR extracting the offset from the descriptor. 64 Value offset(OpBuilder &builder, Location loc); 65 66 /// Builds IR inserting the offset into the descriptor. 67 void setOffset(OpBuilder &builder, Location loc, Value offset); 68 void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); 69 70 /// Builds IR extracting the pos-th size from the descriptor. 71 Value size(OpBuilder &builder, Location loc, unsigned pos); 72 Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); 73 74 /// Builds IR inserting the pos-th size into the descriptor 75 void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); 76 void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, 77 uint64_t size); 78 79 /// Builds IR extracting the pos-th size from the descriptor. 80 Value stride(OpBuilder &builder, Location loc, unsigned pos); 81 82 /// Builds IR inserting the pos-th stride into the descriptor 83 void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); 84 void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, 85 uint64_t stride); 86 87 /// Returns the type of array element in this descriptor. 88 Type getIndexType() { return indexType; }; 89 90 /// Returns the (LLVM) pointer type this descriptor contains. 91 LLVM::LLVMPointerType getElementPtrType(); 92 93 /// Builds IR for getting the start address of the buffer represented 94 /// by this memref: 95 /// `memref.alignedPtr + memref.offset * sizeof(type.getElementType())`. 96 /// \note there is no setter for this one since it is derived from alignedPtr 97 /// and offset. 98 Value bufferPtr(OpBuilder &builder, Location loc, 99 const LLVMTypeConverter &converter, MemRefType type); 100 101 /// Builds IR populating a MemRef descriptor structure from a list of 102 /// individual values composing that descriptor, in the following order: 103 /// - allocated pointer; 104 /// - aligned pointer; 105 /// - offset; 106 /// - <rank> sizes; 107 /// - <rank> strides; 108 /// where <rank> is the MemRef rank as provided in `type`. 109 static Value pack(OpBuilder &builder, Location loc, 110 const LLVMTypeConverter &converter, MemRefType type, 111 ValueRange values); 112 113 /// Builds IR extracting individual elements of a MemRef descriptor structure 114 /// and returning them as `results` list. 115 static void unpack(OpBuilder &builder, Location loc, Value packed, 116 MemRefType type, SmallVectorImpl<Value> &results); 117 118 /// Returns the number of non-aggregate values that would be produced by 119 /// `unpack`. 120 static unsigned getNumUnpackedValues(MemRefType type); 121 122 private: 123 // Cached index type. 124 Type indexType; 125 }; 126 127 /// Helper class allowing the user to access a range of Values that correspond 128 /// to an unpacked memref descriptor using named accessors. This does not own 129 /// the values. 130 class MemRefDescriptorView { 131 public: 132 /// Constructs the view from a range of values. Infers the rank from the size 133 /// of the range. 134 explicit MemRefDescriptorView(ValueRange range); 135 136 /// Returns the allocated pointer Value. 137 Value allocatedPtr(); 138 139 /// Returns the aligned pointer Value. 140 Value alignedPtr(); 141 142 /// Returns the offset Value. 143 Value offset(); 144 145 /// Returns the pos-th size Value. 146 Value size(unsigned pos); 147 148 /// Returns the pos-th stride Value. 149 Value stride(unsigned pos); 150 151 private: 152 /// Rank of the memref the descriptor is pointing to. 153 int rank; 154 /// Underlying range of Values. 155 ValueRange elements; 156 }; 157 158 class UnrankedMemRefDescriptor : public StructBuilder { 159 public: 160 /// Construct a helper for the given descriptor value. 161 explicit UnrankedMemRefDescriptor(Value descriptor); 162 /// Builds IR creating an `undef` value of the descriptor type. 163 static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, 164 Type descriptorType); 165 166 /// Builds IR extracting the rank from the descriptor 167 Value rank(OpBuilder &builder, Location loc) const; 168 /// Builds IR setting the rank in the descriptor 169 void setRank(OpBuilder &builder, Location loc, Value value); 170 /// Builds IR extracting ranked memref descriptor ptr 171 Value memRefDescPtr(OpBuilder &builder, Location loc) const; 172 /// Builds IR setting ranked memref descriptor ptr 173 void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); 174 175 /// Builds IR populating an unranked MemRef descriptor structure from a list 176 /// of individual constituent values in the following order: 177 /// - rank of the memref; 178 /// - pointer to the memref descriptor. 179 static Value pack(OpBuilder &builder, Location loc, 180 const LLVMTypeConverter &converter, UnrankedMemRefType type, 181 ValueRange values); 182 183 /// Builds IR extracting individual elements that compose an unranked memref 184 /// descriptor and returns them as `results` list. 185 static void unpack(OpBuilder &builder, Location loc, Value packed, 186 SmallVectorImpl<Value> &results); 187 188 /// Returns the number of non-aggregate values that would be produced by 189 /// `unpack`. 190 static unsigned getNumUnpackedValues() { return 2; } 191 192 /// Builds IR computing the sizes in bytes (suitable for opaque allocation) 193 /// and appends the corresponding values into `sizes`. `addressSpaces` 194 /// which must have the same length as `values`, is needed to handle layouts 195 /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)). 196 static void computeSizes(OpBuilder &builder, Location loc, 197 const LLVMTypeConverter &typeConverter, 198 ArrayRef<UnrankedMemRefDescriptor> values, 199 ArrayRef<unsigned> addressSpaces, 200 SmallVectorImpl<Value> &sizes); 201 202 /// TODO: The following accessors don't take alignment rules between elements 203 /// of the descriptor struct into account. For some architectures, it might be 204 /// necessary to extend them and to use `llvm::DataLayout` contained in 205 /// `LLVMTypeConverter`. 206 207 /// Builds IR extracting the allocated pointer from the descriptor. 208 static Value allocatedPtr(OpBuilder &builder, Location loc, 209 Value memRefDescPtr, 210 LLVM::LLVMPointerType elemPtrType); 211 /// Builds IR inserting the allocated pointer into the descriptor. 212 static void setAllocatedPtr(OpBuilder &builder, Location loc, 213 Value memRefDescPtr, 214 LLVM::LLVMPointerType elemPtrType, 215 Value allocatedPtr); 216 217 /// Builds IR extracting the aligned pointer from the descriptor. 218 static Value alignedPtr(OpBuilder &builder, Location loc, 219 const LLVMTypeConverter &typeConverter, 220 Value memRefDescPtr, 221 LLVM::LLVMPointerType elemPtrType); 222 /// Builds IR inserting the aligned pointer into the descriptor. 223 static void setAlignedPtr(OpBuilder &builder, Location loc, 224 const LLVMTypeConverter &typeConverter, 225 Value memRefDescPtr, 226 LLVM::LLVMPointerType elemPtrType, 227 Value alignedPtr); 228 229 /// Builds IR for getting the pointer to the offset's location. 230 /// Returns a pointer to a convertType(index), which points to the beggining 231 /// of a struct {index, index[rank], index[rank]}. 232 static Value offsetBasePtr(OpBuilder &builder, Location loc, 233 const LLVMTypeConverter &typeConverter, 234 Value memRefDescPtr, 235 LLVM::LLVMPointerType elemPtrType); 236 /// Builds IR extracting the offset from the descriptor. 237 static Value offset(OpBuilder &builder, Location loc, 238 const LLVMTypeConverter &typeConverter, 239 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType); 240 /// Builds IR inserting the offset into the descriptor. 241 static void setOffset(OpBuilder &builder, Location loc, 242 const LLVMTypeConverter &typeConverter, 243 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, 244 Value offset); 245 246 /// Builds IR extracting the pointer to the first element of the size array. 247 static Value sizeBasePtr(OpBuilder &builder, Location loc, 248 const LLVMTypeConverter &typeConverter, 249 Value memRefDescPtr, 250 LLVM::LLVMPointerType elemPtrType); 251 /// Builds IR extracting the size[index] from the descriptor. 252 static Value size(OpBuilder &builder, Location loc, 253 const LLVMTypeConverter &typeConverter, Value sizeBasePtr, 254 Value index); 255 /// Builds IR inserting the size[index] into the descriptor. 256 static void setSize(OpBuilder &builder, Location loc, 257 const LLVMTypeConverter &typeConverter, Value sizeBasePtr, 258 Value index, Value size); 259 260 /// Builds IR extracting the pointer to the first element of the stride array. 261 static Value strideBasePtr(OpBuilder &builder, Location loc, 262 const LLVMTypeConverter &typeConverter, 263 Value sizeBasePtr, Value rank); 264 /// Builds IR extracting the stride[index] from the descriptor. 265 static Value stride(OpBuilder &builder, Location loc, 266 const LLVMTypeConverter &typeConverter, 267 Value strideBasePtr, Value index, Value stride); 268 /// Builds IR inserting the stride[index] into the descriptor. 269 static void setStride(OpBuilder &builder, Location loc, 270 const LLVMTypeConverter &typeConverter, 271 Value strideBasePtr, Value index, Value stride); 272 }; 273 274 } // namespace mlir 275 276 #endif // MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H 277