xref: /llvm-project/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
175e5f0aaSAlex Zinenko //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
275e5f0aaSAlex Zinenko //
375e5f0aaSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
475e5f0aaSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
575e5f0aaSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
675e5f0aaSAlex Zinenko //
775e5f0aaSAlex Zinenko //===----------------------------------------------------------------------===//
875e5f0aaSAlex Zinenko 
975e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
1067d0d7acSMichele Scuttari 
1175e5f0aaSAlex Zinenko #include "mlir/Analysis/DataLayoutAnalysis.h"
12876a480cSMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1375e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1475e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1575e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1675e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1836550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1975e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
2075e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
217fb9bbe5SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
2275e5f0aaSAlex Zinenko #include "mlir/Dialect/MemRef/IR/MemRef.h"
23b4d6aadaSOleg Shyshkov #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
2475e5f0aaSAlex Zinenko #include "mlir/IR/AffineMap.h"
25ce6ef990SMax191 #include "mlir/IR/BuiltinTypes.h"
264d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
2767d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
286635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
290fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h"
307d2b180eSKazu Hirata #include <optional>
3175e5f0aaSAlex Zinenko 
3267d0d7acSMichele Scuttari namespace mlir {
33cb4ccd38SQuentin Colombet #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
3467d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3567d0d7acSMichele Scuttari } // namespace mlir
3667d0d7acSMichele Scuttari 
3775e5f0aaSAlex Zinenko using namespace mlir;
3875e5f0aaSAlex Zinenko 
3975e5f0aaSAlex Zinenko namespace {
4075e5f0aaSAlex Zinenko 
41*e84f6b6aSLuohao Wang static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42399638f9SAliia Khasanova   return !ShapedType::isDynamic(strideOrOffset);
435380e30eSAshay Rane }
445380e30eSAshay Rane 
45*e84f6b6aSLuohao Wang static FailureOr<LLVM::LLVMFuncOp>
46*e84f6b6aSLuohao Wang getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
479f13b934Sbixia1   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
489f13b934Sbixia1 
499f13b934Sbixia1   if (useGenericFn)
50b28a296cSChristian Ulmann     return LLVM::lookupOrCreateGenericFreeFn(module);
519f13b934Sbixia1 
52b28a296cSChristian Ulmann   return LLVM::lookupOrCreateFreeFn(module);
539f13b934Sbixia1 }
549f13b934Sbixia1 
5575e5f0aaSAlex Zinenko struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56ce254598SMatthias Springer   AllocOpLowering(const LLVMTypeConverter &converter)
5775e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
5875e5f0aaSAlex Zinenko                                 converter) {}
5975e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
6075e5f0aaSAlex Zinenko                                           Location loc, Value sizeBytes,
6175e5f0aaSAlex Zinenko                                           Operation *op) const override {
629f13b934Sbixia1     return allocateBufferManuallyAlign(
639f13b934Sbixia1         rewriter, loc, sizeBytes, op,
649f13b934Sbixia1         getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
6575e5f0aaSAlex Zinenko   }
6675e5f0aaSAlex Zinenko };
6775e5f0aaSAlex Zinenko 
6875e5f0aaSAlex Zinenko struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69ce254598SMatthias Springer   AlignedAllocOpLowering(const LLVMTypeConverter &converter)
7075e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
7175e5f0aaSAlex Zinenko                                 converter) {}
7275e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
7375e5f0aaSAlex Zinenko                                           Location loc, Value sizeBytes,
7475e5f0aaSAlex Zinenko                                           Operation *op) const override {
759f13b934Sbixia1     Value ptr = allocateBufferAutoAlign(
769f13b934Sbixia1         rewriter, loc, sizeBytes, op, &defaultLayout,
779f13b934Sbixia1         alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
789f13b934Sbixia1                                       &defaultLayout));
7973c6248cSKrzysztof Drewniak     if (!ptr)
8073c6248cSKrzysztof Drewniak       return std::make_tuple(Value(), Value());
819f13b934Sbixia1     return std::make_tuple(ptr, ptr);
8275e5f0aaSAlex Zinenko   }
8375e5f0aaSAlex Zinenko 
849f13b934Sbixia1 private:
8575e5f0aaSAlex Zinenko   /// Default layout to use in absence of the corresponding analysis.
8675e5f0aaSAlex Zinenko   DataLayout defaultLayout;
8775e5f0aaSAlex Zinenko };
8875e5f0aaSAlex Zinenko 
8975e5f0aaSAlex Zinenko struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90ce254598SMatthias Springer   AllocaOpLowering(const LLVMTypeConverter &converter)
9175e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92041f1abeSFabian Mora                                 converter) {
93041f1abeSFabian Mora     setRequiresNumElements();
94041f1abeSFabian Mora   }
9575e5f0aaSAlex Zinenko 
9675e5f0aaSAlex Zinenko   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
9775e5f0aaSAlex Zinenko   /// is set to null for stack allocations. `accessAlignment` is set if
9875e5f0aaSAlex Zinenko   /// alignment is needed post allocation (for eg. in conjunction with malloc).
9975e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100041f1abeSFabian Mora                                           Location loc, Value size,
10175e5f0aaSAlex Zinenko                                           Operation *op) const override {
10275e5f0aaSAlex Zinenko 
10375e5f0aaSAlex Zinenko     // With alloca, one gets a pointer to the element type right away.
10475e5f0aaSAlex Zinenko     // For stack allocations.
10575e5f0aaSAlex Zinenko     auto allocaOp = cast<memref::AllocaOp>(op);
10650ea17b8SMarkus Böck     auto elementType =
10750ea17b8SMarkus Böck         typeConverter->convertType(allocaOp.getType().getElementType());
108499abb24SKrzysztof Drewniak     unsigned addrSpace =
109499abb24SKrzysztof Drewniak         *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110499abb24SKrzysztof Drewniak     auto elementPtrType =
111b28a296cSChristian Ulmann         LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
11275e5f0aaSAlex Zinenko 
113041f1abeSFabian Mora     auto allocatedElementPtr =
114041f1abeSFabian Mora         rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
11550ea17b8SMarkus Böck                                         allocaOp.getAlignment().value_or(0));
11675e5f0aaSAlex Zinenko 
11775e5f0aaSAlex Zinenko     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
11875e5f0aaSAlex Zinenko   }
11975e5f0aaSAlex Zinenko };
12075e5f0aaSAlex Zinenko 
12175e5f0aaSAlex Zinenko struct AllocaScopeOpLowering
12275e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
12375e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
12475e5f0aaSAlex Zinenko 
12575e5f0aaSAlex Zinenko   LogicalResult
126ef976337SRiver Riddle   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
12775e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
12875e5f0aaSAlex Zinenko     OpBuilder::InsertionGuard guard(rewriter);
12975e5f0aaSAlex Zinenko     Location loc = allocaScopeOp.getLoc();
13075e5f0aaSAlex Zinenko 
13175e5f0aaSAlex Zinenko     // Split the current block before the AllocaScopeOp to create the inlining
13275e5f0aaSAlex Zinenko     // point.
13375e5f0aaSAlex Zinenko     auto *currentBlock = rewriter.getInsertionBlock();
13475e5f0aaSAlex Zinenko     auto *remainingOpsBlock =
13575e5f0aaSAlex Zinenko         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
13675e5f0aaSAlex Zinenko     Block *continueBlock;
13775e5f0aaSAlex Zinenko     if (allocaScopeOp.getNumResults() == 0) {
13875e5f0aaSAlex Zinenko       continueBlock = remainingOpsBlock;
13975e5f0aaSAlex Zinenko     } else {
140e084679fSRiver Riddle       continueBlock = rewriter.createBlock(
141e084679fSRiver Riddle           remainingOpsBlock, allocaScopeOp.getResultTypes(),
142e084679fSRiver Riddle           SmallVector<Location>(allocaScopeOp->getNumResults(),
143e084679fSRiver Riddle                                 allocaScopeOp.getLoc()));
14475e5f0aaSAlex Zinenko       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
14575e5f0aaSAlex Zinenko     }
14675e5f0aaSAlex Zinenko 
14775e5f0aaSAlex Zinenko     // Inline body region.
148136d746eSJacques Pienaar     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149136d746eSJacques Pienaar     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150136d746eSJacques Pienaar     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
15175e5f0aaSAlex Zinenko 
15275e5f0aaSAlex Zinenko     // Save stack and then branch into the body of the region.
15375e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(currentBlock);
15475e5f0aaSAlex Zinenko     auto stackSaveOp =
15575e5f0aaSAlex Zinenko         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
15675e5f0aaSAlex Zinenko     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
15775e5f0aaSAlex Zinenko 
15875e5f0aaSAlex Zinenko     // Replace the alloca_scope return with a branch that jumps out of the body.
15975e5f0aaSAlex Zinenko     // Stack restore before leaving the body region.
16075e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(afterBody);
16175e5f0aaSAlex Zinenko     auto returnOp =
16275e5f0aaSAlex Zinenko         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
16375e5f0aaSAlex Zinenko     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164136d746eSJacques Pienaar         returnOp, returnOp.getResults(), continueBlock);
16575e5f0aaSAlex Zinenko 
16675e5f0aaSAlex Zinenko     // Insert stack restore before jumping out the body of the region.
16775e5f0aaSAlex Zinenko     rewriter.setInsertionPoint(branchOp);
16875e5f0aaSAlex Zinenko     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
16975e5f0aaSAlex Zinenko 
17075e5f0aaSAlex Zinenko     // Replace the op with values return from the body region.
17175e5f0aaSAlex Zinenko     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
17275e5f0aaSAlex Zinenko 
17375e5f0aaSAlex Zinenko     return success();
17475e5f0aaSAlex Zinenko   }
17575e5f0aaSAlex Zinenko };
17675e5f0aaSAlex Zinenko 
17775e5f0aaSAlex Zinenko struct AssumeAlignmentOpLowering
17875e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
17975e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<
18075e5f0aaSAlex Zinenko       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181ce254598SMatthias Springer   explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182e02d4142SQuentin Colombet       : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
18375e5f0aaSAlex Zinenko 
18475e5f0aaSAlex Zinenko   LogicalResult
185ef976337SRiver Riddle   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
18675e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
187136d746eSJacques Pienaar     Value memref = adaptor.getMemref();
188136d746eSJacques Pienaar     unsigned alignment = op.getAlignment();
18975e5f0aaSAlex Zinenko     auto loc = op.getLoc();
19075e5f0aaSAlex Zinenko 
1915550c821STres Popp     auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192e02d4142SQuentin Colombet     Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193e02d4142SQuentin Colombet                                      rewriter);
19475e5f0aaSAlex Zinenko 
19533598068SKrzysztof Drewniak     // Emit llvm.assume(true) ["align"(memref, alignment)].
19633598068SKrzysztof Drewniak     // This is more direct than ptrtoint-based checks, is explicitly supported,
19733598068SKrzysztof Drewniak     // and works with non-integral address spaces.
19833598068SKrzysztof Drewniak     Value trueCond =
19933598068SKrzysztof Drewniak         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
20033598068SKrzysztof Drewniak     Value alignmentConst =
20133598068SKrzysztof Drewniak         createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
20233598068SKrzysztof Drewniak     rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
20333598068SKrzysztof Drewniak                                     alignmentConst);
20475e5f0aaSAlex Zinenko 
20575e5f0aaSAlex Zinenko     rewriter.eraseOp(op);
20675e5f0aaSAlex Zinenko     return success();
20775e5f0aaSAlex Zinenko   }
20875e5f0aaSAlex Zinenko };
20975e5f0aaSAlex Zinenko 
21075e5f0aaSAlex Zinenko // A `dealloc` is converted into a call to `free` on the underlying data buffer.
21175e5f0aaSAlex Zinenko // The memref descriptor being an SSA value, there is no need to clean it up
21275e5f0aaSAlex Zinenko // in any way.
21375e5f0aaSAlex Zinenko struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
21475e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
21575e5f0aaSAlex Zinenko 
216ce254598SMatthias Springer   explicit DeallocOpLowering(const LLVMTypeConverter &converter)
21775e5f0aaSAlex Zinenko       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
21875e5f0aaSAlex Zinenko 
21975e5f0aaSAlex Zinenko   LogicalResult
220ef976337SRiver Riddle   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
22175e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
22275e5f0aaSAlex Zinenko     // Insert the `free` declaration if it is not already present.
223*e84f6b6aSLuohao Wang     FailureOr<LLVM::LLVMFuncOp> freeFunc =
2249f13b934Sbixia1         getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225*e84f6b6aSLuohao Wang     if (failed(freeFunc))
226*e84f6b6aSLuohao Wang       return failure();
227b58daf91SJohannes Reifferscheid     Value allocatedPtr;
228b58daf91SJohannes Reifferscheid     if (auto unrankedTy =
229b58daf91SJohannes Reifferscheid             llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
230b28a296cSChristian Ulmann       auto elementPtrTy = LLVM::LLVMPointerType::get(
231b28a296cSChristian Ulmann           rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
232b58daf91SJohannes Reifferscheid       allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
233b58daf91SJohannes Reifferscheid           rewriter, op.getLoc(),
234b58daf91SJohannes Reifferscheid           UnrankedMemRefDescriptor(adaptor.getMemref())
235b58daf91SJohannes Reifferscheid               .memRefDescPtr(rewriter, op.getLoc()),
236b58daf91SJohannes Reifferscheid           elementPtrTy);
237b58daf91SJohannes Reifferscheid     } else {
238b58daf91SJohannes Reifferscheid       allocatedPtr = MemRefDescriptor(adaptor.getMemref())
239b58daf91SJohannes Reifferscheid                          .allocatedPtr(rewriter, op.getLoc());
240b58daf91SJohannes Reifferscheid     }
241*e84f6b6aSLuohao Wang     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
242*e84f6b6aSLuohao Wang                                               allocatedPtr);
24375e5f0aaSAlex Zinenko     return success();
24475e5f0aaSAlex Zinenko   }
24575e5f0aaSAlex Zinenko };
24675e5f0aaSAlex Zinenko 
24775e5f0aaSAlex Zinenko // A `dim` is converted to a constant for static sizes and to an access to the
24875e5f0aaSAlex Zinenko // size stored in the memref descriptor for dynamic sizes.
24975e5f0aaSAlex Zinenko struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
25075e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
25175e5f0aaSAlex Zinenko 
25275e5f0aaSAlex Zinenko   LogicalResult
253ef976337SRiver Riddle   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
25475e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
255136d746eSJacques Pienaar     Type operandType = dimOp.getSource().getType();
2565550c821STres Popp     if (isa<UnrankedMemRefType>(operandType)) {
257499abb24SKrzysztof Drewniak       FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
258499abb24SKrzysztof Drewniak           operandType, dimOp, adaptor.getOperands(), rewriter);
259499abb24SKrzysztof Drewniak       if (failed(extractedSize))
260499abb24SKrzysztof Drewniak         return failure();
261499abb24SKrzysztof Drewniak       rewriter.replaceOp(dimOp, {*extractedSize});
26275e5f0aaSAlex Zinenko       return success();
26375e5f0aaSAlex Zinenko     }
2645550c821STres Popp     if (isa<MemRefType>(operandType)) {
265ef976337SRiver Riddle       rewriter.replaceOp(
266ef976337SRiver Riddle           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
267ef976337SRiver Riddle                                             adaptor.getOperands(), rewriter)});
26875e5f0aaSAlex Zinenko       return success();
26975e5f0aaSAlex Zinenko     }
27075e5f0aaSAlex Zinenko     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
27175e5f0aaSAlex Zinenko   }
27275e5f0aaSAlex Zinenko 
27375e5f0aaSAlex Zinenko private:
274499abb24SKrzysztof Drewniak   FailureOr<Value>
275499abb24SKrzysztof Drewniak   extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
276ef976337SRiver Riddle                               OpAdaptor adaptor,
27775e5f0aaSAlex Zinenko                               ConversionPatternRewriter &rewriter) const {
27875e5f0aaSAlex Zinenko     Location loc = dimOp.getLoc();
27975e5f0aaSAlex Zinenko 
2805550c821STres Popp     auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
28175e5f0aaSAlex Zinenko     auto scalarMemRefType =
28275e5f0aaSAlex Zinenko         MemRefType::get({}, unrankedMemRefType.getElementType());
283499abb24SKrzysztof Drewniak     FailureOr<unsigned> maybeAddressSpace =
284499abb24SKrzysztof Drewniak         getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
285499abb24SKrzysztof Drewniak     if (failed(maybeAddressSpace)) {
286499abb24SKrzysztof Drewniak       dimOp.emitOpError("memref memory space must be convertible to an integer "
287499abb24SKrzysztof Drewniak                         "address space");
288499abb24SKrzysztof Drewniak       return failure();
289499abb24SKrzysztof Drewniak     }
290499abb24SKrzysztof Drewniak     unsigned addressSpace = *maybeAddressSpace;
29175e5f0aaSAlex Zinenko 
29275e5f0aaSAlex Zinenko     // Extract pointer to the underlying ranked descriptor and bitcast it to a
29375e5f0aaSAlex Zinenko     // memref<element_type> descriptor pointer to minimize the number of GEP
29475e5f0aaSAlex Zinenko     // operations.
295136d746eSJacques Pienaar     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
29675e5f0aaSAlex Zinenko     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
29750ea17b8SMarkus Böck 
29850ea17b8SMarkus Böck     Type elementType = typeConverter->convertType(scalarMemRefType);
29975e5f0aaSAlex Zinenko 
30075e5f0aaSAlex Zinenko     // Get pointer to offset field of memref<element_type> descriptor.
301b28a296cSChristian Ulmann     auto indexPtrTy =
302b28a296cSChristian Ulmann         LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
30375e5f0aaSAlex Zinenko     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
304b28a296cSChristian Ulmann         loc, indexPtrTy, elementType, underlyingRankedDesc,
30550ea17b8SMarkus Böck         ArrayRef<LLVM::GEPArg>{0, 2});
30675e5f0aaSAlex Zinenko 
30775e5f0aaSAlex Zinenko     // The size value that we have to extract can be obtained using GEPop with
30875e5f0aaSAlex Zinenko     // `dimOp.index() + 1` index argument.
30975e5f0aaSAlex Zinenko     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
310e98e5995SAlex Zinenko         loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
311620e2bb2SNicolas Vasilache         adaptor.getIndex());
31250ea17b8SMarkus Böck     Value sizePtr = rewriter.create<LLVM::GEPOp>(
31350ea17b8SMarkus Böck         loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
31450ea17b8SMarkus Böck         idxPlusOne);
315499abb24SKrzysztof Drewniak     return rewriter
316499abb24SKrzysztof Drewniak         .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
317499abb24SKrzysztof Drewniak         .getResult();
31875e5f0aaSAlex Zinenko   }
31975e5f0aaSAlex Zinenko 
32022426110SRamkumar Ramachandra   std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
32122426110SRamkumar Ramachandra     if (auto idx = dimOp.getConstantIndex())
32275e5f0aaSAlex Zinenko       return idx;
32375e5f0aaSAlex Zinenko 
324136d746eSJacques Pienaar     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
3255550c821STres Popp       return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
32675e5f0aaSAlex Zinenko 
3271a36588eSKazu Hirata     return std::nullopt;
32875e5f0aaSAlex Zinenko   }
32975e5f0aaSAlex Zinenko 
33075e5f0aaSAlex Zinenko   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
331ef976337SRiver Riddle                                   OpAdaptor adaptor,
33275e5f0aaSAlex Zinenko                                   ConversionPatternRewriter &rewriter) const {
33375e5f0aaSAlex Zinenko     Location loc = dimOp.getLoc();
334ef976337SRiver Riddle 
33575e5f0aaSAlex Zinenko     // Take advantage if index is constant.
3365550c821STres Popp     MemRefType memRefType = cast<MemRefType>(operandType);
337e98e5995SAlex Zinenko     Type indexType = getIndexType();
33822426110SRamkumar Ramachandra     if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
3396d5fc1e3SKazu Hirata       int64_t i = *index;
3404bc2357cSQuentin Colombet       if (i >= 0 && i < memRefType.getRank()) {
34175e5f0aaSAlex Zinenko         if (memRefType.isDynamicDim(i)) {
34275e5f0aaSAlex Zinenko           // extract dynamic size from the memref descriptor.
343136d746eSJacques Pienaar           MemRefDescriptor descriptor(adaptor.getSource());
34475e5f0aaSAlex Zinenko           return descriptor.size(rewriter, loc, i);
34575e5f0aaSAlex Zinenko         }
34675e5f0aaSAlex Zinenko         // Use constant for static size.
34775e5f0aaSAlex Zinenko         int64_t dimSize = memRefType.getDimSize(i);
348620e2bb2SNicolas Vasilache         return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
34975e5f0aaSAlex Zinenko       }
3504bc2357cSQuentin Colombet     }
351136d746eSJacques Pienaar     Value index = adaptor.getIndex();
35275e5f0aaSAlex Zinenko     int64_t rank = memRefType.getRank();
353136d746eSJacques Pienaar     MemRefDescriptor memrefDescriptor(adaptor.getSource());
35475e5f0aaSAlex Zinenko     return memrefDescriptor.size(rewriter, loc, index, rank);
35575e5f0aaSAlex Zinenko   }
35675e5f0aaSAlex Zinenko };
35775e5f0aaSAlex Zinenko 
358632a4f88SRiver Riddle /// Common base for load and store operations on MemRefs. Restricts the match
359632a4f88SRiver Riddle /// to supported MemRef types. Provides functionality to emit code accessing a
360632a4f88SRiver Riddle /// specific element of the underlying data buffer.
361632a4f88SRiver Riddle template <typename Derived>
362632a4f88SRiver Riddle struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
363632a4f88SRiver Riddle   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
364632a4f88SRiver Riddle   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
365632a4f88SRiver Riddle   using Base = LoadStoreOpLowering<Derived>;
366632a4f88SRiver Riddle 
367632a4f88SRiver Riddle   LogicalResult match(Derived op) const override {
368632a4f88SRiver Riddle     MemRefType type = op.getMemRefType();
369632a4f88SRiver Riddle     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
370632a4f88SRiver Riddle   }
371632a4f88SRiver Riddle };
372632a4f88SRiver Riddle 
373632a4f88SRiver Riddle /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
374632a4f88SRiver Riddle /// retried until it succeeds in atomically storing a new value into memory.
375632a4f88SRiver Riddle ///
376632a4f88SRiver Riddle ///      +---------------------------------+
377632a4f88SRiver Riddle ///      |   <code before the AtomicRMWOp> |
378632a4f88SRiver Riddle ///      |   <compute initial %loaded>     |
379ace01605SRiver Riddle ///      |   cf.br loop(%loaded)              |
380632a4f88SRiver Riddle ///      +---------------------------------+
381632a4f88SRiver Riddle ///             |
382632a4f88SRiver Riddle ///  -------|   |
383632a4f88SRiver Riddle ///  |      v   v
384632a4f88SRiver Riddle ///  |   +--------------------------------+
385632a4f88SRiver Riddle ///  |   | loop(%loaded):                 |
386632a4f88SRiver Riddle ///  |   |   <body contents>              |
387632a4f88SRiver Riddle ///  |   |   %pair = cmpxchg              |
388632a4f88SRiver Riddle ///  |   |   %ok = %pair[0]               |
389632a4f88SRiver Riddle ///  |   |   %new = %pair[1]              |
390ace01605SRiver Riddle ///  |   |   cf.cond_br %ok, end, loop(%new) |
391632a4f88SRiver Riddle ///  |   +--------------------------------+
392632a4f88SRiver Riddle ///  |          |        |
393632a4f88SRiver Riddle ///  |-----------        |
394632a4f88SRiver Riddle ///                      v
395632a4f88SRiver Riddle ///      +--------------------------------+
396632a4f88SRiver Riddle ///      | end:                           |
397632a4f88SRiver Riddle ///      |   <code after the AtomicRMWOp> |
398632a4f88SRiver Riddle ///      +--------------------------------+
399632a4f88SRiver Riddle ///
400632a4f88SRiver Riddle struct GenericAtomicRMWOpLowering
401632a4f88SRiver Riddle     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
402632a4f88SRiver Riddle   using Base::Base;
403632a4f88SRiver Riddle 
404632a4f88SRiver Riddle   LogicalResult
405632a4f88SRiver Riddle   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
406632a4f88SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
407632a4f88SRiver Riddle     auto loc = atomicOp.getLoc();
408632a4f88SRiver Riddle     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
409632a4f88SRiver Riddle 
410632a4f88SRiver Riddle     // Split the block into initial, loop, and ending parts.
411632a4f88SRiver Riddle     auto *initBlock = rewriter.getInsertionBlock();
412e7833c20SAlexander Belyaev     auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
413e7833c20SAlexander Belyaev     loopBlock->addArgument(valueType, loc);
414632a4f88SRiver Riddle 
415e7833c20SAlexander Belyaev     auto *endBlock =
416e7833c20SAlexander Belyaev         rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
417632a4f88SRiver Riddle 
418632a4f88SRiver Riddle     // Compute the loaded value and branch to the loop block.
419632a4f88SRiver Riddle     rewriter.setInsertionPointToEnd(initBlock);
4205550c821STres Popp     auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
421136d746eSJacques Pienaar     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
422136d746eSJacques Pienaar                                         adaptor.getIndices(), rewriter);
42350ea17b8SMarkus Böck     Value init = rewriter.create<LLVM::LoadOp>(
42450ea17b8SMarkus Böck         loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
425632a4f88SRiver Riddle     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
426632a4f88SRiver Riddle 
427632a4f88SRiver Riddle     // Prepare the body of the loop block.
428632a4f88SRiver Riddle     rewriter.setInsertionPointToStart(loopBlock);
429632a4f88SRiver Riddle 
430632a4f88SRiver Riddle     // Clone the GenericAtomicRMWOp region and extract the result.
431632a4f88SRiver Riddle     auto loopArgument = loopBlock->getArgument(0);
4324d67b278SJeff Niu     IRMapping mapping;
433632a4f88SRiver Riddle     mapping.map(atomicOp.getCurrentValue(), loopArgument);
434632a4f88SRiver Riddle     Block &entryBlock = atomicOp.body().front();
435632a4f88SRiver Riddle     for (auto &nestedOp : entryBlock.without_terminator()) {
436632a4f88SRiver Riddle       Operation *clone = rewriter.clone(nestedOp, mapping);
437632a4f88SRiver Riddle       mapping.map(nestedOp.getResults(), clone->getResults());
438632a4f88SRiver Riddle     }
439632a4f88SRiver Riddle     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
440632a4f88SRiver Riddle 
441632a4f88SRiver Riddle     // Prepare the epilog of the loop block.
442632a4f88SRiver Riddle     // Append the cmpxchg op to the end of the loop block.
443632a4f88SRiver Riddle     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
444632a4f88SRiver Riddle     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
445632a4f88SRiver Riddle     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
4467f97895fSTobias Gysi         loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
447632a4f88SRiver Riddle     // Extract the %new_loaded and %ok values from the pair.
4485c5af910SJeff Niu     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
4495c5af910SJeff Niu     Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
450632a4f88SRiver Riddle 
451632a4f88SRiver Riddle     // Conditionally branch to the end or back to the loop depending on %ok.
452632a4f88SRiver Riddle     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
453632a4f88SRiver Riddle                                     loopBlock, newLoaded);
454632a4f88SRiver Riddle 
455632a4f88SRiver Riddle     rewriter.setInsertionPointToEnd(endBlock);
456632a4f88SRiver Riddle 
457632a4f88SRiver Riddle     // The 'result' of the atomic_rmw op is the newly loaded value.
458632a4f88SRiver Riddle     rewriter.replaceOp(atomicOp, {newLoaded});
459632a4f88SRiver Riddle 
460632a4f88SRiver Riddle     return success();
461632a4f88SRiver Riddle   }
462632a4f88SRiver Riddle };
463632a4f88SRiver Riddle 
46475e5f0aaSAlex Zinenko /// Returns the LLVM type of the global variable given the memref type `type`.
465ce254598SMatthias Springer static Type
466ce254598SMatthias Springer convertGlobalMemrefTypeToLLVM(MemRefType type,
467ce254598SMatthias Springer                               const LLVMTypeConverter &typeConverter) {
46875e5f0aaSAlex Zinenko   // LLVM type for a global memref will be a multi-dimension array. For
46975e5f0aaSAlex Zinenko   // declarations or uninitialized global memrefs, we can potentially flatten
47075e5f0aaSAlex Zinenko   // this to a 1D array. However, for memref.global's with an initial value,
47175e5f0aaSAlex Zinenko   // we do not intend to flatten the ElementsAttribute when going from std ->
47275e5f0aaSAlex Zinenko   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
47375e5f0aaSAlex Zinenko   Type elementType = typeConverter.convertType(type.getElementType());
47475e5f0aaSAlex Zinenko   Type arrayTy = elementType;
47575e5f0aaSAlex Zinenko   // Shape has the outermost dim at index 0, so need to walk it backwards
47675e5f0aaSAlex Zinenko   for (int64_t dim : llvm::reverse(type.getShape()))
47775e5f0aaSAlex Zinenko     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
47875e5f0aaSAlex Zinenko   return arrayTy;
47975e5f0aaSAlex Zinenko }
48075e5f0aaSAlex Zinenko 
48175e5f0aaSAlex Zinenko /// GlobalMemrefOp is lowered to a LLVM Global Variable.
48275e5f0aaSAlex Zinenko struct GlobalMemrefOpLowering
48375e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
48475e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
48575e5f0aaSAlex Zinenko 
48675e5f0aaSAlex Zinenko   LogicalResult
487ef976337SRiver Riddle   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
48875e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
489136d746eSJacques Pienaar     MemRefType type = global.getType();
49075e5f0aaSAlex Zinenko     if (!isConvertibleAndHasIdentityMaps(type))
49175e5f0aaSAlex Zinenko       return failure();
49275e5f0aaSAlex Zinenko 
49375e5f0aaSAlex Zinenko     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
49475e5f0aaSAlex Zinenko 
49575e5f0aaSAlex Zinenko     LLVM::Linkage linkage =
49675e5f0aaSAlex Zinenko         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
49775e5f0aaSAlex Zinenko 
49875e5f0aaSAlex Zinenko     Attribute initialValue = nullptr;
49975e5f0aaSAlex Zinenko     if (!global.isExternal() && !global.isUninitialized()) {
50068f58812STres Popp       auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
50175e5f0aaSAlex Zinenko       initialValue = elementsAttr;
50275e5f0aaSAlex Zinenko 
50375e5f0aaSAlex Zinenko       // For scalar memrefs, the global variable created is of the element type,
50475e5f0aaSAlex Zinenko       // so unpack the elements attribute to extract the value.
50575e5f0aaSAlex Zinenko       if (type.getRank() == 0)
506937e40a8SRiver Riddle         initialValue = elementsAttr.getSplatValue<Attribute>();
50775e5f0aaSAlex Zinenko     }
50875e5f0aaSAlex Zinenko 
509136d746eSJacques Pienaar     uint64_t alignment = global.getAlignment().value_or(0);
510499abb24SKrzysztof Drewniak     FailureOr<unsigned> addressSpace =
511499abb24SKrzysztof Drewniak         getTypeConverter()->getMemRefAddressSpace(type);
512499abb24SKrzysztof Drewniak     if (failed(addressSpace))
513499abb24SKrzysztof Drewniak       return global.emitOpError(
514499abb24SKrzysztof Drewniak           "memory space cannot be converted to an integer address space");
5158c2ff7b6SWilliam S. Moses     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
516136d746eSJacques Pienaar         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
517499abb24SKrzysztof Drewniak         initialValue, alignment, *addressSpace);
5188c2ff7b6SWilliam S. Moses     if (!global.isExternal() && global.isUninitialized()) {
51991d5653eSMatthias Springer       rewriter.createBlock(&newGlobal.getInitializerRegion());
5208c2ff7b6SWilliam S. Moses       Value undef[] = {
5218c2ff7b6SWilliam S. Moses           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
5228c2ff7b6SWilliam S. Moses       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
5238c2ff7b6SWilliam S. Moses     }
52475e5f0aaSAlex Zinenko     return success();
52575e5f0aaSAlex Zinenko   }
52675e5f0aaSAlex Zinenko };
52775e5f0aaSAlex Zinenko 
52875e5f0aaSAlex Zinenko /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
52975e5f0aaSAlex Zinenko /// the first element stashed into the descriptor. This reuses
53075e5f0aaSAlex Zinenko /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
53175e5f0aaSAlex Zinenko struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
532ce254598SMatthias Springer   GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
53375e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
53475e5f0aaSAlex Zinenko                                 converter) {}
53575e5f0aaSAlex Zinenko 
53675e5f0aaSAlex Zinenko   /// Buffer "allocation" for memref.get_global op is getting the address of
53775e5f0aaSAlex Zinenko   /// the global variable referenced.
53875e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
53975e5f0aaSAlex Zinenko                                           Location loc, Value sizeBytes,
54075e5f0aaSAlex Zinenko                                           Operation *op) const override {
54175e5f0aaSAlex Zinenko     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
5425550c821STres Popp     MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
543499abb24SKrzysztof Drewniak 
544499abb24SKrzysztof Drewniak     // This is called after a type conversion, which would have failed if this
545499abb24SKrzysztof Drewniak     // call fails.
54673c6248cSKrzysztof Drewniak     FailureOr<unsigned> maybeAddressSpace =
547620e2bb2SNicolas Vasilache         getTypeConverter()->getMemRefAddressSpace(type);
54873c6248cSKrzysztof Drewniak     if (failed(maybeAddressSpace))
549620e2bb2SNicolas Vasilache       return std::make_tuple(Value(), Value());
550620e2bb2SNicolas Vasilache     unsigned memSpace = *maybeAddressSpace;
55175e5f0aaSAlex Zinenko 
55275e5f0aaSAlex Zinenko     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
553b28a296cSChristian Ulmann     auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
55450ea17b8SMarkus Böck     auto addressOf =
555b28a296cSChristian Ulmann         rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
55675e5f0aaSAlex Zinenko 
55775e5f0aaSAlex Zinenko     // Get the address of the first element in the array by creating a GEP with
55875e5f0aaSAlex Zinenko     // the address of the GV as the base, and (rank + 1) number of 0 indices.
559bd7eff1fSMarkus Böck     auto gep = rewriter.create<LLVM::GEPOp>(
560b28a296cSChristian Ulmann         loc, ptrTy, arrayTy, addressOf,
561bd7eff1fSMarkus Böck         SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
56275e5f0aaSAlex Zinenko 
56375e5f0aaSAlex Zinenko     // We do not expect the memref obtained using `memref.get_global` to be
56475e5f0aaSAlex Zinenko     // ever deallocated. Set the allocated pointer to be known bad value to
56575e5f0aaSAlex Zinenko     // help debug if that ever happens.
56675e5f0aaSAlex Zinenko     auto intPtrType = getIntPtrType(memSpace);
56775e5f0aaSAlex Zinenko     Value deadBeefConst =
56875e5f0aaSAlex Zinenko         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
56975e5f0aaSAlex Zinenko     auto deadBeefPtr =
570b28a296cSChristian Ulmann         rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
57175e5f0aaSAlex Zinenko 
57275e5f0aaSAlex Zinenko     // Both allocated and aligned pointers are same. We could potentially stash
57375e5f0aaSAlex Zinenko     // a nullptr for the allocated pointer since we do not expect any dealloc.
57475e5f0aaSAlex Zinenko     return std::make_tuple(deadBeefPtr, gep);
57575e5f0aaSAlex Zinenko   }
57675e5f0aaSAlex Zinenko };
57775e5f0aaSAlex Zinenko 
57875e5f0aaSAlex Zinenko // Load operation is lowered to obtaining a pointer to the indexed element
57975e5f0aaSAlex Zinenko // and loading it.
58075e5f0aaSAlex Zinenko struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
58175e5f0aaSAlex Zinenko   using Base::Base;
58275e5f0aaSAlex Zinenko 
58375e5f0aaSAlex Zinenko   LogicalResult
584ef976337SRiver Riddle   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
58575e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
58675e5f0aaSAlex Zinenko     auto type = loadOp.getMemRefType();
58775e5f0aaSAlex Zinenko 
588136d746eSJacques Pienaar     Value dataPtr =
589136d746eSJacques Pienaar         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
590136d746eSJacques Pienaar                              adaptor.getIndices(), rewriter);
59150ea17b8SMarkus Böck     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
59250ea17b8SMarkus Böck         loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
59350ea17b8SMarkus Böck         false, loadOp.getNontemporal());
59475e5f0aaSAlex Zinenko     return success();
59575e5f0aaSAlex Zinenko   }
59675e5f0aaSAlex Zinenko };
59775e5f0aaSAlex Zinenko 
59875e5f0aaSAlex Zinenko // Store operation is lowered to obtaining a pointer to the indexed element,
59975e5f0aaSAlex Zinenko // and storing the given value to it.
60075e5f0aaSAlex Zinenko struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
60175e5f0aaSAlex Zinenko   using Base::Base;
60275e5f0aaSAlex Zinenko 
60375e5f0aaSAlex Zinenko   LogicalResult
604ef976337SRiver Riddle   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
60575e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
60675e5f0aaSAlex Zinenko     auto type = op.getMemRefType();
60775e5f0aaSAlex Zinenko 
608136d746eSJacques Pienaar     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
609136d746eSJacques Pienaar                                          adaptor.getIndices(), rewriter);
6101cb91b42SGuray Ozen     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
6111cb91b42SGuray Ozen                                                0, false, op.getNontemporal());
61275e5f0aaSAlex Zinenko     return success();
61375e5f0aaSAlex Zinenko   }
61475e5f0aaSAlex Zinenko };
61575e5f0aaSAlex Zinenko 
61675e5f0aaSAlex Zinenko // The prefetch operation is lowered in a way similar to the load operation
61775e5f0aaSAlex Zinenko // except that the llvm.prefetch operation is used for replacement.
61875e5f0aaSAlex Zinenko struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
61975e5f0aaSAlex Zinenko   using Base::Base;
62075e5f0aaSAlex Zinenko 
62175e5f0aaSAlex Zinenko   LogicalResult
622ef976337SRiver Riddle   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
62375e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
62475e5f0aaSAlex Zinenko     auto type = prefetchOp.getMemRefType();
62575e5f0aaSAlex Zinenko     auto loc = prefetchOp.getLoc();
62675e5f0aaSAlex Zinenko 
627136d746eSJacques Pienaar     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
628136d746eSJacques Pienaar                                          adaptor.getIndices(), rewriter);
62975e5f0aaSAlex Zinenko 
63075e5f0aaSAlex Zinenko     // Replace with llvm.prefetch.
63148b126e3SChristian Ulmann     IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
63248b126e3SChristian Ulmann     IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
63348b126e3SChristian Ulmann     IntegerAttr isData =
63448b126e3SChristian Ulmann         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
63575e5f0aaSAlex Zinenko     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
63675e5f0aaSAlex Zinenko                                                 localityHint, isData);
63775e5f0aaSAlex Zinenko     return success();
63875e5f0aaSAlex Zinenko   }
63975e5f0aaSAlex Zinenko };
64075e5f0aaSAlex Zinenko 
64115f8f3e2SAlexander Belyaev struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
64215f8f3e2SAlexander Belyaev   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
64315f8f3e2SAlexander Belyaev 
64415f8f3e2SAlexander Belyaev   LogicalResult
64515f8f3e2SAlexander Belyaev   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
64615f8f3e2SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
64715f8f3e2SAlexander Belyaev     Location loc = op.getLoc();
648136d746eSJacques Pienaar     Type operandType = op.getMemref().getType();
6490a0aff2dSMikhail Goncharov     if (dyn_cast<UnrankedMemRefType>(operandType)) {
650136d746eSJacques Pienaar       UnrankedMemRefDescriptor desc(adaptor.getMemref());
65115f8f3e2SAlexander Belyaev       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
65215f8f3e2SAlexander Belyaev       return success();
65315f8f3e2SAlexander Belyaev     }
6545550c821STres Popp     if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
655620e2bb2SNicolas Vasilache       Type indexType = getIndexType();
656620e2bb2SNicolas Vasilache       rewriter.replaceOp(op,
657620e2bb2SNicolas Vasilache                          {createIndexAttrConstant(rewriter, loc, indexType,
658620e2bb2SNicolas Vasilache                                                   rankedMemRefType.getRank())});
65915f8f3e2SAlexander Belyaev       return success();
66015f8f3e2SAlexander Belyaev     }
66115f8f3e2SAlexander Belyaev     return failure();
66215f8f3e2SAlexander Belyaev   }
66315f8f3e2SAlexander Belyaev };
66415f8f3e2SAlexander Belyaev 
66575e5f0aaSAlex Zinenko struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
66675e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
66775e5f0aaSAlex Zinenko 
66875e5f0aaSAlex Zinenko   LogicalResult match(memref::CastOp memRefCastOp) const override {
66975e5f0aaSAlex Zinenko     Type srcType = memRefCastOp.getOperand().getType();
67075e5f0aaSAlex Zinenko     Type dstType = memRefCastOp.getType();
67175e5f0aaSAlex Zinenko 
67275e5f0aaSAlex Zinenko     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
67375e5f0aaSAlex Zinenko     // used for type erasure. For now they must preserve underlying element type
67475e5f0aaSAlex Zinenko     // and require source and result type to have the same rank. Therefore,
67575e5f0aaSAlex Zinenko     // perform a sanity check that the underlying structs are the same. Once op
67675e5f0aaSAlex Zinenko     // semantics are relaxed we can revisit.
6775550c821STres Popp     if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
67875e5f0aaSAlex Zinenko       return success(typeConverter->convertType(srcType) ==
67975e5f0aaSAlex Zinenko                      typeConverter->convertType(dstType));
68075e5f0aaSAlex Zinenko 
68175e5f0aaSAlex Zinenko     // At least one of the operands is unranked type
6825550c821STres Popp     assert(isa<UnrankedMemRefType>(srcType) ||
6835550c821STres Popp            isa<UnrankedMemRefType>(dstType));
68475e5f0aaSAlex Zinenko 
68575e5f0aaSAlex Zinenko     // Unranked to unranked cast is disallowed
6865550c821STres Popp     return !(isa<UnrankedMemRefType>(srcType) &&
6875550c821STres Popp              isa<UnrankedMemRefType>(dstType))
68875e5f0aaSAlex Zinenko                ? success()
68975e5f0aaSAlex Zinenko                : failure();
69075e5f0aaSAlex Zinenko   }
69175e5f0aaSAlex Zinenko 
692ef976337SRiver Riddle   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
69375e5f0aaSAlex Zinenko                ConversionPatternRewriter &rewriter) const override {
69475e5f0aaSAlex Zinenko     auto srcType = memRefCastOp.getOperand().getType();
69575e5f0aaSAlex Zinenko     auto dstType = memRefCastOp.getType();
69675e5f0aaSAlex Zinenko     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
69775e5f0aaSAlex Zinenko     auto loc = memRefCastOp.getLoc();
69875e5f0aaSAlex Zinenko 
69975e5f0aaSAlex Zinenko     // For ranked/ranked case, just keep the original descriptor.
7005550c821STres Popp     if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
701136d746eSJacques Pienaar       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
70275e5f0aaSAlex Zinenko 
7035550c821STres Popp     if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
70475e5f0aaSAlex Zinenko       // Casting ranked to unranked memref type
70575e5f0aaSAlex Zinenko       // Set the rank in the destination from the memref type
70675e5f0aaSAlex Zinenko       // Allocate space on the stack and copy the src memref descriptor
70775e5f0aaSAlex Zinenko       // Set the ptr in the destination to the stack space
7085550c821STres Popp       auto srcMemRefType = cast<MemRefType>(srcType);
70975e5f0aaSAlex Zinenko       int64_t rank = srcMemRefType.getRank();
71075e5f0aaSAlex Zinenko       // ptr = AllocaOp sizeof(MemRefDescriptor)
71175e5f0aaSAlex Zinenko       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
712136d746eSJacques Pienaar           loc, adaptor.getSource(), rewriter);
71350ea17b8SMarkus Böck 
71475e5f0aaSAlex Zinenko       // rank = ConstantOp srcRank
71575e5f0aaSAlex Zinenko       auto rankVal = rewriter.create<LLVM::ConstantOp>(
7165b02a480SAdrian Kuegel           loc, getIndexType(), rewriter.getIndexAttr(rank));
71775e5f0aaSAlex Zinenko       // undef = UndefOp
71875e5f0aaSAlex Zinenko       UnrankedMemRefDescriptor memRefDesc =
71975e5f0aaSAlex Zinenko           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
72075e5f0aaSAlex Zinenko       // d1 = InsertValueOp undef, rank, 0
72175e5f0aaSAlex Zinenko       memRefDesc.setRank(rewriter, loc, rankVal);
722b28a296cSChristian Ulmann       // d2 = InsertValueOp d1, ptr, 1
723b28a296cSChristian Ulmann       memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
72475e5f0aaSAlex Zinenko       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
72575e5f0aaSAlex Zinenko 
7265550c821STres Popp     } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
72775e5f0aaSAlex Zinenko       // Casting from unranked type to ranked.
72875e5f0aaSAlex Zinenko       // The operation is assumed to be doing a correct cast. If the destination
72975e5f0aaSAlex Zinenko       // type mismatches the unranked the type, it is undefined behavior.
730136d746eSJacques Pienaar       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
73175e5f0aaSAlex Zinenko       // ptr = ExtractValueOp src, 1
73275e5f0aaSAlex Zinenko       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
73350ea17b8SMarkus Böck 
734b28a296cSChristian Ulmann       // struct = LoadOp ptr
735b28a296cSChristian Ulmann       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
73675e5f0aaSAlex Zinenko       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
73775e5f0aaSAlex Zinenko     } else {
73875e5f0aaSAlex Zinenko       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
73975e5f0aaSAlex Zinenko     }
74075e5f0aaSAlex Zinenko   }
74175e5f0aaSAlex Zinenko };
74275e5f0aaSAlex Zinenko 
743ab95ba70SStephan Herhut /// Pattern to lower a `memref.copy` to llvm.
744ab95ba70SStephan Herhut ///
745ab95ba70SStephan Herhut /// For memrefs with identity layouts, the copy is lowered to the llvm
746ab95ba70SStephan Herhut /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
747ab95ba70SStephan Herhut /// to the generic `MemrefCopyFn`.
74875e5f0aaSAlex Zinenko struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
74975e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
75075e5f0aaSAlex Zinenko 
75175e5f0aaSAlex Zinenko   LogicalResult
752ab95ba70SStephan Herhut   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
753ab95ba70SStephan Herhut                           ConversionPatternRewriter &rewriter) const {
754ab95ba70SStephan Herhut     auto loc = op.getLoc();
7555550c821STres Popp     auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
756ab95ba70SStephan Herhut 
757136d746eSJacques Pienaar     MemRefDescriptor srcDesc(adaptor.getSource());
758ab95ba70SStephan Herhut 
759ab95ba70SStephan Herhut     // Compute number of elements.
760aa3cabe3SStephan Herhut     Value numElements = rewriter.create<LLVM::ConstantOp>(
761aa3cabe3SStephan Herhut         loc, getIndexType(), rewriter.getIndexAttr(1));
762ab95ba70SStephan Herhut     for (int pos = 0; pos < srcType.getRank(); ++pos) {
763ab95ba70SStephan Herhut       auto size = srcDesc.size(rewriter, loc, pos);
764aa3cabe3SStephan Herhut       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
765ab95ba70SStephan Herhut     }
766aa3cabe3SStephan Herhut 
767ab95ba70SStephan Herhut     // Get element size.
768ab95ba70SStephan Herhut     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
769ab95ba70SStephan Herhut     // Compute total.
770ab95ba70SStephan Herhut     Value totalSize =
771ab95ba70SStephan Herhut         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
772ab95ba70SStephan Herhut 
77350ea17b8SMarkus Böck     Type elementType = typeConverter->convertType(srcType.getElementType());
77450ea17b8SMarkus Böck 
775ab95ba70SStephan Herhut     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
77627cd2a62SBenjamin Kramer     Value srcOffset = srcDesc.offset(rewriter, loc);
77750ea17b8SMarkus Böck     Value srcPtr = rewriter.create<LLVM::GEPOp>(
77850ea17b8SMarkus Böck         loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
779136d746eSJacques Pienaar     MemRefDescriptor targetDesc(adaptor.getTarget());
780ab95ba70SStephan Herhut     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
78127cd2a62SBenjamin Kramer     Value targetOffset = targetDesc.offset(rewriter, loc);
78250ea17b8SMarkus Böck     Value targetPtr = rewriter.create<LLVM::GEPOp>(
78350ea17b8SMarkus Böck         loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
78427cd2a62SBenjamin Kramer     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
78548b126e3SChristian Ulmann                                     /*isVolatile=*/false);
786ab95ba70SStephan Herhut     rewriter.eraseOp(op);
787ab95ba70SStephan Herhut 
788ab95ba70SStephan Herhut     return success();
789ab95ba70SStephan Herhut   }
790ab95ba70SStephan Herhut 
791ab95ba70SStephan Herhut   LogicalResult
792ab95ba70SStephan Herhut   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
793ab95ba70SStephan Herhut                              ConversionPatternRewriter &rewriter) const {
79475e5f0aaSAlex Zinenko     auto loc = op.getLoc();
7955550c821STres Popp     auto srcType = cast<BaseMemRefType>(op.getSource().getType());
7965550c821STres Popp     auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
79775e5f0aaSAlex Zinenko 
79875e5f0aaSAlex Zinenko     // First make sure we have an unranked memref descriptor representation.
799eb7f3557SMatthias Springer     auto makeUnranked = [&, this](Value ranked, MemRefType type) {
8000af643f3SJeff Niu       auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
8010af643f3SJeff Niu                                                     type.getRank());
80275e5f0aaSAlex Zinenko       auto *typeConverter = getTypeConverter();
80375e5f0aaSAlex Zinenko       auto ptr =
80475e5f0aaSAlex Zinenko           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
80550ea17b8SMarkus Böck 
80675e5f0aaSAlex Zinenko       auto unrankedType =
80775e5f0aaSAlex Zinenko           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
808b28a296cSChristian Ulmann       return UnrankedMemRefDescriptor::pack(
809b28a296cSChristian Ulmann           rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
81075e5f0aaSAlex Zinenko     };
81175e5f0aaSAlex Zinenko 
812f76e40d1SAndi Drebes     // Save stack position before promoting descriptors
813f76e40d1SAndi Drebes     auto stackSaveOp =
814f76e40d1SAndi Drebes         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
815f76e40d1SAndi Drebes 
8165550c821STres Popp     auto srcMemRefType = dyn_cast<MemRefType>(srcType);
817eb7f3557SMatthias Springer     Value unrankedSource =
818eb7f3557SMatthias Springer         srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
819136d746eSJacques Pienaar                       : adaptor.getSource();
8205550c821STres Popp     auto targetMemRefType = dyn_cast<MemRefType>(targetType);
821eb7f3557SMatthias Springer     Value unrankedTarget =
822eb7f3557SMatthias Springer         targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
823136d746eSJacques Pienaar                          : adaptor.getTarget();
82475e5f0aaSAlex Zinenko 
82575e5f0aaSAlex Zinenko     // Now promote the unranked descriptors to the stack.
82675e5f0aaSAlex Zinenko     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
82775e5f0aaSAlex Zinenko                                                  rewriter.getIndexAttr(1));
82875e5f0aaSAlex Zinenko     auto promote = [&](Value desc) {
829b28a296cSChristian Ulmann       auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
83075e5f0aaSAlex Zinenko       auto allocated =
83150ea17b8SMarkus Böck           rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
83275e5f0aaSAlex Zinenko       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
83375e5f0aaSAlex Zinenko       return allocated;
83475e5f0aaSAlex Zinenko     };
83575e5f0aaSAlex Zinenko 
83675e5f0aaSAlex Zinenko     auto sourcePtr = promote(unrankedSource);
83775e5f0aaSAlex Zinenko     auto targetPtr = promote(unrankedTarget);
83875e5f0aaSAlex Zinenko 
839c336a061SFelix Schneider     // Derive size from llvm.getelementptr which will account for any
840c336a061SFelix Schneider     // potential alignment
841c336a061SFelix Schneider     auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
84275e5f0aaSAlex Zinenko     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
84375e5f0aaSAlex Zinenko         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
844*e84f6b6aSLuohao Wang     if (failed(copyFn))
845*e84f6b6aSLuohao Wang       return failure();
846*e84f6b6aSLuohao Wang     rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
84775e5f0aaSAlex Zinenko                                   ValueRange{elemSize, sourcePtr, targetPtr});
848f76e40d1SAndi Drebes 
849f76e40d1SAndi Drebes     // Restore stack used for descriptors
850f76e40d1SAndi Drebes     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
851f76e40d1SAndi Drebes 
85275e5f0aaSAlex Zinenko     rewriter.eraseOp(op);
85375e5f0aaSAlex Zinenko 
85475e5f0aaSAlex Zinenko     return success();
85575e5f0aaSAlex Zinenko   }
856ab95ba70SStephan Herhut 
857ab95ba70SStephan Herhut   LogicalResult
858ab95ba70SStephan Herhut   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
859ab95ba70SStephan Herhut                   ConversionPatternRewriter &rewriter) const override {
8605550c821STres Popp     auto srcType = cast<BaseMemRefType>(op.getSource().getType());
8615550c821STres Popp     auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
862ab95ba70SStephan Herhut 
86346b90a7bSAlex Zinenko     auto isContiguousMemrefType = [&](BaseMemRefType type) {
8645550c821STres Popp       auto memrefType = dyn_cast<mlir::MemRefType>(type);
86527cd2a62SBenjamin Kramer       // We can use memcpy for memrefs if they have an identity layout or are
86627cd2a62SBenjamin Kramer       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
86727cd2a62SBenjamin Kramer       // special case handled by memrefCopy.
86827cd2a62SBenjamin Kramer       return memrefType &&
86927cd2a62SBenjamin Kramer              (memrefType.getLayout().isIdentity() ||
87027cd2a62SBenjamin Kramer               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
871b4d6aadaSOleg Shyshkov                memref::isStaticShapeAndContiguousRowMajor(memrefType)));
87227cd2a62SBenjamin Kramer     };
87327cd2a62SBenjamin Kramer 
87427cd2a62SBenjamin Kramer     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
875ab95ba70SStephan Herhut       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
876ab95ba70SStephan Herhut 
877ab95ba70SStephan Herhut     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
878ab95ba70SStephan Herhut   }
87975e5f0aaSAlex Zinenko };
88075e5f0aaSAlex Zinenko 
8817fb9bbe5SKrzysztof Drewniak struct MemorySpaceCastOpLowering
8827fb9bbe5SKrzysztof Drewniak     : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
8837fb9bbe5SKrzysztof Drewniak   using ConvertOpToLLVMPattern<
8847fb9bbe5SKrzysztof Drewniak       memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
8857fb9bbe5SKrzysztof Drewniak 
8867fb9bbe5SKrzysztof Drewniak   LogicalResult
8877fb9bbe5SKrzysztof Drewniak   matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
8887fb9bbe5SKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const override {
8897fb9bbe5SKrzysztof Drewniak     Location loc = op.getLoc();
8907fb9bbe5SKrzysztof Drewniak 
8917fb9bbe5SKrzysztof Drewniak     Type resultType = op.getDest().getType();
8925550c821STres Popp     if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
8937fb9bbe5SKrzysztof Drewniak       auto resultDescType =
8945550c821STres Popp           cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
8957fb9bbe5SKrzysztof Drewniak       Type newPtrType = resultDescType.getBody()[0];
8967fb9bbe5SKrzysztof Drewniak 
8977fb9bbe5SKrzysztof Drewniak       SmallVector<Value> descVals;
8987fb9bbe5SKrzysztof Drewniak       MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
8997fb9bbe5SKrzysztof Drewniak                                descVals);
9007fb9bbe5SKrzysztof Drewniak       descVals[0] =
9017fb9bbe5SKrzysztof Drewniak           rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
9027fb9bbe5SKrzysztof Drewniak       descVals[1] =
9037fb9bbe5SKrzysztof Drewniak           rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
9047fb9bbe5SKrzysztof Drewniak       Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
9057fb9bbe5SKrzysztof Drewniak                                             resultTypeR, descVals);
9067fb9bbe5SKrzysztof Drewniak       rewriter.replaceOp(op, result);
9077fb9bbe5SKrzysztof Drewniak       return success();
9087fb9bbe5SKrzysztof Drewniak     }
9095550c821STres Popp     if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
9107fb9bbe5SKrzysztof Drewniak       // Since the type converter won't be doing this for us, get the address
9117fb9bbe5SKrzysztof Drewniak       // space.
9125550c821STres Popp       auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
9137fb9bbe5SKrzysztof Drewniak       FailureOr<unsigned> maybeSourceAddrSpace =
9147fb9bbe5SKrzysztof Drewniak           getTypeConverter()->getMemRefAddressSpace(sourceType);
9157fb9bbe5SKrzysztof Drewniak       if (failed(maybeSourceAddrSpace))
9167fb9bbe5SKrzysztof Drewniak         return rewriter.notifyMatchFailure(loc,
9177fb9bbe5SKrzysztof Drewniak                                            "non-integer source address space");
9187fb9bbe5SKrzysztof Drewniak       unsigned sourceAddrSpace = *maybeSourceAddrSpace;
9197fb9bbe5SKrzysztof Drewniak       FailureOr<unsigned> maybeResultAddrSpace =
9207fb9bbe5SKrzysztof Drewniak           getTypeConverter()->getMemRefAddressSpace(resultTypeU);
9217fb9bbe5SKrzysztof Drewniak       if (failed(maybeResultAddrSpace))
9227fb9bbe5SKrzysztof Drewniak         return rewriter.notifyMatchFailure(loc,
9237fb9bbe5SKrzysztof Drewniak                                            "non-integer result address space");
9247fb9bbe5SKrzysztof Drewniak       unsigned resultAddrSpace = *maybeResultAddrSpace;
9257fb9bbe5SKrzysztof Drewniak 
9267fb9bbe5SKrzysztof Drewniak       UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
9277fb9bbe5SKrzysztof Drewniak       Value rank = sourceDesc.rank(rewriter, loc);
9287fb9bbe5SKrzysztof Drewniak       Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
9297fb9bbe5SKrzysztof Drewniak 
9307fb9bbe5SKrzysztof Drewniak       // Create and allocate storage for new memref descriptor.
9317fb9bbe5SKrzysztof Drewniak       auto result = UnrankedMemRefDescriptor::undef(
9327fb9bbe5SKrzysztof Drewniak           rewriter, loc, typeConverter->convertType(resultTypeU));
9337fb9bbe5SKrzysztof Drewniak       result.setRank(rewriter, loc, rank);
9347fb9bbe5SKrzysztof Drewniak       SmallVector<Value, 1> sizes;
9357fb9bbe5SKrzysztof Drewniak       UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
9367fb9bbe5SKrzysztof Drewniak                                              result, resultAddrSpace, sizes);
9377fb9bbe5SKrzysztof Drewniak       Value resultUnderlyingSize = sizes.front();
9387fb9bbe5SKrzysztof Drewniak       Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
9397fb9bbe5SKrzysztof Drewniak           loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
9407fb9bbe5SKrzysztof Drewniak       result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
9417fb9bbe5SKrzysztof Drewniak 
9427fb9bbe5SKrzysztof Drewniak       // Copy pointers, performing address space casts.
943b28a296cSChristian Ulmann       auto sourceElemPtrType =
944b28a296cSChristian Ulmann           LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
9457fb9bbe5SKrzysztof Drewniak       auto resultElemPtrType =
946b28a296cSChristian Ulmann           LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
9477fb9bbe5SKrzysztof Drewniak 
9487fb9bbe5SKrzysztof Drewniak       Value allocatedPtr = sourceDesc.allocatedPtr(
9497fb9bbe5SKrzysztof Drewniak           rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
9507fb9bbe5SKrzysztof Drewniak       Value alignedPtr =
9517fb9bbe5SKrzysztof Drewniak           sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
9527fb9bbe5SKrzysztof Drewniak                                 sourceUnderlyingDesc, sourceElemPtrType);
9537fb9bbe5SKrzysztof Drewniak       allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
9547fb9bbe5SKrzysztof Drewniak           loc, resultElemPtrType, allocatedPtr);
9557fb9bbe5SKrzysztof Drewniak       alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
9567fb9bbe5SKrzysztof Drewniak           loc, resultElemPtrType, alignedPtr);
9577fb9bbe5SKrzysztof Drewniak 
9587fb9bbe5SKrzysztof Drewniak       result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
9597fb9bbe5SKrzysztof Drewniak                              resultElemPtrType, allocatedPtr);
9607fb9bbe5SKrzysztof Drewniak       result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
9617fb9bbe5SKrzysztof Drewniak                            resultUnderlyingDesc, resultElemPtrType, alignedPtr);
9627fb9bbe5SKrzysztof Drewniak 
9637fb9bbe5SKrzysztof Drewniak       // Copy all the index-valued operands.
9647fb9bbe5SKrzysztof Drewniak       Value sourceIndexVals =
9657fb9bbe5SKrzysztof Drewniak           sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
9667fb9bbe5SKrzysztof Drewniak                                    sourceUnderlyingDesc, sourceElemPtrType);
9677fb9bbe5SKrzysztof Drewniak       Value resultIndexVals =
9687fb9bbe5SKrzysztof Drewniak           result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
9697fb9bbe5SKrzysztof Drewniak                                resultUnderlyingDesc, resultElemPtrType);
9707fb9bbe5SKrzysztof Drewniak 
9717fb9bbe5SKrzysztof Drewniak       int64_t bytesToSkip =
972e843f029SRamkumar Ramachandra           2 * llvm::divideCeil(
9730fb216fbSRamkumar Ramachandra                   getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
9747fb9bbe5SKrzysztof Drewniak       Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
9757fb9bbe5SKrzysztof Drewniak           loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
9767fb9bbe5SKrzysztof Drewniak       Value copySize = rewriter.create<LLVM::SubOp>(
9777fb9bbe5SKrzysztof Drewniak           loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
9787fb9bbe5SKrzysztof Drewniak       rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
97948b126e3SChristian Ulmann                                       copySize, /*isVolatile=*/false);
9807fb9bbe5SKrzysztof Drewniak 
9817fb9bbe5SKrzysztof Drewniak       rewriter.replaceOp(op, ValueRange{result});
9827fb9bbe5SKrzysztof Drewniak       return success();
9837fb9bbe5SKrzysztof Drewniak     }
9847fb9bbe5SKrzysztof Drewniak     return rewriter.notifyMatchFailure(loc, "unexpected memref type");
9857fb9bbe5SKrzysztof Drewniak   }
9867fb9bbe5SKrzysztof Drewniak };
9877fb9bbe5SKrzysztof Drewniak 
98875e5f0aaSAlex Zinenko /// Extracts allocated, aligned pointers and offset from a ranked or unranked
98975e5f0aaSAlex Zinenko /// memref type. In unranked case, the fields are extracted from the underlying
99075e5f0aaSAlex Zinenko /// ranked descriptor.
99175e5f0aaSAlex Zinenko static void extractPointersAndOffset(Location loc,
99275e5f0aaSAlex Zinenko                                      ConversionPatternRewriter &rewriter,
993ce254598SMatthias Springer                                      const LLVMTypeConverter &typeConverter,
99475e5f0aaSAlex Zinenko                                      Value originalOperand,
99575e5f0aaSAlex Zinenko                                      Value convertedOperand,
99675e5f0aaSAlex Zinenko                                      Value *allocatedPtr, Value *alignedPtr,
99775e5f0aaSAlex Zinenko                                      Value *offset = nullptr) {
99875e5f0aaSAlex Zinenko   Type operandType = originalOperand.getType();
9995550c821STres Popp   if (isa<MemRefType>(operandType)) {
100075e5f0aaSAlex Zinenko     MemRefDescriptor desc(convertedOperand);
100175e5f0aaSAlex Zinenko     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
100275e5f0aaSAlex Zinenko     *alignedPtr = desc.alignedPtr(rewriter, loc);
100375e5f0aaSAlex Zinenko     if (offset != nullptr)
100475e5f0aaSAlex Zinenko       *offset = desc.offset(rewriter, loc);
100575e5f0aaSAlex Zinenko     return;
100675e5f0aaSAlex Zinenko   }
100775e5f0aaSAlex Zinenko 
1008499abb24SKrzysztof Drewniak   // These will all cause assert()s on unconvertible types.
1009499abb24SKrzysztof Drewniak   unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
10105550c821STres Popp       cast<UnrankedMemRefType>(operandType));
1011b28a296cSChristian Ulmann   auto elementPtrType =
1012b28a296cSChristian Ulmann       LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
101375e5f0aaSAlex Zinenko 
101475e5f0aaSAlex Zinenko   // Extract pointer to the underlying ranked memref descriptor and cast it to
101575e5f0aaSAlex Zinenko   // ElemType**.
101675e5f0aaSAlex Zinenko   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
101775e5f0aaSAlex Zinenko   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
101875e5f0aaSAlex Zinenko 
101975e5f0aaSAlex Zinenko   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
102050ea17b8SMarkus Böck       rewriter, loc, underlyingDescPtr, elementPtrType);
102175e5f0aaSAlex Zinenko   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
102250ea17b8SMarkus Böck       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
102375e5f0aaSAlex Zinenko   if (offset != nullptr) {
102475e5f0aaSAlex Zinenko     *offset = UnrankedMemRefDescriptor::offset(
102550ea17b8SMarkus Böck         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
102675e5f0aaSAlex Zinenko   }
102775e5f0aaSAlex Zinenko }
102875e5f0aaSAlex Zinenko 
102975e5f0aaSAlex Zinenko struct MemRefReinterpretCastOpLowering
103075e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
103175e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<
103275e5f0aaSAlex Zinenko       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
103375e5f0aaSAlex Zinenko 
103475e5f0aaSAlex Zinenko   LogicalResult
1035ef976337SRiver Riddle   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
103675e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
1037136d746eSJacques Pienaar     Type srcType = castOp.getSource().getType();
103875e5f0aaSAlex Zinenko 
103975e5f0aaSAlex Zinenko     Value descriptor;
104075e5f0aaSAlex Zinenko     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
104175e5f0aaSAlex Zinenko                                                adaptor, &descriptor)))
104275e5f0aaSAlex Zinenko       return failure();
104375e5f0aaSAlex Zinenko     rewriter.replaceOp(castOp, {descriptor});
104475e5f0aaSAlex Zinenko     return success();
104575e5f0aaSAlex Zinenko   }
104675e5f0aaSAlex Zinenko 
104775e5f0aaSAlex Zinenko private:
104875e5f0aaSAlex Zinenko   LogicalResult convertSourceMemRefToDescriptor(
104975e5f0aaSAlex Zinenko       ConversionPatternRewriter &rewriter, Type srcType,
105075e5f0aaSAlex Zinenko       memref::ReinterpretCastOp castOp,
105175e5f0aaSAlex Zinenko       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
105275e5f0aaSAlex Zinenko     MemRefType targetMemRefType =
10535550c821STres Popp         cast<MemRefType>(castOp.getResult().getType());
10545550c821STres Popp     auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
10555550c821STres Popp         typeConverter->convertType(targetMemRefType));
105675e5f0aaSAlex Zinenko     if (!llvmTargetDescriptorTy)
105775e5f0aaSAlex Zinenko       return failure();
105875e5f0aaSAlex Zinenko 
105975e5f0aaSAlex Zinenko     // Create descriptor.
106075e5f0aaSAlex Zinenko     Location loc = castOp.getLoc();
106175e5f0aaSAlex Zinenko     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
106275e5f0aaSAlex Zinenko 
106375e5f0aaSAlex Zinenko     // Set allocated and aligned pointers.
106475e5f0aaSAlex Zinenko     Value allocatedPtr, alignedPtr;
106575e5f0aaSAlex Zinenko     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1066136d746eSJacques Pienaar                              castOp.getSource(), adaptor.getSource(),
1067136d746eSJacques Pienaar                              &allocatedPtr, &alignedPtr);
106875e5f0aaSAlex Zinenko     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
106975e5f0aaSAlex Zinenko     desc.setAlignedPtr(rewriter, loc, alignedPtr);
107075e5f0aaSAlex Zinenko 
107175e5f0aaSAlex Zinenko     // Set offset.
107275e5f0aaSAlex Zinenko     if (castOp.isDynamicOffset(0))
1073136d746eSJacques Pienaar       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
107475e5f0aaSAlex Zinenko     else
107575e5f0aaSAlex Zinenko       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
107675e5f0aaSAlex Zinenko 
107775e5f0aaSAlex Zinenko     // Set sizes and strides.
107875e5f0aaSAlex Zinenko     unsigned dynSizeId = 0;
107975e5f0aaSAlex Zinenko     unsigned dynStrideId = 0;
108075e5f0aaSAlex Zinenko     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
108175e5f0aaSAlex Zinenko       if (castOp.isDynamicSize(i))
1082136d746eSJacques Pienaar         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
108375e5f0aaSAlex Zinenko       else
108475e5f0aaSAlex Zinenko         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
108575e5f0aaSAlex Zinenko 
108675e5f0aaSAlex Zinenko       if (castOp.isDynamicStride(i))
1087136d746eSJacques Pienaar         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
108875e5f0aaSAlex Zinenko       else
108975e5f0aaSAlex Zinenko         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
109075e5f0aaSAlex Zinenko     }
109175e5f0aaSAlex Zinenko     *descriptor = desc;
109275e5f0aaSAlex Zinenko     return success();
109375e5f0aaSAlex Zinenko   }
109475e5f0aaSAlex Zinenko };
109575e5f0aaSAlex Zinenko 
109675e5f0aaSAlex Zinenko struct MemRefReshapeOpLowering
109775e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
109875e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
109975e5f0aaSAlex Zinenko 
110075e5f0aaSAlex Zinenko   LogicalResult
1101ef976337SRiver Riddle   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
110275e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
1103136d746eSJacques Pienaar     Type srcType = reshapeOp.getSource().getType();
110475e5f0aaSAlex Zinenko 
110575e5f0aaSAlex Zinenko     Value descriptor;
110675e5f0aaSAlex Zinenko     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
110775e5f0aaSAlex Zinenko                                                adaptor, &descriptor)))
110875e5f0aaSAlex Zinenko       return failure();
1109ef976337SRiver Riddle     rewriter.replaceOp(reshapeOp, {descriptor});
111075e5f0aaSAlex Zinenko     return success();
111175e5f0aaSAlex Zinenko   }
111275e5f0aaSAlex Zinenko 
111375e5f0aaSAlex Zinenko private:
111475e5f0aaSAlex Zinenko   LogicalResult
111575e5f0aaSAlex Zinenko   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
111675e5f0aaSAlex Zinenko                                   Type srcType, memref::ReshapeOp reshapeOp,
111775e5f0aaSAlex Zinenko                                   memref::ReshapeOp::Adaptor adaptor,
111875e5f0aaSAlex Zinenko                                   Value *descriptor) const {
11195550c821STres Popp     auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
11205380e30eSAshay Rane     if (shapeMemRefType.hasStaticShape()) {
11215380e30eSAshay Rane       MemRefType targetMemRefType =
11225550c821STres Popp           cast<MemRefType>(reshapeOp.getResult().getType());
11235550c821STres Popp       auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
11245550c821STres Popp           typeConverter->convertType(targetMemRefType));
11255380e30eSAshay Rane       if (!llvmTargetDescriptorTy)
112675e5f0aaSAlex Zinenko         return failure();
112775e5f0aaSAlex Zinenko 
11285380e30eSAshay Rane       // Create descriptor.
11295380e30eSAshay Rane       Location loc = reshapeOp.getLoc();
11305380e30eSAshay Rane       auto desc =
11315380e30eSAshay Rane           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11325380e30eSAshay Rane 
11335380e30eSAshay Rane       // Set allocated and aligned pointers.
11345380e30eSAshay Rane       Value allocatedPtr, alignedPtr;
11355380e30eSAshay Rane       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1136136d746eSJacques Pienaar                                reshapeOp.getSource(), adaptor.getSource(),
11375380e30eSAshay Rane                                &allocatedPtr, &alignedPtr);
11385380e30eSAshay Rane       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
11395380e30eSAshay Rane       desc.setAlignedPtr(rewriter, loc, alignedPtr);
11405380e30eSAshay Rane 
11415380e30eSAshay Rane       // Extract the offset and strides from the type.
11425380e30eSAshay Rane       int64_t offset;
11435380e30eSAshay Rane       SmallVector<int64_t> strides;
11446aaa8f25SMatthias Springer       if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
11455380e30eSAshay Rane         return rewriter.notifyMatchFailure(
11465380e30eSAshay Rane             reshapeOp, "failed to get stride and offset exprs");
11475380e30eSAshay Rane 
11485380e30eSAshay Rane       if (!isStaticStrideOrOffset(offset))
11495380e30eSAshay Rane         return rewriter.notifyMatchFailure(reshapeOp,
11505380e30eSAshay Rane                                            "dynamic offset is unsupported");
11515380e30eSAshay Rane 
11525380e30eSAshay Rane       desc.setConstantOffset(rewriter, loc, offset);
11535fee1799SAshay Rane 
11545fee1799SAshay Rane       assert(targetMemRefType.getLayout().isIdentity() &&
11555fee1799SAshay Rane              "Identity layout map is a precondition of a valid reshape op");
11565fee1799SAshay Rane 
1157e98e5995SAlex Zinenko       Type indexType = getIndexType();
11585fee1799SAshay Rane       Value stride = nullptr;
11595fee1799SAshay Rane       int64_t targetRank = targetMemRefType.getRank();
11605fee1799SAshay Rane       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1161399638f9SAliia Khasanova         if (!ShapedType::isDynamic(strides[i])) {
11625fee1799SAshay Rane           // If the stride for this dimension is dynamic, then use the product
11635fee1799SAshay Rane           // of the sizes of the inner dimensions.
1164620e2bb2SNicolas Vasilache           stride =
1165620e2bb2SNicolas Vasilache               createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
11665fee1799SAshay Rane         } else if (!stride) {
11675fee1799SAshay Rane           // `stride` is null only in the first iteration of the loop.  However,
11685fee1799SAshay Rane           // since the target memref has an identity layout, we can safely set
11695fee1799SAshay Rane           // the innermost stride to 1.
1170620e2bb2SNicolas Vasilache           stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
11715fee1799SAshay Rane         }
11725fee1799SAshay Rane 
11735fee1799SAshay Rane         Value dimSize;
11745fee1799SAshay Rane         // If the size of this dimension is dynamic, then load it at runtime
11755fee1799SAshay Rane         // from the shape operand.
1176eb7f3557SMatthias Springer         if (!targetMemRefType.isDynamicDim(i)) {
1177620e2bb2SNicolas Vasilache           dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1178eb7f3557SMatthias Springer                                             targetMemRefType.getDimSize(i));
11795fee1799SAshay Rane         } else {
1180136d746eSJacques Pienaar           Value shapeOp = reshapeOp.getShape();
1181620e2bb2SNicolas Vasilache           Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
11825fee1799SAshay Rane           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1183d4217e6cSIvan Butygin           Type indexType = getIndexType();
1184d4217e6cSIvan Butygin           if (dimSize.getType() != indexType)
1185d4217e6cSIvan Butygin             dimSize = typeConverter->materializeTargetConversion(
1186d4217e6cSIvan Butygin                 rewriter, loc, indexType, dimSize);
1187d4217e6cSIvan Butygin           assert(dimSize && "Invalid memref element type");
11885fee1799SAshay Rane         }
11895fee1799SAshay Rane 
11905fee1799SAshay Rane         desc.setSize(rewriter, loc, i, dimSize);
11915fee1799SAshay Rane         desc.setStride(rewriter, loc, i, stride);
11925fee1799SAshay Rane 
11935fee1799SAshay Rane         // Prepare the stride value for the next dimension.
11945fee1799SAshay Rane         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
11955380e30eSAshay Rane       }
11965380e30eSAshay Rane 
11975380e30eSAshay Rane       *descriptor = desc;
11985380e30eSAshay Rane       return success();
11995380e30eSAshay Rane     }
12005380e30eSAshay Rane 
120175e5f0aaSAlex Zinenko     // The shape is a rank-1 tensor with unknown length.
120275e5f0aaSAlex Zinenko     Location loc = reshapeOp.getLoc();
1203136d746eSJacques Pienaar     MemRefDescriptor shapeDesc(adaptor.getShape());
120475e5f0aaSAlex Zinenko     Value resultRank = shapeDesc.size(rewriter, loc, 0);
120575e5f0aaSAlex Zinenko 
120675e5f0aaSAlex Zinenko     // Extract address space and element type.
12075550c821STres Popp     auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1208499abb24SKrzysztof Drewniak     unsigned addressSpace =
1209499abb24SKrzysztof Drewniak         *getTypeConverter()->getMemRefAddressSpace(targetType);
121075e5f0aaSAlex Zinenko 
121175e5f0aaSAlex Zinenko     // Create the unranked memref descriptor that holds the ranked one. The
121275e5f0aaSAlex Zinenko     // inner descriptor is allocated on stack.
121375e5f0aaSAlex Zinenko     auto targetDesc = UnrankedMemRefDescriptor::undef(
121475e5f0aaSAlex Zinenko         rewriter, loc, typeConverter->convertType(targetType));
121575e5f0aaSAlex Zinenko     targetDesc.setRank(rewriter, loc, resultRank);
121675e5f0aaSAlex Zinenko     SmallVector<Value, 4> sizes;
121775e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1218d0f19ce7SKrzysztof Drewniak                                            targetDesc, addressSpace, sizes);
121975e5f0aaSAlex Zinenko     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
122050ea17b8SMarkus Böck         loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
122150ea17b8SMarkus Böck         sizes.front());
122275e5f0aaSAlex Zinenko     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
122375e5f0aaSAlex Zinenko 
122475e5f0aaSAlex Zinenko     // Extract pointers and offset from the source memref.
122575e5f0aaSAlex Zinenko     Value allocatedPtr, alignedPtr, offset;
122675e5f0aaSAlex Zinenko     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1227136d746eSJacques Pienaar                              reshapeOp.getSource(), adaptor.getSource(),
122875e5f0aaSAlex Zinenko                              &allocatedPtr, &alignedPtr, &offset);
122975e5f0aaSAlex Zinenko 
123075e5f0aaSAlex Zinenko     // Set pointers and offset.
1231b28a296cSChristian Ulmann     auto elementPtrType =
1232b28a296cSChristian Ulmann         LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
123350ea17b8SMarkus Böck 
123475e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
123550ea17b8SMarkus Böck                                               elementPtrType, allocatedPtr);
123675e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
123750ea17b8SMarkus Böck                                             underlyingDescPtr, elementPtrType,
123850ea17b8SMarkus Böck                                             alignedPtr);
123975e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
124050ea17b8SMarkus Böck                                         underlyingDescPtr, elementPtrType,
124175e5f0aaSAlex Zinenko                                         offset);
124275e5f0aaSAlex Zinenko 
124375e5f0aaSAlex Zinenko     // Use the offset pointer as base for further addressing. Copy over the new
124475e5f0aaSAlex Zinenko     // shape and compute strides. For this, we create a loop from rank-1 to 0.
124575e5f0aaSAlex Zinenko     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
124650ea17b8SMarkus Böck         rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
124775e5f0aaSAlex Zinenko     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
124875e5f0aaSAlex Zinenko         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
124975e5f0aaSAlex Zinenko     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1250e98e5995SAlex Zinenko     Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
125175e5f0aaSAlex Zinenko     Value resultRankMinusOne =
125275e5f0aaSAlex Zinenko         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
125375e5f0aaSAlex Zinenko 
125475e5f0aaSAlex Zinenko     Block *initBlock = rewriter.getInsertionBlock();
125575e5f0aaSAlex Zinenko     Type indexType = getTypeConverter()->getIndexType();
125675e5f0aaSAlex Zinenko     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
125775e5f0aaSAlex Zinenko 
125875e5f0aaSAlex Zinenko     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1259e084679fSRiver Riddle                                             {indexType, indexType}, {loc, loc});
126075e5f0aaSAlex Zinenko 
126175e5f0aaSAlex Zinenko     // Move the remaining initBlock ops to condBlock.
126275e5f0aaSAlex Zinenko     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
126375e5f0aaSAlex Zinenko     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
126475e5f0aaSAlex Zinenko 
126575e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(initBlock);
126675e5f0aaSAlex Zinenko     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
126775e5f0aaSAlex Zinenko                                 condBlock);
126875e5f0aaSAlex Zinenko     rewriter.setInsertionPointToStart(condBlock);
126975e5f0aaSAlex Zinenko     Value indexArg = condBlock->getArgument(0);
127075e5f0aaSAlex Zinenko     Value strideArg = condBlock->getArgument(1);
127175e5f0aaSAlex Zinenko 
1272620e2bb2SNicolas Vasilache     Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
127375e5f0aaSAlex Zinenko     Value pred = rewriter.create<LLVM::ICmpOp>(
127475e5f0aaSAlex Zinenko         loc, IntegerType::get(rewriter.getContext(), 1),
127575e5f0aaSAlex Zinenko         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
127675e5f0aaSAlex Zinenko 
127775e5f0aaSAlex Zinenko     Block *bodyBlock =
127875e5f0aaSAlex Zinenko         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
127975e5f0aaSAlex Zinenko     rewriter.setInsertionPointToStart(bodyBlock);
128075e5f0aaSAlex Zinenko 
128175e5f0aaSAlex Zinenko     // Copy size from shape to descriptor.
1282b28a296cSChristian Ulmann     auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
128350ea17b8SMarkus Böck     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
128450ea17b8SMarkus Böck         loc, llvmIndexPtrType,
128550ea17b8SMarkus Böck         typeConverter->convertType(shapeMemRefType.getElementType()),
1286bd7eff1fSMarkus Böck         shapeOperandPtr, indexArg);
128750ea17b8SMarkus Böck     Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
128875e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
128975e5f0aaSAlex Zinenko                                       targetSizesBase, indexArg, size);
129075e5f0aaSAlex Zinenko 
129175e5f0aaSAlex Zinenko     // Write stride value and compute next one.
129275e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
129375e5f0aaSAlex Zinenko                                         targetStridesBase, indexArg, strideArg);
129475e5f0aaSAlex Zinenko     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
129575e5f0aaSAlex Zinenko 
129675e5f0aaSAlex Zinenko     // Decrement loop counter and branch back.
129775e5f0aaSAlex Zinenko     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
129875e5f0aaSAlex Zinenko     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
129975e5f0aaSAlex Zinenko                                 condBlock);
130075e5f0aaSAlex Zinenko 
130175e5f0aaSAlex Zinenko     Block *remainder =
130275e5f0aaSAlex Zinenko         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
130375e5f0aaSAlex Zinenko 
130475e5f0aaSAlex Zinenko     // Hook up the cond exit to the remainder.
130575e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(condBlock);
13061a36588eSKazu Hirata     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
13071a36588eSKazu Hirata                                     remainder, std::nullopt);
130875e5f0aaSAlex Zinenko 
130975e5f0aaSAlex Zinenko     // Reset position to beginning of new remainder block.
131075e5f0aaSAlex Zinenko     rewriter.setInsertionPointToStart(remainder);
131175e5f0aaSAlex Zinenko 
131275e5f0aaSAlex Zinenko     *descriptor = targetDesc;
131375e5f0aaSAlex Zinenko     return success();
131475e5f0aaSAlex Zinenko   }
131575e5f0aaSAlex Zinenko };
131675e5f0aaSAlex Zinenko 
131739148362SQuentin Colombet /// RessociatingReshapeOp must be expanded before we reach this stage.
131839148362SQuentin Colombet /// Report that information.
131946ef86b5SAlexander Belyaev template <typename ReshapeOp>
132046ef86b5SAlexander Belyaev class ReassociatingReshapeOpConversion
132146ef86b5SAlexander Belyaev     : public ConvertOpToLLVMPattern<ReshapeOp> {
132246ef86b5SAlexander Belyaev public:
132346ef86b5SAlexander Belyaev   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
132446ef86b5SAlexander Belyaev   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
132546ef86b5SAlexander Belyaev 
132646ef86b5SAlexander Belyaev   LogicalResult
1327ef976337SRiver Riddle   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
132846ef86b5SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
1329381c3b92SYi Zhang     return rewriter.notifyMatchFailure(
133039148362SQuentin Colombet         reshapeOp,
133139148362SQuentin Colombet         "reassociation operations should have been expanded beforehand");
133246ef86b5SAlexander Belyaev   }
133346ef86b5SAlexander Belyaev };
1334381c3b92SYi Zhang 
1335786cbb09SQuentin Colombet /// Subviews must be expanded before we reach this stage.
1336786cbb09SQuentin Colombet /// Report that information.
133775e5f0aaSAlex Zinenko struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
133875e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
133975e5f0aaSAlex Zinenko 
134075e5f0aaSAlex Zinenko   LogicalResult
1341ef976337SRiver Riddle   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
134275e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
1343786cbb09SQuentin Colombet     return rewriter.notifyMatchFailure(
1344786cbb09SQuentin Colombet         subViewOp, "subview operations should have been expanded beforehand");
134575e5f0aaSAlex Zinenko   }
134675e5f0aaSAlex Zinenko };
134775e5f0aaSAlex Zinenko 
134875e5f0aaSAlex Zinenko /// Conversion pattern that transforms a transpose op into:
134975e5f0aaSAlex Zinenko ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
135075e5f0aaSAlex Zinenko ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
135175e5f0aaSAlex Zinenko ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
135275e5f0aaSAlex Zinenko ///      and stride. Size and stride are permutations of the original values.
135375e5f0aaSAlex Zinenko ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
135475e5f0aaSAlex Zinenko /// The transpose op is replaced by the alloca'ed pointer.
135575e5f0aaSAlex Zinenko class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
135675e5f0aaSAlex Zinenko public:
135775e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
135875e5f0aaSAlex Zinenko 
135975e5f0aaSAlex Zinenko   LogicalResult
1360ef976337SRiver Riddle   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
136175e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
136275e5f0aaSAlex Zinenko     auto loc = transposeOp.getLoc();
1363136d746eSJacques Pienaar     MemRefDescriptor viewMemRef(adaptor.getIn());
136475e5f0aaSAlex Zinenko 
136575e5f0aaSAlex Zinenko     // No permutation, early exit.
1366136d746eSJacques Pienaar     if (transposeOp.getPermutation().isIdentity())
136775e5f0aaSAlex Zinenko       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
136875e5f0aaSAlex Zinenko 
136975e5f0aaSAlex Zinenko     auto targetMemRef = MemRefDescriptor::undef(
1370eb7f3557SMatthias Springer         rewriter, loc,
1371eb7f3557SMatthias Springer         typeConverter->convertType(transposeOp.getIn().getType()));
137275e5f0aaSAlex Zinenko 
137375e5f0aaSAlex Zinenko     // Copy the base and aligned pointers from the old descriptor to the new
137475e5f0aaSAlex Zinenko     // one.
137575e5f0aaSAlex Zinenko     targetMemRef.setAllocatedPtr(rewriter, loc,
137675e5f0aaSAlex Zinenko                                  viewMemRef.allocatedPtr(rewriter, loc));
137775e5f0aaSAlex Zinenko     targetMemRef.setAlignedPtr(rewriter, loc,
137875e5f0aaSAlex Zinenko                                viewMemRef.alignedPtr(rewriter, loc));
137975e5f0aaSAlex Zinenko 
138075e5f0aaSAlex Zinenko     // Copy the offset pointer from the old descriptor to the new one.
138175e5f0aaSAlex Zinenko     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
138275e5f0aaSAlex Zinenko 
138355088efeSFelix Schneider     // Iterate over the dimensions and apply size/stride permutation:
1384b28a296cSChristian Ulmann     // When enumerating the results of the permutation map, the enumeration
1385b28a296cSChristian Ulmann     // index is the index into the target dimensions and the DimExpr points to
1386b28a296cSChristian Ulmann     // the dimension of the source memref.
1387e4853be2SMehdi Amini     for (const auto &en :
1388136d746eSJacques Pienaar          llvm::enumerate(transposeOp.getPermutation().getResults())) {
138955088efeSFelix Schneider       int targetPos = en.index();
13901609f1c2Slong.chen       int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
139175e5f0aaSAlex Zinenko       targetMemRef.setSize(rewriter, loc, targetPos,
139275e5f0aaSAlex Zinenko                            viewMemRef.size(rewriter, loc, sourcePos));
139375e5f0aaSAlex Zinenko       targetMemRef.setStride(rewriter, loc, targetPos,
139475e5f0aaSAlex Zinenko                              viewMemRef.stride(rewriter, loc, sourcePos));
139575e5f0aaSAlex Zinenko     }
139675e5f0aaSAlex Zinenko 
139775e5f0aaSAlex Zinenko     rewriter.replaceOp(transposeOp, {targetMemRef});
139875e5f0aaSAlex Zinenko     return success();
139975e5f0aaSAlex Zinenko   }
140075e5f0aaSAlex Zinenko };
140175e5f0aaSAlex Zinenko 
140275e5f0aaSAlex Zinenko /// Conversion pattern that transforms an op into:
140375e5f0aaSAlex Zinenko ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
140475e5f0aaSAlex Zinenko ///   2. Updates to the descriptor to introduce the data ptr, offset, size
140575e5f0aaSAlex Zinenko ///      and stride.
140675e5f0aaSAlex Zinenko /// The view op is replaced by the descriptor.
140775e5f0aaSAlex Zinenko struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
140875e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
140975e5f0aaSAlex Zinenko 
141075e5f0aaSAlex Zinenko   // Build and return the value for the idx^th shape dimension, either by
141175e5f0aaSAlex Zinenko   // returning the constant shape dimension or counting the proper dynamic size.
141275e5f0aaSAlex Zinenko   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1413620e2bb2SNicolas Vasilache                 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1414620e2bb2SNicolas Vasilache                 Type indexType) const {
141575e5f0aaSAlex Zinenko     assert(idx < shape.size());
141675e5f0aaSAlex Zinenko     if (!ShapedType::isDynamic(shape[idx]))
1417620e2bb2SNicolas Vasilache       return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
141875e5f0aaSAlex Zinenko     // Count the number of dynamic dims in range [0, idx]
1419380a1b20SKazu Hirata     unsigned nDynamic =
1420380a1b20SKazu Hirata         llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
142175e5f0aaSAlex Zinenko     return dynamicSizes[nDynamic];
142275e5f0aaSAlex Zinenko   }
142375e5f0aaSAlex Zinenko 
142475e5f0aaSAlex Zinenko   // Build and return the idx^th stride, either by returning the constant stride
142575e5f0aaSAlex Zinenko   // or by computing the dynamic stride from the current `runningStride` and
142675e5f0aaSAlex Zinenko   // `nextSize`. The caller should keep a running stride and update it with the
142775e5f0aaSAlex Zinenko   // result returned by this function.
142875e5f0aaSAlex Zinenko   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
142975e5f0aaSAlex Zinenko                   ArrayRef<int64_t> strides, Value nextSize,
1430620e2bb2SNicolas Vasilache                   Value runningStride, unsigned idx, Type indexType) const {
143175e5f0aaSAlex Zinenko     assert(idx < strides.size());
1432399638f9SAliia Khasanova     if (!ShapedType::isDynamic(strides[idx]))
1433620e2bb2SNicolas Vasilache       return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
143475e5f0aaSAlex Zinenko     if (nextSize)
143575e5f0aaSAlex Zinenko       return runningStride
143675e5f0aaSAlex Zinenko                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
143775e5f0aaSAlex Zinenko                  : nextSize;
143875e5f0aaSAlex Zinenko     assert(!runningStride);
1439620e2bb2SNicolas Vasilache     return createIndexAttrConstant(rewriter, loc, indexType, 1);
144075e5f0aaSAlex Zinenko   }
144175e5f0aaSAlex Zinenko 
144275e5f0aaSAlex Zinenko   LogicalResult
1443ef976337SRiver Riddle   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
144475e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
144575e5f0aaSAlex Zinenko     auto loc = viewOp.getLoc();
144675e5f0aaSAlex Zinenko 
144775e5f0aaSAlex Zinenko     auto viewMemRefType = viewOp.getType();
144875e5f0aaSAlex Zinenko     auto targetElementTy =
144975e5f0aaSAlex Zinenko         typeConverter->convertType(viewMemRefType.getElementType());
145075e5f0aaSAlex Zinenko     auto targetDescTy = typeConverter->convertType(viewMemRefType);
145175e5f0aaSAlex Zinenko     if (!targetDescTy || !targetElementTy ||
145275e5f0aaSAlex Zinenko         !LLVM::isCompatibleType(targetElementTy) ||
145375e5f0aaSAlex Zinenko         !LLVM::isCompatibleType(targetDescTy))
145475e5f0aaSAlex Zinenko       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
145575e5f0aaSAlex Zinenko              failure();
145675e5f0aaSAlex Zinenko 
145775e5f0aaSAlex Zinenko     int64_t offset;
145875e5f0aaSAlex Zinenko     SmallVector<int64_t, 4> strides;
14596aaa8f25SMatthias Springer     auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
146075e5f0aaSAlex Zinenko     if (failed(successStrides))
146175e5f0aaSAlex Zinenko       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
146275e5f0aaSAlex Zinenko     assert(offset == 0 && "expected offset to be 0");
146375e5f0aaSAlex Zinenko 
1464705f048cSEugene Zhulenev     // Target memref must be contiguous in memory (innermost stride is 1), or
1465705f048cSEugene Zhulenev     // empty (special case when at least one of the memref dimensions is 0).
1466705f048cSEugene Zhulenev     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1467705f048cSEugene Zhulenev       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1468705f048cSEugene Zhulenev              failure();
1469705f048cSEugene Zhulenev 
147075e5f0aaSAlex Zinenko     // Create the descriptor.
1471136d746eSJacques Pienaar     MemRefDescriptor sourceMemRef(adaptor.getSource());
147275e5f0aaSAlex Zinenko     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
147375e5f0aaSAlex Zinenko 
147475e5f0aaSAlex Zinenko     // Field 1: Copy the allocated pointer, used for malloc/free.
147575e5f0aaSAlex Zinenko     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
14765550c821STres Popp     auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1477b28a296cSChristian Ulmann     targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
147875e5f0aaSAlex Zinenko 
147975e5f0aaSAlex Zinenko     // Field 2: Copy the actual aligned pointer to payload.
148075e5f0aaSAlex Zinenko     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1481136d746eSJacques Pienaar     alignedPtr = rewriter.create<LLVM::GEPOp>(
148250ea17b8SMarkus Böck         loc, alignedPtr.getType(),
148350ea17b8SMarkus Böck         typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
148450ea17b8SMarkus Böck         adaptor.getByteShift());
148550ea17b8SMarkus Böck 
1486b28a296cSChristian Ulmann     targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
148775e5f0aaSAlex Zinenko 
1488e98e5995SAlex Zinenko     Type indexType = getIndexType();
1489620e2bb2SNicolas Vasilache     // Field 3: The offset in the resulting type must be 0. This is
1490620e2bb2SNicolas Vasilache     // because of the type change: an offset on srcType* may not be
1491620e2bb2SNicolas Vasilache     // expressible as an offset on dstType*.
1492620e2bb2SNicolas Vasilache     targetMemRef.setOffset(
1493620e2bb2SNicolas Vasilache         rewriter, loc,
1494620e2bb2SNicolas Vasilache         createIndexAttrConstant(rewriter, loc, indexType, offset));
149575e5f0aaSAlex Zinenko 
149675e5f0aaSAlex Zinenko     // Early exit for 0-D corner case.
149775e5f0aaSAlex Zinenko     if (viewMemRefType.getRank() == 0)
149875e5f0aaSAlex Zinenko       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
149975e5f0aaSAlex Zinenko 
150075e5f0aaSAlex Zinenko     // Fields 4 and 5: Update sizes and strides.
150175e5f0aaSAlex Zinenko     Value stride = nullptr, nextSize = nullptr;
150275e5f0aaSAlex Zinenko     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
150375e5f0aaSAlex Zinenko       // Update size.
1504136d746eSJacques Pienaar       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1505620e2bb2SNicolas Vasilache                            adaptor.getSizes(), i, indexType);
150675e5f0aaSAlex Zinenko       targetMemRef.setSize(rewriter, loc, i, size);
150775e5f0aaSAlex Zinenko       // Update stride.
1508620e2bb2SNicolas Vasilache       stride =
1509620e2bb2SNicolas Vasilache           getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
151075e5f0aaSAlex Zinenko       targetMemRef.setStride(rewriter, loc, i, stride);
151175e5f0aaSAlex Zinenko       nextSize = size;
151275e5f0aaSAlex Zinenko     }
151375e5f0aaSAlex Zinenko 
151475e5f0aaSAlex Zinenko     rewriter.replaceOp(viewOp, {targetMemRef});
151575e5f0aaSAlex Zinenko     return success();
151675e5f0aaSAlex Zinenko   }
151775e5f0aaSAlex Zinenko };
151875e5f0aaSAlex Zinenko 
1519a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
1520a6a583daSWilliam S. Moses // AtomicRMWOpLowering
1521a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
1522a6a583daSWilliam S. Moses 
152323aa5a74SRiver Riddle /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1524a6a583daSWilliam S. Moses /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
15257d2b180eSKazu Hirata static std::optional<LLVM::AtomicBinOp>
1526a6a583daSWilliam S. Moses matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1527136d746eSJacques Pienaar   switch (atomicOp.getKind()) {
1528a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::addf:
1529a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::fadd;
1530a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::addi:
1531a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::add;
1532a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::assign:
1533a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::xchg;
1534c46a0433SDaniil Dudkin   case arith::AtomicRMWKind::maximumf:
15357db18533SKrzysztof Drewniak     return LLVM::AtomicBinOp::fmax;
1536a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::maxs:
1537a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::max;
1538a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::maxu:
1539a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::umax;
1540c46a0433SDaniil Dudkin   case arith::AtomicRMWKind::minimumf:
15417db18533SKrzysztof Drewniak     return LLVM::AtomicBinOp::fmin;
1542a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::mins:
1543a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::min;
1544a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::minu:
1545a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::umin;
1546a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::ori:
1547a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::_or;
1548a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::andi:
1549a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::_and;
1550a6a583daSWilliam S. Moses   default:
15511a36588eSKazu Hirata     return std::nullopt;
1552a6a583daSWilliam S. Moses   }
1553a6a583daSWilliam S. Moses   llvm_unreachable("Invalid AtomicRMWKind");
1554a6a583daSWilliam S. Moses }
1555a6a583daSWilliam S. Moses 
1556a6a583daSWilliam S. Moses struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1557a6a583daSWilliam S. Moses   using Base::Base;
1558a6a583daSWilliam S. Moses 
1559a6a583daSWilliam S. Moses   LogicalResult
1560a6a583daSWilliam S. Moses   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1561a6a583daSWilliam S. Moses                   ConversionPatternRewriter &rewriter) const override {
1562a6a583daSWilliam S. Moses     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1563a6a583daSWilliam S. Moses     if (!maybeKind)
1564a6a583daSWilliam S. Moses       return failure();
1565a6a583daSWilliam S. Moses     auto memRefType = atomicOp.getMemRefType();
1566ce6ef990SMax191     SmallVector<int64_t> strides;
1567ce6ef990SMax191     int64_t offset;
15686aaa8f25SMatthias Springer     if (failed(memRefType.getStridesAndOffset(strides, offset)))
1569ce6ef990SMax191       return failure();
1570a6a583daSWilliam S. Moses     auto dataPtr =
1571136d746eSJacques Pienaar         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1572136d746eSJacques Pienaar                              adaptor.getIndices(), rewriter);
1573a6a583daSWilliam S. Moses     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
15747f97895fSTobias Gysi         atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1575a6a583daSWilliam S. Moses         LLVM::AtomicOrdering::acq_rel);
1576a6a583daSWilliam S. Moses     return success();
1577a6a583daSWilliam S. Moses   }
1578a6a583daSWilliam S. Moses };
1579a6a583daSWilliam S. Moses 
158007801f71SNicolas Vasilache /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
158107801f71SNicolas Vasilache class ConvertExtractAlignedPointerAsIndex
158207801f71SNicolas Vasilache     : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
158307801f71SNicolas Vasilache public:
158407801f71SNicolas Vasilache   using ConvertOpToLLVMPattern<
158507801f71SNicolas Vasilache       memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
158607801f71SNicolas Vasilache 
158707801f71SNicolas Vasilache   LogicalResult
158807801f71SNicolas Vasilache   matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
158907801f71SNicolas Vasilache                   OpAdaptor adaptor,
159007801f71SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1591846103c7SSpenser Bauman     BaseMemRefType sourceTy = extractOp.getSource().getType();
1592846103c7SSpenser Bauman 
1593846103c7SSpenser Bauman     Value alignedPtr;
1594846103c7SSpenser Bauman     if (sourceTy.hasRank()) {
159507801f71SNicolas Vasilache       MemRefDescriptor desc(adaptor.getSource());
1596846103c7SSpenser Bauman       alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1597846103c7SSpenser Bauman     } else {
1598846103c7SSpenser Bauman       auto elementPtrTy = LLVM::LLVMPointerType::get(
1599846103c7SSpenser Bauman           rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1600846103c7SSpenser Bauman 
1601846103c7SSpenser Bauman       UnrankedMemRefDescriptor desc(adaptor.getSource());
1602846103c7SSpenser Bauman       Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1603846103c7SSpenser Bauman 
1604846103c7SSpenser Bauman       alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1605846103c7SSpenser Bauman           rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1606846103c7SSpenser Bauman           elementPtrTy);
1607846103c7SSpenser Bauman     }
1608846103c7SSpenser Bauman 
160907801f71SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1610846103c7SSpenser Bauman         extractOp, getTypeConverter()->getIndexType(), alignedPtr);
161107801f71SNicolas Vasilache     return success();
161207801f71SNicolas Vasilache   }
161307801f71SNicolas Vasilache };
161407801f71SNicolas Vasilache 
161555d08a86SQuentin Colombet /// Materialize the MemRef descriptor represented by the results of
161655d08a86SQuentin Colombet /// ExtractStridedMetadataOp.
161755d08a86SQuentin Colombet class ExtractStridedMetadataOpLowering
161855d08a86SQuentin Colombet     : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
161955d08a86SQuentin Colombet public:
162055d08a86SQuentin Colombet   using ConvertOpToLLVMPattern<
162155d08a86SQuentin Colombet       memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
162255d08a86SQuentin Colombet 
162355d08a86SQuentin Colombet   LogicalResult
162455d08a86SQuentin Colombet   matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
162555d08a86SQuentin Colombet                   OpAdaptor adaptor,
162655d08a86SQuentin Colombet                   ConversionPatternRewriter &rewriter) const override {
162755d08a86SQuentin Colombet 
162855d08a86SQuentin Colombet     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
162955d08a86SQuentin Colombet       return failure();
163055d08a86SQuentin Colombet 
163155d08a86SQuentin Colombet     // Create the descriptor.
1632200266a0SQuentin Colombet     MemRefDescriptor sourceMemRef(adaptor.getSource());
163355d08a86SQuentin Colombet     Location loc = extractStridedMetadataOp.getLoc();
163455d08a86SQuentin Colombet     Value source = extractStridedMetadataOp.getSource();
163555d08a86SQuentin Colombet 
16365550c821STres Popp     auto sourceMemRefType = cast<MemRefType>(source.getType());
163755d08a86SQuentin Colombet     int64_t rank = sourceMemRefType.getRank();
163855d08a86SQuentin Colombet     SmallVector<Value> results;
163955d08a86SQuentin Colombet     results.reserve(2 + rank * 2);
164055d08a86SQuentin Colombet 
164155d08a86SQuentin Colombet     // Base buffer.
1642200266a0SQuentin Colombet     Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1643200266a0SQuentin Colombet     Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1644200266a0SQuentin Colombet     MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
1645200266a0SQuentin Colombet         rewriter, loc, *getTypeConverter(),
16465550c821STres Popp         cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1647200266a0SQuentin Colombet         baseBuffer, alignedBuffer);
1648200266a0SQuentin Colombet     results.push_back((Value)dstMemRef);
164955d08a86SQuentin Colombet 
165055d08a86SQuentin Colombet     // Offset.
165155d08a86SQuentin Colombet     results.push_back(sourceMemRef.offset(rewriter, loc));
165255d08a86SQuentin Colombet 
165355d08a86SQuentin Colombet     // Sizes.
165455d08a86SQuentin Colombet     for (unsigned i = 0; i < rank; ++i)
165555d08a86SQuentin Colombet       results.push_back(sourceMemRef.size(rewriter, loc, i));
165655d08a86SQuentin Colombet     // Strides.
165755d08a86SQuentin Colombet     for (unsigned i = 0; i < rank; ++i)
165855d08a86SQuentin Colombet       results.push_back(sourceMemRef.stride(rewriter, loc, i));
165955d08a86SQuentin Colombet 
166055d08a86SQuentin Colombet     rewriter.replaceOp(extractStridedMetadataOp, results);
166155d08a86SQuentin Colombet     return success();
166255d08a86SQuentin Colombet   }
166355d08a86SQuentin Colombet };
166455d08a86SQuentin Colombet 
166575e5f0aaSAlex Zinenko } // namespace
166675e5f0aaSAlex Zinenko 
1667fcb0294bSAlex Zinenko void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1668206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
166975e5f0aaSAlex Zinenko   // clang-format off
167075e5f0aaSAlex Zinenko   patterns.add<
167175e5f0aaSAlex Zinenko       AllocaOpLowering,
167275e5f0aaSAlex Zinenko       AllocaScopeOpLowering,
1673a6a583daSWilliam S. Moses       AtomicRMWOpLowering,
167475e5f0aaSAlex Zinenko       AssumeAlignmentOpLowering,
167507801f71SNicolas Vasilache       ConvertExtractAlignedPointerAsIndex,
167675e5f0aaSAlex Zinenko       DimOpLowering,
167755d08a86SQuentin Colombet       ExtractStridedMetadataOpLowering,
1678632a4f88SRiver Riddle       GenericAtomicRMWOpLowering,
167975e5f0aaSAlex Zinenko       GlobalMemrefOpLowering,
168075e5f0aaSAlex Zinenko       GetGlobalMemrefOpLowering,
168175e5f0aaSAlex Zinenko       LoadOpLowering,
168275e5f0aaSAlex Zinenko       MemRefCastOpLowering,
1683fcb0294bSAlex Zinenko       MemRefCopyOpLowering,
16847fb9bbe5SKrzysztof Drewniak       MemorySpaceCastOpLowering,
168575e5f0aaSAlex Zinenko       MemRefReinterpretCastOpLowering,
168675e5f0aaSAlex Zinenko       MemRefReshapeOpLowering,
168775e5f0aaSAlex Zinenko       PrefetchOpLowering,
168815f8f3e2SAlexander Belyaev       RankOpLowering,
168946ef86b5SAlexander Belyaev       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
169046ef86b5SAlexander Belyaev       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
169175e5f0aaSAlex Zinenko       StoreOpLowering,
169275e5f0aaSAlex Zinenko       SubViewOpLowering,
169375e5f0aaSAlex Zinenko       TransposeOpLowering,
169475e5f0aaSAlex Zinenko       ViewOpLowering>(converter);
169575e5f0aaSAlex Zinenko   // clang-format on
169675e5f0aaSAlex Zinenko   auto allocLowering = converter.getOptions().allocLowering;
169775e5f0aaSAlex Zinenko   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
16988037deb7SMartin Erhart     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
169975e5f0aaSAlex Zinenko   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
17008037deb7SMartin Erhart     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
170175e5f0aaSAlex Zinenko }
170275e5f0aaSAlex Zinenko 
170375e5f0aaSAlex Zinenko namespace {
1704cb4ccd38SQuentin Colombet struct FinalizeMemRefToLLVMConversionPass
1705cb4ccd38SQuentin Colombet     : public impl::FinalizeMemRefToLLVMConversionPassBase<
1706cb4ccd38SQuentin Colombet           FinalizeMemRefToLLVMConversionPass> {
1707cb4ccd38SQuentin Colombet   using FinalizeMemRefToLLVMConversionPassBase::
1708cb4ccd38SQuentin Colombet       FinalizeMemRefToLLVMConversionPassBase;
170975e5f0aaSAlex Zinenko 
171075e5f0aaSAlex Zinenko   void runOnOperation() override {
171175e5f0aaSAlex Zinenko     Operation *op = getOperation();
171275e5f0aaSAlex Zinenko     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
171375e5f0aaSAlex Zinenko     LowerToLLVMOptions options(&getContext(),
171475e5f0aaSAlex Zinenko                                dataLayoutAnalysis.getAtOrAbove(op));
171575e5f0aaSAlex Zinenko     options.allocLowering =
171675e5f0aaSAlex Zinenko         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
171775e5f0aaSAlex Zinenko                          : LowerToLLVMOptions::AllocLowering::Malloc);
1718a8601f11SMichele Scuttari 
1719a8601f11SMichele Scuttari     options.useGenericFunctions = useGenericFunctions;
1720a8601f11SMichele Scuttari 
172175e5f0aaSAlex Zinenko     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
172275e5f0aaSAlex Zinenko       options.overrideIndexBitwidth(indexBitwidth);
172375e5f0aaSAlex Zinenko 
172475e5f0aaSAlex Zinenko     LLVMTypeConverter typeConverter(&getContext(), options,
172575e5f0aaSAlex Zinenko                                     &dataLayoutAnalysis);
172675e5f0aaSAlex Zinenko     RewritePatternSet patterns(&getContext());
1727cb4ccd38SQuentin Colombet     populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
172875e5f0aaSAlex Zinenko     LLVMConversionTarget target(getContext());
172958ceae95SRiver Riddle     target.addLegalOp<func::FuncOp>();
173075e5f0aaSAlex Zinenko     if (failed(applyPartialConversion(op, target, std::move(patterns))))
173175e5f0aaSAlex Zinenko       signalPassFailure();
173275e5f0aaSAlex Zinenko   }
173375e5f0aaSAlex Zinenko };
1734876a480cSMatthias Springer 
1735876a480cSMatthias Springer /// Implement the interface to convert MemRef to LLVM.
1736876a480cSMatthias Springer struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1737876a480cSMatthias Springer   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1738876a480cSMatthias Springer   void loadDependentDialects(MLIRContext *context) const final {
1739876a480cSMatthias Springer     context->loadDialect<LLVM::LLVMDialect>();
1740876a480cSMatthias Springer   }
1741876a480cSMatthias Springer 
1742876a480cSMatthias Springer   /// Hook for derived dialect interface to provide conversion patterns
1743876a480cSMatthias Springer   /// and mark dialect legal for the conversion target.
1744876a480cSMatthias Springer   void populateConvertToLLVMConversionPatterns(
1745876a480cSMatthias Springer       ConversionTarget &target, LLVMTypeConverter &typeConverter,
1746876a480cSMatthias Springer       RewritePatternSet &patterns) const final {
1747876a480cSMatthias Springer     populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1748876a480cSMatthias Springer   }
1749876a480cSMatthias Springer };
1750876a480cSMatthias Springer 
175175e5f0aaSAlex Zinenko } // namespace
1752876a480cSMatthias Springer 
1753876a480cSMatthias Springer void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
1754876a480cSMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1755876a480cSMatthias Springer     dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1756876a480cSMatthias Springer   });
1757876a480cSMatthias Springer }
1758