xref: /llvm-project/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h (revision ce254598b73b119c9463f5b7f4131559e276e844)
1 //===- AllocLikeConversion.h - Convert allocation ops to LLVM ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
10 #define MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
11 
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 
14 namespace mlir {
15 
16 /// Lowering for memory allocation ops.
17 struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
18   using ConvertToLLVMPattern::createIndexAttrConstant;
19   using ConvertToLLVMPattern::getIndexType;
20   using ConvertToLLVMPattern::getVoidPtrType;
21 
22   explicit AllocationOpLLVMLowering(StringRef opName,
23                                     const LLVMTypeConverter &converter,
24                                     PatternBenefit benefit = 1)
25       : ConvertToLLVMPattern(opName, &converter.getContext(), converter,
26                              benefit) {}
27 
28 protected:
29   /// Computes the aligned value for 'input' as follows:
30   ///   bumped = input + alignement - 1
31   ///   aligned = bumped - bumped % alignment
32   static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
33                              Value input, Value alignment);
34 
getMemRefResultTypeAllocationOpLLVMLowering35   static MemRefType getMemRefResultType(Operation *op) {
36     return cast<MemRefType>(op->getResult(0).getType());
37   }
38 
39   /// Computes the alignment for the given memory allocation op.
40   template <typename OpType>
getAlignmentAllocationOpLLVMLowering41   Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
42                      OpType op) const {
43     MemRefType memRefType = op.getType();
44     Value alignment;
45     if (auto alignmentAttr = op.getAlignment()) {
46       Type indexType = getIndexType();
47       alignment =
48           createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
49     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
50       // In the case where no alignment is specified, we may want to override
51       // `malloc's` behavior. `malloc` typically aligns at the size of the
52       // biggest scalar on a target HW. For non-scalars, use the natural
53       // alignment of the LLVM type given by the LLVM DataLayout.
54       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
55     }
56     return alignment;
57   }
58 
59   /// Computes the alignment for aligned_alloc used to allocate the buffer for
60   /// the memory allocation op.
61   ///
62   /// Aligned_alloc requires the allocation size to be a power of two, and the
63   /// allocation size to be a multiple of the alignment.
64   template <typename OpType>
alignedAllocationGetAlignmentAllocationOpLLVMLowering65   int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
66                                         Location loc, OpType op,
67                                         const DataLayout *defaultLayout) const {
68     if (std::optional<uint64_t> alignment = op.getAlignment())
69       return *alignment;
70 
71     // Whenever we don't have alignment set, we will use an alignment
72     // consistent with the element type; since the allocation size has to be a
73     // power of two, we will bump to the next power of two if it isn't.
74     unsigned eltSizeBytes =
75         getMemRefEltSizeInBytes(op.getType(), op, defaultLayout);
76     return std::max(kMinAlignedAllocAlignment,
77                     llvm::PowerOf2Ceil(eltSizeBytes));
78   }
79 
80   /// Allocates a memory buffer using an allocation method that doesn't
81   /// guarantee alignment. Returns the pointer and its aligned value.
82   std::tuple<Value, Value>
83   allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
84                               Value sizeBytes, Operation *op,
85                               Value alignment) const;
86 
87   /// Allocates a memory buffer using an aligned allocation method.
88   Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
89                                 Location loc, Value sizeBytes, Operation *op,
90                                 const DataLayout *defaultLayout,
91                                 int64_t alignment) const;
92 
93 private:
94   /// Computes the byte size for the MemRef element type.
95   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op,
96                                    const DataLayout *defaultLayout) const;
97 
98   /// Returns true if the memref size in bytes is known to be a multiple of
99   /// factor.
100   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
101                               const DataLayout *defaultLayout) const;
102 
103   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
104   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
105 };
106 
107 /// Lowering for AllocOp and AllocaOp.
108 struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
109   explicit AllocLikeOpLLVMLowering(StringRef opName,
110                                    const LLVMTypeConverter &converter,
111                                    PatternBenefit benefit = 1)
AllocationOpLLVMLoweringAllocLikeOpLLVMLowering112       : AllocationOpLLVMLowering(opName, converter, benefit) {}
113 
114 protected:
115   /// Allocates the underlying buffer. Returns the allocated pointer and the
116   /// aligned pointer.
117   virtual std::tuple<Value, Value>
118   allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
119                  Operation *op) const = 0;
120 
121   /// Sets the flag 'requiresNumElements', specifying the Op requires the number
122   /// of elements instead of the size in bytes.
123   void setRequiresNumElements();
124 
125 private:
126   // An `alloc` is converted into a definition of a memref descriptor value and
127   // a call to `malloc` to allocate the underlying data buffer.  The memref
128   // descriptor is of the LLVM structure type where:
129   //   1. the first element is a pointer to the allocated (typed) data buffer,
130   //   2. the second element is a pointer to the (typed) payload, aligned to the
131   //      specified alignment,
132   //   3. the remaining elements serve to store all the sizes and strides of the
133   //      memref using LLVM-converted `index` type.
134   //
135   // Alignment is performed by allocating `alignment` more bytes than
136   // requested and shifting the aligned pointer relative to the allocated
137   // memory. Note: `alignment - <minimum malloc alignment>` would actually be
138   // sufficient. If alignment is unspecified, the two pointers are equal.
139 
140   // An `alloca` is converted into a definition of a memref descriptor value and
141   // an llvm.alloca to allocate the underlying data buffer.
142   LogicalResult
143   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
144                   ConversionPatternRewriter &rewriter) const override;
145 
146   // Flag for specifying the Op requires the number of elements instead of the
147   // size in bytes.
148   bool requiresNumElements = false;
149 };
150 
151 } // namespace mlir
152 
153 #endif // MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
154