xref: /llvm-project/mlir/lib/Conversion/LLVMCommon/Pattern.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinAttributes.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
22 ConvertToLLVMPattern::ConvertToLLVMPattern(
23     StringRef rootOpName, MLIRContext *context,
24     const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25     : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26 
27 const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
28   return static_cast<const LLVMTypeConverter *>(
29       ConversionPattern::getTypeConverter());
30 }
31 
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33   return *getTypeConverter()->getDialect();
34 }
35 
36 Type ConvertToLLVMPattern::getIndexType() const {
37   return getTypeConverter()->getIndexType();
38 }
39 
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41   return IntegerType::get(&getTypeConverter()->getContext(),
42                           getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44 
45 Type ConvertToLLVMPattern::getVoidType() const {
46   return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
47 }
48 
49 Type ConvertToLLVMPattern::getVoidPtrType() const {
50   return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext());
51 }
52 
53 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
54                                                     Location loc,
55                                                     Type resultType,
56                                                     int64_t value) {
57   return builder.create<LLVM::ConstantOp>(loc, resultType,
58                                           builder.getIndexAttr(value));
59 }
60 
61 Value ConvertToLLVMPattern::getStridedElementPtr(
62     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
63     ConversionPatternRewriter &rewriter) const {
64 
65   auto [strides, offset] = type.getStridesAndOffset();
66 
67   MemRefDescriptor memRefDescriptor(memRefDesc);
68   // Use a canonical representation of the start address so that later
69   // optimizations have a longer sequence of instructions to CSE.
70   // If we don't do that we would sprinkle the memref.offset in various
71   // position of the different address computations.
72   Value base =
73       memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
74 
75   Type indexType = getIndexType();
76   Value index;
77   for (int i = 0, e = indices.size(); i < e; ++i) {
78     Value increment = indices[i];
79     if (strides[i] != 1) { // Skip if stride is 1.
80       Value stride =
81           ShapedType::isDynamic(strides[i])
82               ? memRefDescriptor.stride(rewriter, loc, i)
83               : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
84       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
85     }
86     index =
87         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
88   }
89 
90   Type elementPtrType = memRefDescriptor.getElementPtrType();
91   return index ? rewriter.create<LLVM::GEPOp>(
92                      loc, elementPtrType,
93                      getTypeConverter()->convertType(type.getElementType()),
94                      base, index)
95                : base;
96 }
97 
98 // Check if the MemRefType `type` is supported by the lowering. We currently
99 // only support memrefs with identity maps.
100 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
101     MemRefType type) const {
102   if (!typeConverter->convertType(type.getElementType()))
103     return false;
104   return type.getLayout().isIdentity();
105 }
106 
107 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
108   auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
109   if (failed(addressSpace))
110     return {};
111   return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
112 }
113 
114 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
115     Location loc, MemRefType memRefType, ValueRange dynamicSizes,
116     ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
117     SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
118   assert(isConvertibleAndHasIdentityMaps(memRefType) &&
119          "layout maps must have been normalized away");
120   assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
121              static_cast<ssize_t>(dynamicSizes.size()) &&
122          "dynamicSizes size doesn't match dynamic sizes count in memref shape");
123 
124   sizes.reserve(memRefType.getRank());
125   unsigned dynamicIndex = 0;
126   Type indexType = getIndexType();
127   for (int64_t size : memRefType.getShape()) {
128     sizes.push_back(
129         size == ShapedType::kDynamic
130             ? dynamicSizes[dynamicIndex++]
131             : createIndexAttrConstant(rewriter, loc, indexType, size));
132   }
133 
134   // Strides: iterate sizes in reverse order and multiply.
135   int64_t stride = 1;
136   Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
137   strides.resize(memRefType.getRank());
138   for (auto i = memRefType.getRank(); i-- > 0;) {
139     strides[i] = runningStride;
140 
141     int64_t staticSize = memRefType.getShape()[i];
142     bool useSizeAsStride = stride == 1;
143     if (staticSize == ShapedType::kDynamic)
144       stride = ShapedType::kDynamic;
145     if (stride != ShapedType::kDynamic)
146       stride *= staticSize;
147 
148     if (useSizeAsStride)
149       runningStride = sizes[i];
150     else if (stride == ShapedType::kDynamic)
151       runningStride =
152           rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
153     else
154       runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
155   }
156   if (sizeInBytes) {
157     // Buffer size in bytes.
158     Type elementType = typeConverter->convertType(memRefType.getElementType());
159     auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
160     Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
161     Value gepPtr = rewriter.create<LLVM::GEPOp>(
162         loc, elementPtrType, elementType, nullPtr, runningStride);
163     size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
164   } else {
165     size = runningStride;
166   }
167 }
168 
169 Value ConvertToLLVMPattern::getSizeInBytes(
170     Location loc, Type type, ConversionPatternRewriter &rewriter) const {
171   // Compute the size of an individual element. This emits the MLIR equivalent
172   // of the following sizeof(...) implementation in LLVM IR:
173   //   %0 = getelementptr %elementType* null, %indexType 1
174   //   %1 = ptrtoint %elementType* %0 to %indexType
175   // which is a common pattern of getting the size of a type in bytes.
176   Type llvmType = typeConverter->convertType(type);
177   auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
178   auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
179   auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
180                                           nullPtr, ArrayRef<LLVM::GEPArg>{1});
181   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
182 }
183 
184 Value ConvertToLLVMPattern::getNumElements(
185     Location loc, MemRefType memRefType, ValueRange dynamicSizes,
186     ConversionPatternRewriter &rewriter) const {
187   assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
188              static_cast<ssize_t>(dynamicSizes.size()) &&
189          "dynamicSizes size doesn't match dynamic sizes count in memref shape");
190 
191   Type indexType = getIndexType();
192   Value numElements = memRefType.getRank() == 0
193                           ? createIndexAttrConstant(rewriter, loc, indexType, 1)
194                           : nullptr;
195   unsigned dynamicIndex = 0;
196 
197   // Compute the total number of memref elements.
198   for (int64_t staticSize : memRefType.getShape()) {
199     if (numElements) {
200       Value size =
201           staticSize == ShapedType::kDynamic
202               ? dynamicSizes[dynamicIndex++]
203               : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
204       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
205     } else {
206       numElements =
207           staticSize == ShapedType::kDynamic
208               ? dynamicSizes[dynamicIndex++]
209               : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
210     }
211   }
212   return numElements;
213 }
214 
215 /// Creates and populates the memref descriptor struct given all its fields.
216 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
217     Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
218     ArrayRef<Value> sizes, ArrayRef<Value> strides,
219     ConversionPatternRewriter &rewriter) const {
220   auto structType = typeConverter->convertType(memRefType);
221   auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
222 
223   // Field 1: Allocated pointer, used for malloc/free.
224   memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
225 
226   // Field 2: Actual aligned pointer to payload.
227   memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
228 
229   // Field 3: Offset in aligned pointer.
230   Type indexType = getIndexType();
231   memRefDescriptor.setOffset(
232       rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
233 
234   // Fields 4: Sizes.
235   for (const auto &en : llvm::enumerate(sizes))
236     memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
237 
238   // Field 5: Strides.
239   for (const auto &en : llvm::enumerate(strides))
240     memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
241 
242   return memRefDescriptor;
243 }
244 
245 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
246     OpBuilder &builder, Location loc, TypeRange origTypes,
247     SmallVectorImpl<Value> &operands, bool toDynamic) const {
248   assert(origTypes.size() == operands.size() &&
249          "expected as may original types as operands");
250 
251   // Find operands of unranked memref type and store them.
252   SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
253   SmallVector<unsigned> unrankedAddressSpaces;
254   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
255     if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
256       unrankedMemrefs.emplace_back(operands[i]);
257       FailureOr<unsigned> addressSpace =
258           getTypeConverter()->getMemRefAddressSpace(memRefType);
259       if (failed(addressSpace))
260         return failure();
261       unrankedAddressSpaces.emplace_back(*addressSpace);
262     }
263   }
264 
265   if (unrankedMemrefs.empty())
266     return success();
267 
268   // Compute allocation sizes.
269   SmallVector<Value> sizes;
270   UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
271                                          unrankedMemrefs, unrankedAddressSpaces,
272                                          sizes);
273 
274   // Get frequently used types.
275   Type indexType = getTypeConverter()->getIndexType();
276 
277   // Find the malloc and free, or declare them if necessary.
278   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
279   FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
280   if (toDynamic) {
281     mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
282     if (failed(mallocFunc))
283       return failure();
284   }
285   if (!toDynamic) {
286     freeFunc = LLVM::lookupOrCreateFreeFn(module);
287     if (failed(freeFunc))
288       return failure();
289   }
290 
291   unsigned unrankedMemrefPos = 0;
292   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
293     Type type = origTypes[i];
294     if (!isa<UnrankedMemRefType>(type))
295       continue;
296     Value allocationSize = sizes[unrankedMemrefPos++];
297     UnrankedMemRefDescriptor desc(operands[i]);
298 
299     // Allocate memory, copy, and free the source if necessary.
300     Value memory =
301         toDynamic
302             ? builder
303                   .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
304                   .getResult()
305             : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
306                                              IntegerType::get(getContext(), 8),
307                                              allocationSize,
308                                              /*alignment=*/0);
309     Value source = desc.memRefDescPtr(builder, loc);
310     builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
311     if (!toDynamic)
312       builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
313 
314     // Create a new descriptor. The same descriptor can be returned multiple
315     // times, attempting to modify its pointer can lead to memory leaks
316     // (allocated twice and overwritten) or double frees (the caller does not
317     // know if the descriptor points to the same memory).
318     Type descriptorType = getTypeConverter()->convertType(type);
319     if (!descriptorType)
320       return failure();
321     auto updatedDesc =
322         UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
323     Value rank = desc.rank(builder, loc);
324     updatedDesc.setRank(builder, loc, rank);
325     updatedDesc.setMemRefDescPtr(builder, loc, memory);
326 
327     operands[i] = updatedDesc;
328   }
329 
330   return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // Detail methods
335 //===----------------------------------------------------------------------===//
336 
337 void LLVM::detail::setNativeProperties(Operation *op,
338                                        IntegerOverflowFlags overflowFlags) {
339   if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
340     iface.setOverflowFlags(overflowFlags);
341 }
342 
343 /// Replaces the given operation "op" with a new operation of type "targetOp"
344 /// and given operands.
345 LogicalResult LLVM::detail::oneToOneRewrite(
346     Operation *op, StringRef targetOp, ValueRange operands,
347     ArrayRef<NamedAttribute> targetAttrs,
348     const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
349     IntegerOverflowFlags overflowFlags) {
350   unsigned numResults = op->getNumResults();
351 
352   SmallVector<Type> resultTypes;
353   if (numResults != 0) {
354     resultTypes.push_back(
355         typeConverter.packOperationResults(op->getResultTypes()));
356     if (!resultTypes.back())
357       return failure();
358   }
359 
360   // Create the operation through state since we don't know its C++ type.
361   Operation *newOp =
362       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
363                       resultTypes, targetAttrs);
364 
365   setNativeProperties(newOp, overflowFlags);
366 
367   // If the operation produced 0 or 1 result, return them immediately.
368   if (numResults == 0)
369     return rewriter.eraseOp(op), success();
370   if (numResults == 1)
371     return rewriter.replaceOp(op, newOp->getResult(0)), success();
372 
373   // Otherwise, it had been converted to an operation producing a structure.
374   // Extract individual results from the structure and return them as list.
375   SmallVector<Value, 4> results;
376   results.reserve(numResults);
377   for (unsigned i = 0; i < numResults; ++i) {
378     results.push_back(rewriter.create<LLVM::ExtractValueOp>(
379         op->getLoc(), newOp->getResult(0), i));
380   }
381   rewriter.replaceOp(op, results);
382   return success();
383 }
384