1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 10 #include "mlir/Analysis/DataLayoutAnalysis.h" 11 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/IR/SymbolTable.h" 14 15 using namespace mlir; 16 17 static FailureOr<LLVM::LLVMFuncOp> 18 getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, 19 Type indexType) { 20 bool useGenericFn = typeConverter->getOptions().useGenericFunctions; 21 if (useGenericFn) 22 return LLVM::lookupOrCreateGenericAllocFn(module, indexType); 23 24 return LLVM::lookupOrCreateMallocFn(module, indexType); 25 } 26 27 static FailureOr<LLVM::LLVMFuncOp> 28 getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, 29 Type indexType) { 30 bool useGenericFn = typeConverter->getOptions().useGenericFunctions; 31 32 if (useGenericFn) 33 return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType); 34 35 return LLVM::lookupOrCreateAlignedAllocFn(module, indexType); 36 } 37 38 Value AllocationOpLLVMLowering::createAligned( 39 ConversionPatternRewriter &rewriter, Location loc, Value input, 40 Value alignment) { 41 Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); 42 Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one); 43 Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump); 44 Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment); 45 return rewriter.create<LLVM::SubOp>(loc, bumped, mod); 46 } 47 48 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, 49 Location loc, Value allocatedPtr, 50 MemRefType memRefType, Type elementPtrType, 51 const LLVMTypeConverter &typeConverter) { 52 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType()); 53 FailureOr<unsigned> maybeMemrefAddrSpace = 54 typeConverter.getMemRefAddressSpace(memRefType); 55 if (failed(maybeMemrefAddrSpace)) 56 return Value(); 57 unsigned memrefAddrSpace = *maybeMemrefAddrSpace; 58 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) 59 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 60 loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), 61 allocatedPtr); 62 return allocatedPtr; 63 } 64 65 std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign( 66 ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, 67 Operation *op, Value alignment) const { 68 if (alignment) { 69 // Adjust the allocation size to consider alignment. 70 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 71 } 72 73 MemRefType memRefType = getMemRefResultType(op); 74 // Allocate the underlying buffer. 75 Type elementPtrType = this->getElementPtrType(memRefType); 76 if (!elementPtrType) { 77 emitError(loc, "conversion of memref memory space ") 78 << memRefType.getMemorySpace() 79 << " to integer address space " 80 "failed. Consider adding memory space conversions."; 81 } 82 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn( 83 getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(), 84 getIndexType()); 85 if (failed(allocFuncOp)) 86 return std::make_tuple(Value(), Value()); 87 auto results = 88 rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes); 89 90 Value allocatedPtr = 91 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, 92 elementPtrType, *getTypeConverter()); 93 if (!allocatedPtr) 94 return std::make_tuple(Value(), Value()); 95 Value alignedPtr = allocatedPtr; 96 if (alignment) { 97 // Compute the aligned pointer. 98 Value allocatedInt = 99 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 100 Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); 101 alignedPtr = 102 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 103 } 104 105 return std::make_tuple(allocatedPtr, alignedPtr); 106 } 107 108 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes( 109 MemRefType memRefType, Operation *op, 110 const DataLayout *defaultLayout) const { 111 const DataLayout *layout = defaultLayout; 112 if (const DataLayoutAnalysis *analysis = 113 getTypeConverter()->getDataLayoutAnalysis()) { 114 layout = &analysis->getAbove(op); 115 } 116 Type elementType = memRefType.getElementType(); 117 if (auto memRefElementType = dyn_cast<MemRefType>(elementType)) 118 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 119 *layout); 120 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType)) 121 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 122 memRefElementType, *layout); 123 return layout->getTypeSize(elementType); 124 } 125 126 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf( 127 MemRefType type, uint64_t factor, Operation *op, 128 const DataLayout *defaultLayout) const { 129 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout); 130 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 131 if (type.isDynamicDim(i)) 132 continue; 133 sizeDivisor = sizeDivisor * type.getDimSize(i); 134 } 135 return sizeDivisor % factor == 0; 136 } 137 138 Value AllocationOpLLVMLowering::allocateBufferAutoAlign( 139 ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, 140 Operation *op, const DataLayout *defaultLayout, int64_t alignment) const { 141 Value allocAlignment = 142 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); 143 144 MemRefType memRefType = getMemRefResultType(op); 145 // Function aligned_alloc requires size to be a multiple of alignment; we pad 146 // the size to the next multiple if necessary. 147 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout)) 148 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 149 150 Type elementPtrType = this->getElementPtrType(memRefType); 151 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn( 152 getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(), 153 getIndexType()); 154 if (failed(allocFuncOp)) 155 return Value(); 156 auto results = rewriter.create<LLVM::CallOp>( 157 loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); 158 159 return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, 160 elementPtrType, *getTypeConverter()); 161 } 162 163 void AllocLikeOpLLVMLowering::setRequiresNumElements() { 164 requiresNumElements = true; 165 } 166 167 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( 168 Operation *op, ArrayRef<Value> operands, 169 ConversionPatternRewriter &rewriter) const { 170 MemRefType memRefType = getMemRefResultType(op); 171 if (!isConvertibleAndHasIdentityMaps(memRefType)) 172 return rewriter.notifyMatchFailure(op, "incompatible memref type"); 173 auto loc = op->getLoc(); 174 175 // Get actual sizes of the memref as values: static sizes are constant 176 // values and dynamic sizes are passed to 'alloc' as operands. In case of 177 // zero-dimensional memref, assume a scalar (size 1). 178 SmallVector<Value, 4> sizes; 179 SmallVector<Value, 4> strides; 180 Value size; 181 182 this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, 183 strides, size, !requiresNumElements); 184 185 // Allocate the underlying buffer. 186 auto [allocatedPtr, alignedPtr] = 187 this->allocateBuffer(rewriter, loc, size, op); 188 189 if (!allocatedPtr || !alignedPtr) 190 return rewriter.notifyMatchFailure(loc, 191 "underlying buffer allocation failed"); 192 193 // Create the MemRef descriptor. 194 auto memRefDescriptor = this->createMemRefDescriptor( 195 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); 196 197 // Return the final value of the descriptor. 198 rewriter.replaceOp(op, {memRefDescriptor}); 199 return success(); 200 } 201