xref: /llvm-project/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp (revision 6900768719ff6d38403f39ceb75e0ec953278f5a)
1b5d847b1SAlex Zinenko //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===//
2b5d847b1SAlex Zinenko //
3b5d847b1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b5d847b1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5b5d847b1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b5d847b1SAlex Zinenko //
7b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===//
8b5d847b1SAlex Zinenko 
9b5d847b1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
10b5d847b1SAlex Zinenko #include "MemRefDescriptor.h"
11b5d847b1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12b5d847b1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
137fb9bbe5SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14b5d847b1SAlex Zinenko #include "mlir/IR/Builders.h"
150fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h"
16b5d847b1SAlex Zinenko 
17b5d847b1SAlex Zinenko using namespace mlir;
18b5d847b1SAlex Zinenko 
19b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===//
20b5d847b1SAlex Zinenko // MemRefDescriptor implementation
21b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===//
22b5d847b1SAlex Zinenko 
23b5d847b1SAlex Zinenko /// Construct a helper for the given descriptor value.
24b5d847b1SAlex Zinenko MemRefDescriptor::MemRefDescriptor(Value descriptor)
25b5d847b1SAlex Zinenko     : StructBuilder(descriptor) {
26b5d847b1SAlex Zinenko   assert(value != nullptr && "value cannot be null");
275550c821STres Popp   indexType = cast<LLVM::LLVMStructType>(value.getType())
28b5d847b1SAlex Zinenko                   .getBody()[kOffsetPosInMemRefDescriptor];
29b5d847b1SAlex Zinenko }
30b5d847b1SAlex Zinenko 
31b5d847b1SAlex Zinenko /// Builds IR creating an `undef` value of the descriptor type.
32b5d847b1SAlex Zinenko MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
33b5d847b1SAlex Zinenko                                          Type descriptorType) {
34b5d847b1SAlex Zinenko 
35b5d847b1SAlex Zinenko   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36b5d847b1SAlex Zinenko   return MemRefDescriptor(descriptor);
37b5d847b1SAlex Zinenko }
38b5d847b1SAlex Zinenko 
39b5d847b1SAlex Zinenko /// Builds IR creating a MemRef descriptor that represents `type` and
40b5d847b1SAlex Zinenko /// populates it with static shape and stride information extracted from the
41b5d847b1SAlex Zinenko /// type.
42b5d847b1SAlex Zinenko MemRefDescriptor
43b5d847b1SAlex Zinenko MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
44ce254598SMatthias Springer                                   const LLVMTypeConverter &typeConverter,
45b5d847b1SAlex Zinenko                                   MemRefType type, Value memory) {
46200266a0SQuentin Colombet   return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
47200266a0SQuentin Colombet }
48200266a0SQuentin Colombet 
49200266a0SQuentin Colombet MemRefDescriptor MemRefDescriptor::fromStaticShape(
50ce254598SMatthias Springer     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
51200266a0SQuentin Colombet     MemRefType type, Value memory, Value alignedMemory) {
52b5d847b1SAlex Zinenko   assert(type.hasStaticShape() && "unexpected dynamic shape");
53b5d847b1SAlex Zinenko 
54b5d847b1SAlex Zinenko   // Extract all strides and offsets and verify they are static.
556aaa8f25SMatthias Springer   auto [strides, offset] = type.getStridesAndOffset();
56b28a296cSChristian Ulmann   assert(!ShapedType::isDynamic(offset) && "expected static offset");
57399638f9SAliia Khasanova   assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
58380a1b20SKazu Hirata          "expected static strides");
59b5d847b1SAlex Zinenko 
60b5d847b1SAlex Zinenko   auto convertedType = typeConverter.convertType(type);
61b5d847b1SAlex Zinenko   assert(convertedType && "unexpected failure in memref type conversion");
62b5d847b1SAlex Zinenko 
63b5d847b1SAlex Zinenko   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
64b5d847b1SAlex Zinenko   descr.setAllocatedPtr(builder, loc, memory);
65200266a0SQuentin Colombet   descr.setAlignedPtr(builder, loc, alignedMemory);
66b5d847b1SAlex Zinenko   descr.setConstantOffset(builder, loc, offset);
67b5d847b1SAlex Zinenko 
68b5d847b1SAlex Zinenko   // Fill in sizes and strides
69b5d847b1SAlex Zinenko   for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
70b5d847b1SAlex Zinenko     descr.setConstantSize(builder, loc, i, type.getDimSize(i));
71b5d847b1SAlex Zinenko     descr.setConstantStride(builder, loc, i, strides[i]);
72b5d847b1SAlex Zinenko   }
73b5d847b1SAlex Zinenko   return descr;
74b5d847b1SAlex Zinenko }
75b5d847b1SAlex Zinenko 
76b5d847b1SAlex Zinenko /// Builds IR extracting the allocated pointer from the descriptor.
77b5d847b1SAlex Zinenko Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
78b5d847b1SAlex Zinenko   return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
79b5d847b1SAlex Zinenko }
80b5d847b1SAlex Zinenko 
81b5d847b1SAlex Zinenko /// Builds IR inserting the allocated pointer into the descriptor.
82b5d847b1SAlex Zinenko void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
83b5d847b1SAlex Zinenko                                        Value ptr) {
84b5d847b1SAlex Zinenko   setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
85b5d847b1SAlex Zinenko }
86b5d847b1SAlex Zinenko 
87b5d847b1SAlex Zinenko /// Builds IR extracting the aligned pointer from the descriptor.
88b5d847b1SAlex Zinenko Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
89b5d847b1SAlex Zinenko   return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
90b5d847b1SAlex Zinenko }
91b5d847b1SAlex Zinenko 
92b5d847b1SAlex Zinenko /// Builds IR inserting the aligned pointer into the descriptor.
93b5d847b1SAlex Zinenko void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
94b5d847b1SAlex Zinenko                                      Value ptr) {
95b5d847b1SAlex Zinenko   setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
96b5d847b1SAlex Zinenko }
97b5d847b1SAlex Zinenko 
98b5d847b1SAlex Zinenko // Creates a constant Op producing a value of `resultType` from an index-typed
99b5d847b1SAlex Zinenko // integer attribute.
100b5d847b1SAlex Zinenko static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
101b5d847b1SAlex Zinenko                                      Type resultType, int64_t value) {
1020af643f3SJeff Niu   return builder.create<LLVM::ConstantOp>(loc, resultType,
1030af643f3SJeff Niu                                           builder.getIndexAttr(value));
104b5d847b1SAlex Zinenko }
105b5d847b1SAlex Zinenko 
106b5d847b1SAlex Zinenko /// Builds IR extracting the offset from the descriptor.
107b5d847b1SAlex Zinenko Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
1085c5af910SJeff Niu   return builder.create<LLVM::ExtractValueOp>(loc, value,
1095c5af910SJeff Niu                                               kOffsetPosInMemRefDescriptor);
110b5d847b1SAlex Zinenko }
111b5d847b1SAlex Zinenko 
112b5d847b1SAlex Zinenko /// Builds IR inserting the offset into the descriptor.
113b5d847b1SAlex Zinenko void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
114b5d847b1SAlex Zinenko                                  Value offset) {
1155c5af910SJeff Niu   value = builder.create<LLVM::InsertValueOp>(loc, value, offset,
1165c5af910SJeff Niu                                               kOffsetPosInMemRefDescriptor);
117b5d847b1SAlex Zinenko }
118b5d847b1SAlex Zinenko 
119b5d847b1SAlex Zinenko /// Builds IR inserting the offset into the descriptor.
120b5d847b1SAlex Zinenko void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
121b5d847b1SAlex Zinenko                                          uint64_t offset) {
122b5d847b1SAlex Zinenko   setOffset(builder, loc,
123b5d847b1SAlex Zinenko             createIndexAttrConstant(builder, loc, indexType, offset));
124b5d847b1SAlex Zinenko }
125b5d847b1SAlex Zinenko 
126b5d847b1SAlex Zinenko /// Builds IR extracting the pos-th size from the descriptor.
127b5d847b1SAlex Zinenko Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
128b5d847b1SAlex Zinenko   return builder.create<LLVM::ExtractValueOp>(
12982973067SUday Bondhugula       loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
130b5d847b1SAlex Zinenko }
131b5d847b1SAlex Zinenko 
132b5d847b1SAlex Zinenko Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
133b5d847b1SAlex Zinenko                              int64_t rank) {
134b5d847b1SAlex Zinenko   auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
13550ea17b8SMarkus Böck 
136b28a296cSChristian Ulmann   auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
137b5d847b1SAlex Zinenko 
138b5d847b1SAlex Zinenko   // Copy size values to stack-allocated memory.
139b5d847b1SAlex Zinenko   auto one = createIndexAttrConstant(builder, loc, indexType, 1);
140b5d847b1SAlex Zinenko   auto sizes = builder.create<LLVM::ExtractValueOp>(
141984b800aSserge-sans-paille       loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor}));
142b28a296cSChristian Ulmann   auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one,
14350ea17b8SMarkus Böck                                                  /*alignment=*/0);
144b5d847b1SAlex Zinenko   builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
145b5d847b1SAlex Zinenko 
146b5d847b1SAlex Zinenko   // Load an return size value of interest.
147b28a296cSChristian Ulmann   auto resultPtr = builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr,
148b28a296cSChristian Ulmann                                                ArrayRef<LLVM::GEPArg>{0, pos});
14950ea17b8SMarkus Böck   return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr);
150b5d847b1SAlex Zinenko }
151b5d847b1SAlex Zinenko 
152b5d847b1SAlex Zinenko /// Builds IR inserting the pos-th size into the descriptor
153b5d847b1SAlex Zinenko void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
154b5d847b1SAlex Zinenko                                Value size) {
155b5d847b1SAlex Zinenko   value = builder.create<LLVM::InsertValueOp>(
15682973067SUday Bondhugula       loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
157b5d847b1SAlex Zinenko }
158b5d847b1SAlex Zinenko 
159b5d847b1SAlex Zinenko void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
160b5d847b1SAlex Zinenko                                        unsigned pos, uint64_t size) {
161b5d847b1SAlex Zinenko   setSize(builder, loc, pos,
162b5d847b1SAlex Zinenko           createIndexAttrConstant(builder, loc, indexType, size));
163b5d847b1SAlex Zinenko }
164b5d847b1SAlex Zinenko 
165b5d847b1SAlex Zinenko /// Builds IR extracting the pos-th stride from the descriptor.
166b5d847b1SAlex Zinenko Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
167b5d847b1SAlex Zinenko   return builder.create<LLVM::ExtractValueOp>(
16882973067SUday Bondhugula       loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
169b5d847b1SAlex Zinenko }
170b5d847b1SAlex Zinenko 
171b5d847b1SAlex Zinenko /// Builds IR inserting the pos-th stride into the descriptor
172b5d847b1SAlex Zinenko void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
173b5d847b1SAlex Zinenko                                  Value stride) {
174b5d847b1SAlex Zinenko   value = builder.create<LLVM::InsertValueOp>(
1755c5af910SJeff Niu       loc, value, stride,
17682973067SUday Bondhugula       ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
177b5d847b1SAlex Zinenko }
178b5d847b1SAlex Zinenko 
179b5d847b1SAlex Zinenko void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
180b5d847b1SAlex Zinenko                                          unsigned pos, uint64_t stride) {
181b5d847b1SAlex Zinenko   setStride(builder, loc, pos,
182b5d847b1SAlex Zinenko             createIndexAttrConstant(builder, loc, indexType, stride));
183b5d847b1SAlex Zinenko }
184b5d847b1SAlex Zinenko 
185b5d847b1SAlex Zinenko LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
1865550c821STres Popp   return cast<LLVM::LLVMPointerType>(
1875550c821STres Popp       cast<LLVM::LLVMStructType>(value.getType())
1885550c821STres Popp           .getBody()[kAlignedPtrPosInMemRefDescriptor]);
189b5d847b1SAlex Zinenko }
190b5d847b1SAlex Zinenko 
191e02d4142SQuentin Colombet Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
192ce254598SMatthias Springer                                   const LLVMTypeConverter &converter,
193e02d4142SQuentin Colombet                                   MemRefType type) {
194e02d4142SQuentin Colombet   // When we convert to LLVM, the input memref must have been normalized
195e02d4142SQuentin Colombet   // beforehand. Hence, this call is guaranteed to work.
1966aaa8f25SMatthias Springer   auto [strides, offsetCst] = type.getStridesAndOffset();
197e02d4142SQuentin Colombet 
198e02d4142SQuentin Colombet   Value ptr = alignedPtr(builder, loc);
199bba9209fSQuentin Colombet   // For zero offsets, we already have the base pointer.
200bba9209fSQuentin Colombet   if (offsetCst == 0)
201bba9209fSQuentin Colombet     return ptr;
202bba9209fSQuentin Colombet 
203bba9209fSQuentin Colombet   // Otherwise add the offset to the aligned base.
204e02d4142SQuentin Colombet   Type indexType = converter.getIndexType();
205e02d4142SQuentin Colombet   Value offsetVal =
206e02d4142SQuentin Colombet       ShapedType::isDynamic(offsetCst)
207e02d4142SQuentin Colombet           ? offset(builder, loc)
208e02d4142SQuentin Colombet           : createIndexAttrConstant(builder, loc, indexType, offsetCst);
209e02d4142SQuentin Colombet   Type elementType = converter.convertType(type.getElementType());
210e02d4142SQuentin Colombet   ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
211e02d4142SQuentin Colombet                                     offsetVal);
212e02d4142SQuentin Colombet   return ptr;
213e02d4142SQuentin Colombet }
214e02d4142SQuentin Colombet 
215b5d847b1SAlex Zinenko /// Creates a MemRef descriptor structure from a list of individual values
216b5d847b1SAlex Zinenko /// composing that descriptor, in the following order:
217b5d847b1SAlex Zinenko /// - allocated pointer;
218b5d847b1SAlex Zinenko /// - aligned pointer;
219b5d847b1SAlex Zinenko /// - offset;
220b5d847b1SAlex Zinenko /// - <rank> sizes;
221*69007687SMatthias Springer /// - <rank> strides;
222b5d847b1SAlex Zinenko /// where <rank> is the MemRef rank as provided in `type`.
223b5d847b1SAlex Zinenko Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
224ce254598SMatthias Springer                              const LLVMTypeConverter &converter,
225ce254598SMatthias Springer                              MemRefType type, ValueRange values) {
226b5d847b1SAlex Zinenko   Type llvmType = converter.convertType(type);
227b5d847b1SAlex Zinenko   auto d = MemRefDescriptor::undef(builder, loc, llvmType);
228b5d847b1SAlex Zinenko 
229b5d847b1SAlex Zinenko   d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
230b5d847b1SAlex Zinenko   d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
231b5d847b1SAlex Zinenko   d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
232b5d847b1SAlex Zinenko 
233b5d847b1SAlex Zinenko   int64_t rank = type.getRank();
234b5d847b1SAlex Zinenko   for (unsigned i = 0; i < rank; ++i) {
235b5d847b1SAlex Zinenko     d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
236b5d847b1SAlex Zinenko     d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
237b5d847b1SAlex Zinenko   }
238b5d847b1SAlex Zinenko 
239b5d847b1SAlex Zinenko   return d;
240b5d847b1SAlex Zinenko }
241b5d847b1SAlex Zinenko 
242b5d847b1SAlex Zinenko /// Builds IR extracting individual elements of a MemRef descriptor structure
243b5d847b1SAlex Zinenko /// and returning them as `results` list.
244b5d847b1SAlex Zinenko void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
245b5d847b1SAlex Zinenko                               MemRefType type,
246b5d847b1SAlex Zinenko                               SmallVectorImpl<Value> &results) {
247b5d847b1SAlex Zinenko   int64_t rank = type.getRank();
248b5d847b1SAlex Zinenko   results.reserve(results.size() + getNumUnpackedValues(type));
249b5d847b1SAlex Zinenko 
250b5d847b1SAlex Zinenko   MemRefDescriptor d(packed);
251b5d847b1SAlex Zinenko   results.push_back(d.allocatedPtr(builder, loc));
252b5d847b1SAlex Zinenko   results.push_back(d.alignedPtr(builder, loc));
253b5d847b1SAlex Zinenko   results.push_back(d.offset(builder, loc));
254b5d847b1SAlex Zinenko   for (int64_t i = 0; i < rank; ++i)
255b5d847b1SAlex Zinenko     results.push_back(d.size(builder, loc, i));
256b5d847b1SAlex Zinenko   for (int64_t i = 0; i < rank; ++i)
257b5d847b1SAlex Zinenko     results.push_back(d.stride(builder, loc, i));
258b5d847b1SAlex Zinenko }
259b5d847b1SAlex Zinenko 
260b5d847b1SAlex Zinenko /// Returns the number of non-aggregate values that would be produced by
261b5d847b1SAlex Zinenko /// `unpack`.
262b5d847b1SAlex Zinenko unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
263*69007687SMatthias Springer   // Two pointers, offset, <rank> sizes, <rank> strides.
264b5d847b1SAlex Zinenko   return 3 + 2 * type.getRank();
265b5d847b1SAlex Zinenko }
266b5d847b1SAlex Zinenko 
267b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===//
268b5d847b1SAlex Zinenko // MemRefDescriptorView implementation.
269b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===//
270b5d847b1SAlex Zinenko 
271b5d847b1SAlex Zinenko MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
272b5d847b1SAlex Zinenko     : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
273b5d847b1SAlex Zinenko 
274b5d847b1SAlex Zinenko Value MemRefDescriptorView::allocatedPtr() {
275b5d847b1SAlex Zinenko   return elements[kAllocatedPtrPosInMemRefDescriptor];
276b5d847b1SAlex Zinenko }
277b5d847b1SAlex Zinenko 
278b5d847b1SAlex Zinenko Value MemRefDescriptorView::alignedPtr() {
279b5d847b1SAlex Zinenko   return elements[kAlignedPtrPosInMemRefDescriptor];
280b5d847b1SAlex Zinenko }
281b5d847b1SAlex Zinenko 
282b5d847b1SAlex Zinenko Value MemRefDescriptorView::offset() {
283b5d847b1SAlex Zinenko   return elements[kOffsetPosInMemRefDescriptor];
284b5d847b1SAlex Zinenko }
285b5d847b1SAlex Zinenko 
286b5d847b1SAlex Zinenko Value MemRefDescriptorView::size(unsigned pos) {
287b5d847b1SAlex Zinenko   return elements[kSizePosInMemRefDescriptor + pos];
288b5d847b1SAlex Zinenko }
289b5d847b1SAlex Zinenko 
290b5d847b1SAlex Zinenko Value MemRefDescriptorView::stride(unsigned pos) {
291b5d847b1SAlex Zinenko   return elements[kSizePosInMemRefDescriptor + rank + pos];
292b5d847b1SAlex Zinenko }
293b5d847b1SAlex Zinenko 
294b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===//
295b5d847b1SAlex Zinenko // UnrankedMemRefDescriptor implementation
296b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===//
297b5d847b1SAlex Zinenko 
298b5d847b1SAlex Zinenko /// Construct a helper for the given descriptor value.
299b5d847b1SAlex Zinenko UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
300b5d847b1SAlex Zinenko     : StructBuilder(descriptor) {}
301b5d847b1SAlex Zinenko 
302b5d847b1SAlex Zinenko /// Builds IR creating an `undef` value of the descriptor type.
303b5d847b1SAlex Zinenko UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
304b5d847b1SAlex Zinenko                                                          Location loc,
305b5d847b1SAlex Zinenko                                                          Type descriptorType) {
306b5d847b1SAlex Zinenko   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
307b5d847b1SAlex Zinenko   return UnrankedMemRefDescriptor(descriptor);
308b5d847b1SAlex Zinenko }
309d0f19ce7SKrzysztof Drewniak Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
310b5d847b1SAlex Zinenko   return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
311b5d847b1SAlex Zinenko }
312b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
313b5d847b1SAlex Zinenko                                        Value v) {
314b5d847b1SAlex Zinenko   setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
315b5d847b1SAlex Zinenko }
316b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
317d0f19ce7SKrzysztof Drewniak                                               Location loc) const {
318b5d847b1SAlex Zinenko   return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
319b5d847b1SAlex Zinenko }
320b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
321b5d847b1SAlex Zinenko                                                 Location loc, Value v) {
322b5d847b1SAlex Zinenko   setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
323b5d847b1SAlex Zinenko }
324b5d847b1SAlex Zinenko 
325b5d847b1SAlex Zinenko /// Builds IR populating an unranked MemRef descriptor structure from a list
326b5d847b1SAlex Zinenko /// of individual constituent values in the following order:
327b5d847b1SAlex Zinenko /// - rank of the memref;
328b5d847b1SAlex Zinenko /// - pointer to the memref descriptor.
329b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
330ce254598SMatthias Springer                                      const LLVMTypeConverter &converter,
331b5d847b1SAlex Zinenko                                      UnrankedMemRefType type,
332b5d847b1SAlex Zinenko                                      ValueRange values) {
333b5d847b1SAlex Zinenko   Type llvmType = converter.convertType(type);
334b5d847b1SAlex Zinenko   auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
335b5d847b1SAlex Zinenko 
336b5d847b1SAlex Zinenko   d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
337b5d847b1SAlex Zinenko   d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
338b5d847b1SAlex Zinenko   return d;
339b5d847b1SAlex Zinenko }
340b5d847b1SAlex Zinenko 
341b5d847b1SAlex Zinenko /// Builds IR extracting individual elements that compose an unranked memref
342b5d847b1SAlex Zinenko /// descriptor and returns them as `results` list.
343b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
344b5d847b1SAlex Zinenko                                       Value packed,
345b5d847b1SAlex Zinenko                                       SmallVectorImpl<Value> &results) {
346b5d847b1SAlex Zinenko   UnrankedMemRefDescriptor d(packed);
347b5d847b1SAlex Zinenko   results.reserve(results.size() + 2);
348b5d847b1SAlex Zinenko   results.push_back(d.rank(builder, loc));
349b5d847b1SAlex Zinenko   results.push_back(d.memRefDescPtr(builder, loc));
350b5d847b1SAlex Zinenko }
351b5d847b1SAlex Zinenko 
352b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::computeSizes(
353ce254598SMatthias Springer     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
354d0f19ce7SKrzysztof Drewniak     ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
355d0f19ce7SKrzysztof Drewniak     SmallVectorImpl<Value> &sizes) {
356b5d847b1SAlex Zinenko   if (values.empty())
357b5d847b1SAlex Zinenko     return;
358d0f19ce7SKrzysztof Drewniak   assert(values.size() == addressSpaces.size() &&
359d0f19ce7SKrzysztof Drewniak          "must provide address space for each descriptor");
360b5d847b1SAlex Zinenko   // Cache the index type.
361b5d847b1SAlex Zinenko   Type indexType = typeConverter.getIndexType();
362b5d847b1SAlex Zinenko 
363b5d847b1SAlex Zinenko   // Initialize shared constants.
364b5d847b1SAlex Zinenko   Value one = createIndexAttrConstant(builder, loc, indexType, 1);
365b5d847b1SAlex Zinenko   Value two = createIndexAttrConstant(builder, loc, indexType, 2);
3660fb216fbSRamkumar Ramachandra   Value indexSize = createIndexAttrConstant(
3670fb216fbSRamkumar Ramachandra       builder, loc, indexType,
368e843f029SRamkumar Ramachandra       llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
369b5d847b1SAlex Zinenko 
370b5d847b1SAlex Zinenko   sizes.reserve(sizes.size() + values.size());
371d0f19ce7SKrzysztof Drewniak   for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
372b5d847b1SAlex Zinenko     // Emit IR computing the memory necessary to store the descriptor. This
373b5d847b1SAlex Zinenko     // assumes the descriptor to be
374b5d847b1SAlex Zinenko     //   { type*, type*, index, index[rank], index[rank] }
375b5d847b1SAlex Zinenko     // and densely packed, so the total size is
376b5d847b1SAlex Zinenko     //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
377b5d847b1SAlex Zinenko     // TODO: consider including the actual size (including eventual padding due
378b5d847b1SAlex Zinenko     // to data layout) into the unranked descriptor.
379d0f19ce7SKrzysztof Drewniak     Value pointerSize = createIndexAttrConstant(
380d0f19ce7SKrzysztof Drewniak         builder, loc, indexType,
381e843f029SRamkumar Ramachandra         llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
382b5d847b1SAlex Zinenko     Value doublePointerSize =
383b5d847b1SAlex Zinenko         builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
384b5d847b1SAlex Zinenko 
385b5d847b1SAlex Zinenko     // (1 + 2 * rank) * sizeof(index)
386b5d847b1SAlex Zinenko     Value rank = desc.rank(builder, loc);
387b5d847b1SAlex Zinenko     Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
388b5d847b1SAlex Zinenko     Value doubleRankIncremented =
389b5d847b1SAlex Zinenko         builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
390b5d847b1SAlex Zinenko     Value rankIndexSize = builder.create<LLVM::MulOp>(
391b5d847b1SAlex Zinenko         loc, indexType, doubleRankIncremented, indexSize);
392b5d847b1SAlex Zinenko 
393b5d847b1SAlex Zinenko     // Total allocation size.
394b5d847b1SAlex Zinenko     Value allocationSize = builder.create<LLVM::AddOp>(
395b5d847b1SAlex Zinenko         loc, indexType, doublePointerSize, rankIndexSize);
396b5d847b1SAlex Zinenko     sizes.push_back(allocationSize);
397b5d847b1SAlex Zinenko   }
398b5d847b1SAlex Zinenko }
399b5d847b1SAlex Zinenko 
40050ea17b8SMarkus Böck Value UnrankedMemRefDescriptor::allocatedPtr(
40150ea17b8SMarkus Böck     OpBuilder &builder, Location loc, Value memRefDescPtr,
40250ea17b8SMarkus Böck     LLVM::LLVMPointerType elemPtrType) {
403b28a296cSChristian Ulmann   return builder.create<LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr);
404b5d847b1SAlex Zinenko }
405b5d847b1SAlex Zinenko 
40650ea17b8SMarkus Böck void UnrankedMemRefDescriptor::setAllocatedPtr(
40750ea17b8SMarkus Böck     OpBuilder &builder, Location loc, Value memRefDescPtr,
40850ea17b8SMarkus Böck     LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) {
409b28a296cSChristian Ulmann   builder.create<LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr);
410b5d847b1SAlex Zinenko }
411b5d847b1SAlex Zinenko 
41250ea17b8SMarkus Böck static std::pair<Value, Type>
41350ea17b8SMarkus Böck castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr,
41450ea17b8SMarkus Böck                  LLVM::LLVMPointerType elemPtrType) {
415b28a296cSChristian Ulmann   auto elemPtrPtrType = LLVM::LLVMPointerType::get(builder.getContext());
416b28a296cSChristian Ulmann   return {memRefDescPtr, elemPtrPtrType};
41750ea17b8SMarkus Böck }
41850ea17b8SMarkus Böck 
419ce254598SMatthias Springer Value UnrankedMemRefDescriptor::alignedPtr(
420ce254598SMatthias Springer     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
421ce254598SMatthias Springer     Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
42250ea17b8SMarkus Böck   auto [elementPtrPtr, elemPtrPtrType] =
42350ea17b8SMarkus Böck       castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
424b5d847b1SAlex Zinenko 
42550ea17b8SMarkus Böck   Value alignedGep =
42650ea17b8SMarkus Böck       builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
42750ea17b8SMarkus Böck                                   elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
42850ea17b8SMarkus Böck   return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep);
429b5d847b1SAlex Zinenko }
430b5d847b1SAlex Zinenko 
431ce254598SMatthias Springer void UnrankedMemRefDescriptor::setAlignedPtr(
432ce254598SMatthias Springer     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
433ce254598SMatthias Springer     Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr) {
43450ea17b8SMarkus Böck   auto [elementPtrPtr, elemPtrPtrType] =
43550ea17b8SMarkus Böck       castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
436b5d847b1SAlex Zinenko 
43750ea17b8SMarkus Böck   Value alignedGep =
43850ea17b8SMarkus Böck       builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
43950ea17b8SMarkus Böck                                   elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
440b5d847b1SAlex Zinenko   builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
441b5d847b1SAlex Zinenko }
442b5d847b1SAlex Zinenko 
4437fb9bbe5SKrzysztof Drewniak Value UnrankedMemRefDescriptor::offsetBasePtr(
444ce254598SMatthias Springer     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
4457fb9bbe5SKrzysztof Drewniak     Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
44650ea17b8SMarkus Böck   auto [elementPtrPtr, elemPtrPtrType] =
44750ea17b8SMarkus Böck       castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
448b5d847b1SAlex Zinenko 
449b28a296cSChristian Ulmann   return builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
45050ea17b8SMarkus Böck                                      elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
4517fb9bbe5SKrzysztof Drewniak }
45250ea17b8SMarkus Böck 
4537fb9bbe5SKrzysztof Drewniak Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
454ce254598SMatthias Springer                                        const LLVMTypeConverter &typeConverter,
4557fb9bbe5SKrzysztof Drewniak                                        Value memRefDescPtr,
4567fb9bbe5SKrzysztof Drewniak                                        LLVM::LLVMPointerType elemPtrType) {
4577fb9bbe5SKrzysztof Drewniak   Value offsetPtr =
4587fb9bbe5SKrzysztof Drewniak       offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
4596a10381bSStephan Herhut   return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(),
4606a10381bSStephan Herhut                                       offsetPtr);
461b5d847b1SAlex Zinenko }
462b5d847b1SAlex Zinenko 
463b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
464ce254598SMatthias Springer                                          const LLVMTypeConverter &typeConverter,
465b5d847b1SAlex Zinenko                                          Value memRefDescPtr,
46650ea17b8SMarkus Böck                                          LLVM::LLVMPointerType elemPtrType,
46750ea17b8SMarkus Böck                                          Value offset) {
4687fb9bbe5SKrzysztof Drewniak   Value offsetPtr =
4697fb9bbe5SKrzysztof Drewniak       offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
4707fb9bbe5SKrzysztof Drewniak   builder.create<LLVM::StoreOp>(loc, offset, offsetPtr);
471b5d847b1SAlex Zinenko }
472b5d847b1SAlex Zinenko 
473ce254598SMatthias Springer Value UnrankedMemRefDescriptor::sizeBasePtr(
474ce254598SMatthias Springer     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
475ce254598SMatthias Springer     Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
476b5d847b1SAlex Zinenko   Type indexTy = typeConverter.getIndexType();
47750ea17b8SMarkus Böck   Type structTy = LLVM::LLVMStructType::getLiteral(
47850ea17b8SMarkus Böck       indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy});
479b28a296cSChristian Ulmann   auto resultType = LLVM::LLVMPointerType::get(builder.getContext());
480b28a296cSChristian Ulmann   return builder.create<LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr,
48150ea17b8SMarkus Böck                                      ArrayRef<LLVM::GEPArg>{0, 3});
482b5d847b1SAlex Zinenko }
483b5d847b1SAlex Zinenko 
484b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
485ce254598SMatthias Springer                                      const LLVMTypeConverter &typeConverter,
486b5d847b1SAlex Zinenko                                      Value sizeBasePtr, Value index) {
48750ea17b8SMarkus Böck 
48850ea17b8SMarkus Böck   Type indexTy = typeConverter.getIndexType();
48997a238e8SChristian Ulmann   auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
49050ea17b8SMarkus Böck 
491bd7eff1fSMarkus Böck   Value sizeStoreGep =
49297a238e8SChristian Ulmann       builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index);
49350ea17b8SMarkus Böck   return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep);
494b5d847b1SAlex Zinenko }
495b5d847b1SAlex Zinenko 
496b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
497ce254598SMatthias Springer                                        const LLVMTypeConverter &typeConverter,
498b5d847b1SAlex Zinenko                                        Value sizeBasePtr, Value index,
499b5d847b1SAlex Zinenko                                        Value size) {
50050ea17b8SMarkus Böck   Type indexTy = typeConverter.getIndexType();
50197a238e8SChristian Ulmann   auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
50250ea17b8SMarkus Böck 
503bd7eff1fSMarkus Böck   Value sizeStoreGep =
50497a238e8SChristian Ulmann       builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index);
505b5d847b1SAlex Zinenko   builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
506b5d847b1SAlex Zinenko }
507b5d847b1SAlex Zinenko 
508ce254598SMatthias Springer Value UnrankedMemRefDescriptor::strideBasePtr(
509ce254598SMatthias Springer     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
510b5d847b1SAlex Zinenko     Value sizeBasePtr, Value rank) {
51150ea17b8SMarkus Böck   Type indexTy = typeConverter.getIndexType();
51297a238e8SChristian Ulmann   auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
51350ea17b8SMarkus Böck 
51497a238e8SChristian Ulmann   return builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, rank);
515b5d847b1SAlex Zinenko }
516b5d847b1SAlex Zinenko 
517b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
518ce254598SMatthias Springer                                        const LLVMTypeConverter &typeConverter,
519b5d847b1SAlex Zinenko                                        Value strideBasePtr, Value index,
520b5d847b1SAlex Zinenko                                        Value stride) {
52150ea17b8SMarkus Böck   Type indexTy = typeConverter.getIndexType();
52297a238e8SChristian Ulmann   auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
52350ea17b8SMarkus Böck 
52497a238e8SChristian Ulmann   Value strideStoreGep =
52597a238e8SChristian Ulmann       builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index);
52650ea17b8SMarkus Böck   return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep);
527b5d847b1SAlex Zinenko }
528b5d847b1SAlex Zinenko 
529b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
530ce254598SMatthias Springer                                          const LLVMTypeConverter &typeConverter,
531b5d847b1SAlex Zinenko                                          Value strideBasePtr, Value index,
532b5d847b1SAlex Zinenko                                          Value stride) {
53350ea17b8SMarkus Böck   Type indexTy = typeConverter.getIndexType();
53497a238e8SChristian Ulmann   auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
53550ea17b8SMarkus Böck 
53697a238e8SChristian Ulmann   Value strideStoreGep =
53797a238e8SChristian Ulmann       builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index);
538b5d847b1SAlex Zinenko   builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
539b5d847b1SAlex Zinenko }
540