xref: /llvm-project/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10 #include "mlir/Analysis/DataLayoutAnalysis.h"
11 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/SymbolTable.h"
14 
15 using namespace mlir;
16 
17 static FailureOr<LLVM::LLVMFuncOp>
18 getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
19                      Type indexType) {
20   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
21   if (useGenericFn)
22     return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
23 
24   return LLVM::lookupOrCreateMallocFn(module, indexType);
25 }
26 
27 static FailureOr<LLVM::LLVMFuncOp>
28 getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
29                   Type indexType) {
30   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
31 
32   if (useGenericFn)
33     return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
34 
35   return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
36 }
37 
38 Value AllocationOpLLVMLowering::createAligned(
39     ConversionPatternRewriter &rewriter, Location loc, Value input,
40     Value alignment) {
41   Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
42   Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
43   Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
44   Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
45   return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
46 }
47 
48 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
49                                  Location loc, Value allocatedPtr,
50                                  MemRefType memRefType, Type elementPtrType,
51                                  const LLVMTypeConverter &typeConverter) {
52   auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
53   FailureOr<unsigned> maybeMemrefAddrSpace =
54       typeConverter.getMemRefAddressSpace(memRefType);
55   if (failed(maybeMemrefAddrSpace))
56     return Value();
57   unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
58   if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
59     allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
60         loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
61         allocatedPtr);
62   return allocatedPtr;
63 }
64 
65 std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
66     ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
67     Operation *op, Value alignment) const {
68   if (alignment) {
69     // Adjust the allocation size to consider alignment.
70     sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
71   }
72 
73   MemRefType memRefType = getMemRefResultType(op);
74   // Allocate the underlying buffer.
75   Type elementPtrType = this->getElementPtrType(memRefType);
76   if (!elementPtrType) {
77     emitError(loc, "conversion of memref memory space ")
78         << memRefType.getMemorySpace()
79         << " to integer address space "
80            "failed. Consider adding memory space conversions.";
81   }
82   FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
83       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
84       getIndexType());
85   if (failed(allocFuncOp))
86     return std::make_tuple(Value(), Value());
87   auto results =
88       rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
89 
90   Value allocatedPtr =
91       castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
92                           elementPtrType, *getTypeConverter());
93   if (!allocatedPtr)
94     return std::make_tuple(Value(), Value());
95   Value alignedPtr = allocatedPtr;
96   if (alignment) {
97     // Compute the aligned pointer.
98     Value allocatedInt =
99         rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
100     Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
101     alignedPtr =
102         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
103   }
104 
105   return std::make_tuple(allocatedPtr, alignedPtr);
106 }
107 
108 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
109     MemRefType memRefType, Operation *op,
110     const DataLayout *defaultLayout) const {
111   const DataLayout *layout = defaultLayout;
112   if (const DataLayoutAnalysis *analysis =
113           getTypeConverter()->getDataLayoutAnalysis()) {
114     layout = &analysis->getAbove(op);
115   }
116   Type elementType = memRefType.getElementType();
117   if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
118     return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
119                                                        *layout);
120   if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
121     return getTypeConverter()->getUnrankedMemRefDescriptorSize(
122         memRefElementType, *layout);
123   return layout->getTypeSize(elementType);
124 }
125 
126 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
127     MemRefType type, uint64_t factor, Operation *op,
128     const DataLayout *defaultLayout) const {
129   uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
130   for (unsigned i = 0, e = type.getRank(); i < e; i++) {
131     if (type.isDynamicDim(i))
132       continue;
133     sizeDivisor = sizeDivisor * type.getDimSize(i);
134   }
135   return sizeDivisor % factor == 0;
136 }
137 
138 Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
139     ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
140     Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
141   Value allocAlignment =
142       createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
143 
144   MemRefType memRefType = getMemRefResultType(op);
145   // Function aligned_alloc requires size to be a multiple of alignment; we pad
146   // the size to the next multiple if necessary.
147   if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
148     sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
149 
150   Type elementPtrType = this->getElementPtrType(memRefType);
151   FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
152       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
153       getIndexType());
154   if (failed(allocFuncOp))
155     return Value();
156   auto results = rewriter.create<LLVM::CallOp>(
157       loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
158 
159   return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
160                              elementPtrType, *getTypeConverter());
161 }
162 
163 void AllocLikeOpLLVMLowering::setRequiresNumElements() {
164   requiresNumElements = true;
165 }
166 
167 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
168     Operation *op, ArrayRef<Value> operands,
169     ConversionPatternRewriter &rewriter) const {
170   MemRefType memRefType = getMemRefResultType(op);
171   if (!isConvertibleAndHasIdentityMaps(memRefType))
172     return rewriter.notifyMatchFailure(op, "incompatible memref type");
173   auto loc = op->getLoc();
174 
175   // Get actual sizes of the memref as values: static sizes are constant
176   // values and dynamic sizes are passed to 'alloc' as operands.  In case of
177   // zero-dimensional memref, assume a scalar (size 1).
178   SmallVector<Value, 4> sizes;
179   SmallVector<Value, 4> strides;
180   Value size;
181 
182   this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
183                                  strides, size, !requiresNumElements);
184 
185   // Allocate the underlying buffer.
186   auto [allocatedPtr, alignedPtr] =
187       this->allocateBuffer(rewriter, loc, size, op);
188 
189   if (!allocatedPtr || !alignedPtr)
190     return rewriter.notifyMatchFailure(loc,
191                                        "underlying buffer allocation failed");
192 
193   // Create the MemRef descriptor.
194   auto memRefDescriptor = this->createMemRefDescriptor(
195       loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
196 
197   // Return the final value of the descriptor.
198   rewriter.replaceOp(op, {memRefDescriptor});
199   return success();
200 }
201