175e5f0aaSAlex Zinenko //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 275e5f0aaSAlex Zinenko // 375e5f0aaSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 475e5f0aaSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 575e5f0aaSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 675e5f0aaSAlex Zinenko // 775e5f0aaSAlex Zinenko //===----------------------------------------------------------------------===// 875e5f0aaSAlex Zinenko 975e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 1067d0d7acSMichele Scuttari 1175e5f0aaSAlex Zinenko #include "mlir/Analysis/DataLayoutAnalysis.h" 12876a480cSMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 1375e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1475e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1575e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 1675e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1836550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1975e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 2075e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 217fb9bbe5SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 2275e5f0aaSAlex Zinenko #include "mlir/Dialect/MemRef/IR/MemRef.h" 23b4d6aadaSOleg Shyshkov #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 2475e5f0aaSAlex Zinenko #include "mlir/IR/AffineMap.h" 25ce6ef990SMax191 #include "mlir/IR/BuiltinTypes.h" 264d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 2767d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 286635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h" 290fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h" 307d2b180eSKazu Hirata #include <optional> 3175e5f0aaSAlex Zinenko 3267d0d7acSMichele Scuttari namespace mlir { 33cb4ccd38SQuentin Colombet #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS 3467d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3567d0d7acSMichele Scuttari } // namespace mlir 3667d0d7acSMichele Scuttari 3775e5f0aaSAlex Zinenko using namespace mlir; 3875e5f0aaSAlex Zinenko 3975e5f0aaSAlex Zinenko namespace { 4075e5f0aaSAlex Zinenko 41*e84f6b6aSLuohao Wang static bool isStaticStrideOrOffset(int64_t strideOrOffset) { 42399638f9SAliia Khasanova return !ShapedType::isDynamic(strideOrOffset); 435380e30eSAshay Rane } 445380e30eSAshay Rane 45*e84f6b6aSLuohao Wang static FailureOr<LLVM::LLVMFuncOp> 46*e84f6b6aSLuohao Wang getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) { 479f13b934Sbixia1 bool useGenericFn = typeConverter->getOptions().useGenericFunctions; 489f13b934Sbixia1 499f13b934Sbixia1 if (useGenericFn) 50b28a296cSChristian Ulmann return LLVM::lookupOrCreateGenericFreeFn(module); 519f13b934Sbixia1 52b28a296cSChristian Ulmann return LLVM::lookupOrCreateFreeFn(module); 539f13b934Sbixia1 } 549f13b934Sbixia1 5575e5f0aaSAlex Zinenko struct AllocOpLowering : public AllocLikeOpLLVMLowering { 56ce254598SMatthias Springer AllocOpLowering(const LLVMTypeConverter &converter) 5775e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 5875e5f0aaSAlex Zinenko converter) {} 5975e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 6075e5f0aaSAlex Zinenko Location loc, Value sizeBytes, 6175e5f0aaSAlex Zinenko Operation *op) const override { 629f13b934Sbixia1 return allocateBufferManuallyAlign( 639f13b934Sbixia1 rewriter, loc, sizeBytes, op, 649f13b934Sbixia1 getAlignment(rewriter, loc, cast<memref::AllocOp>(op))); 6575e5f0aaSAlex Zinenko } 6675e5f0aaSAlex Zinenko }; 6775e5f0aaSAlex Zinenko 6875e5f0aaSAlex Zinenko struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 69ce254598SMatthias Springer AlignedAllocOpLowering(const LLVMTypeConverter &converter) 7075e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 7175e5f0aaSAlex Zinenko converter) {} 7275e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 7375e5f0aaSAlex Zinenko Location loc, Value sizeBytes, 7475e5f0aaSAlex Zinenko Operation *op) const override { 759f13b934Sbixia1 Value ptr = allocateBufferAutoAlign( 769f13b934Sbixia1 rewriter, loc, sizeBytes, op, &defaultLayout, 779f13b934Sbixia1 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op), 789f13b934Sbixia1 &defaultLayout)); 7973c6248cSKrzysztof Drewniak if (!ptr) 8073c6248cSKrzysztof Drewniak return std::make_tuple(Value(), Value()); 819f13b934Sbixia1 return std::make_tuple(ptr, ptr); 8275e5f0aaSAlex Zinenko } 8375e5f0aaSAlex Zinenko 849f13b934Sbixia1 private: 8575e5f0aaSAlex Zinenko /// Default layout to use in absence of the corresponding analysis. 8675e5f0aaSAlex Zinenko DataLayout defaultLayout; 8775e5f0aaSAlex Zinenko }; 8875e5f0aaSAlex Zinenko 8975e5f0aaSAlex Zinenko struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 90ce254598SMatthias Springer AllocaOpLowering(const LLVMTypeConverter &converter) 9175e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 92041f1abeSFabian Mora converter) { 93041f1abeSFabian Mora setRequiresNumElements(); 94041f1abeSFabian Mora } 9575e5f0aaSAlex Zinenko 9675e5f0aaSAlex Zinenko /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 9775e5f0aaSAlex Zinenko /// is set to null for stack allocations. `accessAlignment` is set if 9875e5f0aaSAlex Zinenko /// alignment is needed post allocation (for eg. in conjunction with malloc). 9975e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 100041f1abeSFabian Mora Location loc, Value size, 10175e5f0aaSAlex Zinenko Operation *op) const override { 10275e5f0aaSAlex Zinenko 10375e5f0aaSAlex Zinenko // With alloca, one gets a pointer to the element type right away. 10475e5f0aaSAlex Zinenko // For stack allocations. 10575e5f0aaSAlex Zinenko auto allocaOp = cast<memref::AllocaOp>(op); 10650ea17b8SMarkus Böck auto elementType = 10750ea17b8SMarkus Böck typeConverter->convertType(allocaOp.getType().getElementType()); 108499abb24SKrzysztof Drewniak unsigned addrSpace = 109499abb24SKrzysztof Drewniak *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType()); 110499abb24SKrzysztof Drewniak auto elementPtrType = 111b28a296cSChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); 11275e5f0aaSAlex Zinenko 113041f1abeSFabian Mora auto allocatedElementPtr = 114041f1abeSFabian Mora rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size, 11550ea17b8SMarkus Böck allocaOp.getAlignment().value_or(0)); 11675e5f0aaSAlex Zinenko 11775e5f0aaSAlex Zinenko return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 11875e5f0aaSAlex Zinenko } 11975e5f0aaSAlex Zinenko }; 12075e5f0aaSAlex Zinenko 12175e5f0aaSAlex Zinenko struct AllocaScopeOpLowering 12275e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 12375e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 12475e5f0aaSAlex Zinenko 12575e5f0aaSAlex Zinenko LogicalResult 126ef976337SRiver Riddle matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 12775e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 12875e5f0aaSAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 12975e5f0aaSAlex Zinenko Location loc = allocaScopeOp.getLoc(); 13075e5f0aaSAlex Zinenko 13175e5f0aaSAlex Zinenko // Split the current block before the AllocaScopeOp to create the inlining 13275e5f0aaSAlex Zinenko // point. 13375e5f0aaSAlex Zinenko auto *currentBlock = rewriter.getInsertionBlock(); 13475e5f0aaSAlex Zinenko auto *remainingOpsBlock = 13575e5f0aaSAlex Zinenko rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 13675e5f0aaSAlex Zinenko Block *continueBlock; 13775e5f0aaSAlex Zinenko if (allocaScopeOp.getNumResults() == 0) { 13875e5f0aaSAlex Zinenko continueBlock = remainingOpsBlock; 13975e5f0aaSAlex Zinenko } else { 140e084679fSRiver Riddle continueBlock = rewriter.createBlock( 141e084679fSRiver Riddle remainingOpsBlock, allocaScopeOp.getResultTypes(), 142e084679fSRiver Riddle SmallVector<Location>(allocaScopeOp->getNumResults(), 143e084679fSRiver Riddle allocaScopeOp.getLoc())); 14475e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 14575e5f0aaSAlex Zinenko } 14675e5f0aaSAlex Zinenko 14775e5f0aaSAlex Zinenko // Inline body region. 148136d746eSJacques Pienaar Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); 149136d746eSJacques Pienaar Block *afterBody = &allocaScopeOp.getBodyRegion().back(); 150136d746eSJacques Pienaar rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); 15175e5f0aaSAlex Zinenko 15275e5f0aaSAlex Zinenko // Save stack and then branch into the body of the region. 15375e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(currentBlock); 15475e5f0aaSAlex Zinenko auto stackSaveOp = 15575e5f0aaSAlex Zinenko rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 15675e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 15775e5f0aaSAlex Zinenko 15875e5f0aaSAlex Zinenko // Replace the alloca_scope return with a branch that jumps out of the body. 15975e5f0aaSAlex Zinenko // Stack restore before leaving the body region. 16075e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(afterBody); 16175e5f0aaSAlex Zinenko auto returnOp = 16275e5f0aaSAlex Zinenko cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 16375e5f0aaSAlex Zinenko auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 164136d746eSJacques Pienaar returnOp, returnOp.getResults(), continueBlock); 16575e5f0aaSAlex Zinenko 16675e5f0aaSAlex Zinenko // Insert stack restore before jumping out the body of the region. 16775e5f0aaSAlex Zinenko rewriter.setInsertionPoint(branchOp); 16875e5f0aaSAlex Zinenko rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 16975e5f0aaSAlex Zinenko 17075e5f0aaSAlex Zinenko // Replace the op with values return from the body region. 17175e5f0aaSAlex Zinenko rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 17275e5f0aaSAlex Zinenko 17375e5f0aaSAlex Zinenko return success(); 17475e5f0aaSAlex Zinenko } 17575e5f0aaSAlex Zinenko }; 17675e5f0aaSAlex Zinenko 17775e5f0aaSAlex Zinenko struct AssumeAlignmentOpLowering 17875e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 17975e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern< 18075e5f0aaSAlex Zinenko memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 181ce254598SMatthias Springer explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter) 182e02d4142SQuentin Colombet : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {} 18375e5f0aaSAlex Zinenko 18475e5f0aaSAlex Zinenko LogicalResult 185ef976337SRiver Riddle matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 18675e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 187136d746eSJacques Pienaar Value memref = adaptor.getMemref(); 188136d746eSJacques Pienaar unsigned alignment = op.getAlignment(); 18975e5f0aaSAlex Zinenko auto loc = op.getLoc(); 19075e5f0aaSAlex Zinenko 1915550c821STres Popp auto srcMemRefType = cast<MemRefType>(op.getMemref().getType()); 192e02d4142SQuentin Colombet Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, 193e02d4142SQuentin Colombet rewriter); 19475e5f0aaSAlex Zinenko 19533598068SKrzysztof Drewniak // Emit llvm.assume(true) ["align"(memref, alignment)]. 19633598068SKrzysztof Drewniak // This is more direct than ptrtoint-based checks, is explicitly supported, 19733598068SKrzysztof Drewniak // and works with non-integral address spaces. 19833598068SKrzysztof Drewniak Value trueCond = 19933598068SKrzysztof Drewniak rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true)); 20033598068SKrzysztof Drewniak Value alignmentConst = 20133598068SKrzysztof Drewniak createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); 20233598068SKrzysztof Drewniak rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr, 20333598068SKrzysztof Drewniak alignmentConst); 20475e5f0aaSAlex Zinenko 20575e5f0aaSAlex Zinenko rewriter.eraseOp(op); 20675e5f0aaSAlex Zinenko return success(); 20775e5f0aaSAlex Zinenko } 20875e5f0aaSAlex Zinenko }; 20975e5f0aaSAlex Zinenko 21075e5f0aaSAlex Zinenko // A `dealloc` is converted into a call to `free` on the underlying data buffer. 21175e5f0aaSAlex Zinenko // The memref descriptor being an SSA value, there is no need to clean it up 21275e5f0aaSAlex Zinenko // in any way. 21375e5f0aaSAlex Zinenko struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 21475e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 21575e5f0aaSAlex Zinenko 216ce254598SMatthias Springer explicit DeallocOpLowering(const LLVMTypeConverter &converter) 21775e5f0aaSAlex Zinenko : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 21875e5f0aaSAlex Zinenko 21975e5f0aaSAlex Zinenko LogicalResult 220ef976337SRiver Riddle matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 22175e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 22275e5f0aaSAlex Zinenko // Insert the `free` declaration if it is not already present. 223*e84f6b6aSLuohao Wang FailureOr<LLVM::LLVMFuncOp> freeFunc = 2249f13b934Sbixia1 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>()); 225*e84f6b6aSLuohao Wang if (failed(freeFunc)) 226*e84f6b6aSLuohao Wang return failure(); 227b58daf91SJohannes Reifferscheid Value allocatedPtr; 228b58daf91SJohannes Reifferscheid if (auto unrankedTy = 229b58daf91SJohannes Reifferscheid llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) { 230b28a296cSChristian Ulmann auto elementPtrTy = LLVM::LLVMPointerType::get( 231b28a296cSChristian Ulmann rewriter.getContext(), unrankedTy.getMemorySpaceAsInt()); 232b58daf91SJohannes Reifferscheid allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 233b58daf91SJohannes Reifferscheid rewriter, op.getLoc(), 234b58daf91SJohannes Reifferscheid UnrankedMemRefDescriptor(adaptor.getMemref()) 235b58daf91SJohannes Reifferscheid .memRefDescPtr(rewriter, op.getLoc()), 236b58daf91SJohannes Reifferscheid elementPtrTy); 237b58daf91SJohannes Reifferscheid } else { 238b58daf91SJohannes Reifferscheid allocatedPtr = MemRefDescriptor(adaptor.getMemref()) 239b58daf91SJohannes Reifferscheid .allocatedPtr(rewriter, op.getLoc()); 240b58daf91SJohannes Reifferscheid } 241*e84f6b6aSLuohao Wang rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(), 242*e84f6b6aSLuohao Wang allocatedPtr); 24375e5f0aaSAlex Zinenko return success(); 24475e5f0aaSAlex Zinenko } 24575e5f0aaSAlex Zinenko }; 24675e5f0aaSAlex Zinenko 24775e5f0aaSAlex Zinenko // A `dim` is converted to a constant for static sizes and to an access to the 24875e5f0aaSAlex Zinenko // size stored in the memref descriptor for dynamic sizes. 24975e5f0aaSAlex Zinenko struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 25075e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 25175e5f0aaSAlex Zinenko 25275e5f0aaSAlex Zinenko LogicalResult 253ef976337SRiver Riddle matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 25475e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 255136d746eSJacques Pienaar Type operandType = dimOp.getSource().getType(); 2565550c821STres Popp if (isa<UnrankedMemRefType>(operandType)) { 257499abb24SKrzysztof Drewniak FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef( 258499abb24SKrzysztof Drewniak operandType, dimOp, adaptor.getOperands(), rewriter); 259499abb24SKrzysztof Drewniak if (failed(extractedSize)) 260499abb24SKrzysztof Drewniak return failure(); 261499abb24SKrzysztof Drewniak rewriter.replaceOp(dimOp, {*extractedSize}); 26275e5f0aaSAlex Zinenko return success(); 26375e5f0aaSAlex Zinenko } 2645550c821STres Popp if (isa<MemRefType>(operandType)) { 265ef976337SRiver Riddle rewriter.replaceOp( 266ef976337SRiver Riddle dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 267ef976337SRiver Riddle adaptor.getOperands(), rewriter)}); 26875e5f0aaSAlex Zinenko return success(); 26975e5f0aaSAlex Zinenko } 27075e5f0aaSAlex Zinenko llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 27175e5f0aaSAlex Zinenko } 27275e5f0aaSAlex Zinenko 27375e5f0aaSAlex Zinenko private: 274499abb24SKrzysztof Drewniak FailureOr<Value> 275499abb24SKrzysztof Drewniak extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 276ef976337SRiver Riddle OpAdaptor adaptor, 27775e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const { 27875e5f0aaSAlex Zinenko Location loc = dimOp.getLoc(); 27975e5f0aaSAlex Zinenko 2805550c821STres Popp auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType); 28175e5f0aaSAlex Zinenko auto scalarMemRefType = 28275e5f0aaSAlex Zinenko MemRefType::get({}, unrankedMemRefType.getElementType()); 283499abb24SKrzysztof Drewniak FailureOr<unsigned> maybeAddressSpace = 284499abb24SKrzysztof Drewniak getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType); 285499abb24SKrzysztof Drewniak if (failed(maybeAddressSpace)) { 286499abb24SKrzysztof Drewniak dimOp.emitOpError("memref memory space must be convertible to an integer " 287499abb24SKrzysztof Drewniak "address space"); 288499abb24SKrzysztof Drewniak return failure(); 289499abb24SKrzysztof Drewniak } 290499abb24SKrzysztof Drewniak unsigned addressSpace = *maybeAddressSpace; 29175e5f0aaSAlex Zinenko 29275e5f0aaSAlex Zinenko // Extract pointer to the underlying ranked descriptor and bitcast it to a 29375e5f0aaSAlex Zinenko // memref<element_type> descriptor pointer to minimize the number of GEP 29475e5f0aaSAlex Zinenko // operations. 295136d746eSJacques Pienaar UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); 29675e5f0aaSAlex Zinenko Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 29750ea17b8SMarkus Böck 29850ea17b8SMarkus Böck Type elementType = typeConverter->convertType(scalarMemRefType); 29975e5f0aaSAlex Zinenko 30075e5f0aaSAlex Zinenko // Get pointer to offset field of memref<element_type> descriptor. 301b28a296cSChristian Ulmann auto indexPtrTy = 302b28a296cSChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); 30375e5f0aaSAlex Zinenko Value offsetPtr = rewriter.create<LLVM::GEPOp>( 304b28a296cSChristian Ulmann loc, indexPtrTy, elementType, underlyingRankedDesc, 30550ea17b8SMarkus Böck ArrayRef<LLVM::GEPArg>{0, 2}); 30675e5f0aaSAlex Zinenko 30775e5f0aaSAlex Zinenko // The size value that we have to extract can be obtained using GEPop with 30875e5f0aaSAlex Zinenko // `dimOp.index() + 1` index argument. 30975e5f0aaSAlex Zinenko Value idxPlusOne = rewriter.create<LLVM::AddOp>( 310e98e5995SAlex Zinenko loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), 311620e2bb2SNicolas Vasilache adaptor.getIndex()); 31250ea17b8SMarkus Böck Value sizePtr = rewriter.create<LLVM::GEPOp>( 31350ea17b8SMarkus Böck loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, 31450ea17b8SMarkus Böck idxPlusOne); 315499abb24SKrzysztof Drewniak return rewriter 316499abb24SKrzysztof Drewniak .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) 317499abb24SKrzysztof Drewniak .getResult(); 31875e5f0aaSAlex Zinenko } 31975e5f0aaSAlex Zinenko 32022426110SRamkumar Ramachandra std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 32122426110SRamkumar Ramachandra if (auto idx = dimOp.getConstantIndex()) 32275e5f0aaSAlex Zinenko return idx; 32375e5f0aaSAlex Zinenko 324136d746eSJacques Pienaar if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>()) 3255550c821STres Popp return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue(); 32675e5f0aaSAlex Zinenko 3271a36588eSKazu Hirata return std::nullopt; 32875e5f0aaSAlex Zinenko } 32975e5f0aaSAlex Zinenko 33075e5f0aaSAlex Zinenko Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 331ef976337SRiver Riddle OpAdaptor adaptor, 33275e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const { 33375e5f0aaSAlex Zinenko Location loc = dimOp.getLoc(); 334ef976337SRiver Riddle 33575e5f0aaSAlex Zinenko // Take advantage if index is constant. 3365550c821STres Popp MemRefType memRefType = cast<MemRefType>(operandType); 337e98e5995SAlex Zinenko Type indexType = getIndexType(); 33822426110SRamkumar Ramachandra if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) { 3396d5fc1e3SKazu Hirata int64_t i = *index; 3404bc2357cSQuentin Colombet if (i >= 0 && i < memRefType.getRank()) { 34175e5f0aaSAlex Zinenko if (memRefType.isDynamicDim(i)) { 34275e5f0aaSAlex Zinenko // extract dynamic size from the memref descriptor. 343136d746eSJacques Pienaar MemRefDescriptor descriptor(adaptor.getSource()); 34475e5f0aaSAlex Zinenko return descriptor.size(rewriter, loc, i); 34575e5f0aaSAlex Zinenko } 34675e5f0aaSAlex Zinenko // Use constant for static size. 34775e5f0aaSAlex Zinenko int64_t dimSize = memRefType.getDimSize(i); 348620e2bb2SNicolas Vasilache return createIndexAttrConstant(rewriter, loc, indexType, dimSize); 34975e5f0aaSAlex Zinenko } 3504bc2357cSQuentin Colombet } 351136d746eSJacques Pienaar Value index = adaptor.getIndex(); 35275e5f0aaSAlex Zinenko int64_t rank = memRefType.getRank(); 353136d746eSJacques Pienaar MemRefDescriptor memrefDescriptor(adaptor.getSource()); 35475e5f0aaSAlex Zinenko return memrefDescriptor.size(rewriter, loc, index, rank); 35575e5f0aaSAlex Zinenko } 35675e5f0aaSAlex Zinenko }; 35775e5f0aaSAlex Zinenko 358632a4f88SRiver Riddle /// Common base for load and store operations on MemRefs. Restricts the match 359632a4f88SRiver Riddle /// to supported MemRef types. Provides functionality to emit code accessing a 360632a4f88SRiver Riddle /// specific element of the underlying data buffer. 361632a4f88SRiver Riddle template <typename Derived> 362632a4f88SRiver Riddle struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 363632a4f88SRiver Riddle using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 364632a4f88SRiver Riddle using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 365632a4f88SRiver Riddle using Base = LoadStoreOpLowering<Derived>; 366632a4f88SRiver Riddle 367632a4f88SRiver Riddle LogicalResult match(Derived op) const override { 368632a4f88SRiver Riddle MemRefType type = op.getMemRefType(); 369632a4f88SRiver Riddle return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 370632a4f88SRiver Riddle } 371632a4f88SRiver Riddle }; 372632a4f88SRiver Riddle 373632a4f88SRiver Riddle /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 374632a4f88SRiver Riddle /// retried until it succeeds in atomically storing a new value into memory. 375632a4f88SRiver Riddle /// 376632a4f88SRiver Riddle /// +---------------------------------+ 377632a4f88SRiver Riddle /// | <code before the AtomicRMWOp> | 378632a4f88SRiver Riddle /// | <compute initial %loaded> | 379ace01605SRiver Riddle /// | cf.br loop(%loaded) | 380632a4f88SRiver Riddle /// +---------------------------------+ 381632a4f88SRiver Riddle /// | 382632a4f88SRiver Riddle /// -------| | 383632a4f88SRiver Riddle /// | v v 384632a4f88SRiver Riddle /// | +--------------------------------+ 385632a4f88SRiver Riddle /// | | loop(%loaded): | 386632a4f88SRiver Riddle /// | | <body contents> | 387632a4f88SRiver Riddle /// | | %pair = cmpxchg | 388632a4f88SRiver Riddle /// | | %ok = %pair[0] | 389632a4f88SRiver Riddle /// | | %new = %pair[1] | 390ace01605SRiver Riddle /// | | cf.cond_br %ok, end, loop(%new) | 391632a4f88SRiver Riddle /// | +--------------------------------+ 392632a4f88SRiver Riddle /// | | | 393632a4f88SRiver Riddle /// |----------- | 394632a4f88SRiver Riddle /// v 395632a4f88SRiver Riddle /// +--------------------------------+ 396632a4f88SRiver Riddle /// | end: | 397632a4f88SRiver Riddle /// | <code after the AtomicRMWOp> | 398632a4f88SRiver Riddle /// +--------------------------------+ 399632a4f88SRiver Riddle /// 400632a4f88SRiver Riddle struct GenericAtomicRMWOpLowering 401632a4f88SRiver Riddle : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 402632a4f88SRiver Riddle using Base::Base; 403632a4f88SRiver Riddle 404632a4f88SRiver Riddle LogicalResult 405632a4f88SRiver Riddle matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 406632a4f88SRiver Riddle ConversionPatternRewriter &rewriter) const override { 407632a4f88SRiver Riddle auto loc = atomicOp.getLoc(); 408632a4f88SRiver Riddle Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 409632a4f88SRiver Riddle 410632a4f88SRiver Riddle // Split the block into initial, loop, and ending parts. 411632a4f88SRiver Riddle auto *initBlock = rewriter.getInsertionBlock(); 412e7833c20SAlexander Belyaev auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp)); 413e7833c20SAlexander Belyaev loopBlock->addArgument(valueType, loc); 414632a4f88SRiver Riddle 415e7833c20SAlexander Belyaev auto *endBlock = 416e7833c20SAlexander Belyaev rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++); 417632a4f88SRiver Riddle 418632a4f88SRiver Riddle // Compute the loaded value and branch to the loop block. 419632a4f88SRiver Riddle rewriter.setInsertionPointToEnd(initBlock); 4205550c821STres Popp auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType()); 421136d746eSJacques Pienaar auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), 422136d746eSJacques Pienaar adaptor.getIndices(), rewriter); 42350ea17b8SMarkus Böck Value init = rewriter.create<LLVM::LoadOp>( 42450ea17b8SMarkus Böck loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); 425632a4f88SRiver Riddle rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 426632a4f88SRiver Riddle 427632a4f88SRiver Riddle // Prepare the body of the loop block. 428632a4f88SRiver Riddle rewriter.setInsertionPointToStart(loopBlock); 429632a4f88SRiver Riddle 430632a4f88SRiver Riddle // Clone the GenericAtomicRMWOp region and extract the result. 431632a4f88SRiver Riddle auto loopArgument = loopBlock->getArgument(0); 4324d67b278SJeff Niu IRMapping mapping; 433632a4f88SRiver Riddle mapping.map(atomicOp.getCurrentValue(), loopArgument); 434632a4f88SRiver Riddle Block &entryBlock = atomicOp.body().front(); 435632a4f88SRiver Riddle for (auto &nestedOp : entryBlock.without_terminator()) { 436632a4f88SRiver Riddle Operation *clone = rewriter.clone(nestedOp, mapping); 437632a4f88SRiver Riddle mapping.map(nestedOp.getResults(), clone->getResults()); 438632a4f88SRiver Riddle } 439632a4f88SRiver Riddle Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 440632a4f88SRiver Riddle 441632a4f88SRiver Riddle // Prepare the epilog of the loop block. 442632a4f88SRiver Riddle // Append the cmpxchg op to the end of the loop block. 443632a4f88SRiver Riddle auto successOrdering = LLVM::AtomicOrdering::acq_rel; 444632a4f88SRiver Riddle auto failureOrdering = LLVM::AtomicOrdering::monotonic; 445632a4f88SRiver Riddle auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 4467f97895fSTobias Gysi loc, dataPtr, loopArgument, result, successOrdering, failureOrdering); 447632a4f88SRiver Riddle // Extract the %new_loaded and %ok values from the pair. 4485c5af910SJeff Niu Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0); 4495c5af910SJeff Niu Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1); 450632a4f88SRiver Riddle 451632a4f88SRiver Riddle // Conditionally branch to the end or back to the loop depending on %ok. 452632a4f88SRiver Riddle rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 453632a4f88SRiver Riddle loopBlock, newLoaded); 454632a4f88SRiver Riddle 455632a4f88SRiver Riddle rewriter.setInsertionPointToEnd(endBlock); 456632a4f88SRiver Riddle 457632a4f88SRiver Riddle // The 'result' of the atomic_rmw op is the newly loaded value. 458632a4f88SRiver Riddle rewriter.replaceOp(atomicOp, {newLoaded}); 459632a4f88SRiver Riddle 460632a4f88SRiver Riddle return success(); 461632a4f88SRiver Riddle } 462632a4f88SRiver Riddle }; 463632a4f88SRiver Riddle 46475e5f0aaSAlex Zinenko /// Returns the LLVM type of the global variable given the memref type `type`. 465ce254598SMatthias Springer static Type 466ce254598SMatthias Springer convertGlobalMemrefTypeToLLVM(MemRefType type, 467ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) { 46875e5f0aaSAlex Zinenko // LLVM type for a global memref will be a multi-dimension array. For 46975e5f0aaSAlex Zinenko // declarations or uninitialized global memrefs, we can potentially flatten 47075e5f0aaSAlex Zinenko // this to a 1D array. However, for memref.global's with an initial value, 47175e5f0aaSAlex Zinenko // we do not intend to flatten the ElementsAttribute when going from std -> 47275e5f0aaSAlex Zinenko // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 47375e5f0aaSAlex Zinenko Type elementType = typeConverter.convertType(type.getElementType()); 47475e5f0aaSAlex Zinenko Type arrayTy = elementType; 47575e5f0aaSAlex Zinenko // Shape has the outermost dim at index 0, so need to walk it backwards 47675e5f0aaSAlex Zinenko for (int64_t dim : llvm::reverse(type.getShape())) 47775e5f0aaSAlex Zinenko arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 47875e5f0aaSAlex Zinenko return arrayTy; 47975e5f0aaSAlex Zinenko } 48075e5f0aaSAlex Zinenko 48175e5f0aaSAlex Zinenko /// GlobalMemrefOp is lowered to a LLVM Global Variable. 48275e5f0aaSAlex Zinenko struct GlobalMemrefOpLowering 48375e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::GlobalOp> { 48475e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 48575e5f0aaSAlex Zinenko 48675e5f0aaSAlex Zinenko LogicalResult 487ef976337SRiver Riddle matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 48875e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 489136d746eSJacques Pienaar MemRefType type = global.getType(); 49075e5f0aaSAlex Zinenko if (!isConvertibleAndHasIdentityMaps(type)) 49175e5f0aaSAlex Zinenko return failure(); 49275e5f0aaSAlex Zinenko 49375e5f0aaSAlex Zinenko Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 49475e5f0aaSAlex Zinenko 49575e5f0aaSAlex Zinenko LLVM::Linkage linkage = 49675e5f0aaSAlex Zinenko global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 49775e5f0aaSAlex Zinenko 49875e5f0aaSAlex Zinenko Attribute initialValue = nullptr; 49975e5f0aaSAlex Zinenko if (!global.isExternal() && !global.isUninitialized()) { 50068f58812STres Popp auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue()); 50175e5f0aaSAlex Zinenko initialValue = elementsAttr; 50275e5f0aaSAlex Zinenko 50375e5f0aaSAlex Zinenko // For scalar memrefs, the global variable created is of the element type, 50475e5f0aaSAlex Zinenko // so unpack the elements attribute to extract the value. 50575e5f0aaSAlex Zinenko if (type.getRank() == 0) 506937e40a8SRiver Riddle initialValue = elementsAttr.getSplatValue<Attribute>(); 50775e5f0aaSAlex Zinenko } 50875e5f0aaSAlex Zinenko 509136d746eSJacques Pienaar uint64_t alignment = global.getAlignment().value_or(0); 510499abb24SKrzysztof Drewniak FailureOr<unsigned> addressSpace = 511499abb24SKrzysztof Drewniak getTypeConverter()->getMemRefAddressSpace(type); 512499abb24SKrzysztof Drewniak if (failed(addressSpace)) 513499abb24SKrzysztof Drewniak return global.emitOpError( 514499abb24SKrzysztof Drewniak "memory space cannot be converted to an integer address space"); 5158c2ff7b6SWilliam S. Moses auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 516136d746eSJacques Pienaar global, arrayTy, global.getConstant(), linkage, global.getSymName(), 517499abb24SKrzysztof Drewniak initialValue, alignment, *addressSpace); 5188c2ff7b6SWilliam S. Moses if (!global.isExternal() && global.isUninitialized()) { 51991d5653eSMatthias Springer rewriter.createBlock(&newGlobal.getInitializerRegion()); 5208c2ff7b6SWilliam S. Moses Value undef[] = { 5218c2ff7b6SWilliam S. Moses rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 5228c2ff7b6SWilliam S. Moses rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 5238c2ff7b6SWilliam S. Moses } 52475e5f0aaSAlex Zinenko return success(); 52575e5f0aaSAlex Zinenko } 52675e5f0aaSAlex Zinenko }; 52775e5f0aaSAlex Zinenko 52875e5f0aaSAlex Zinenko /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 52975e5f0aaSAlex Zinenko /// the first element stashed into the descriptor. This reuses 53075e5f0aaSAlex Zinenko /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 53175e5f0aaSAlex Zinenko struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 532ce254598SMatthias Springer GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter) 53375e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 53475e5f0aaSAlex Zinenko converter) {} 53575e5f0aaSAlex Zinenko 53675e5f0aaSAlex Zinenko /// Buffer "allocation" for memref.get_global op is getting the address of 53775e5f0aaSAlex Zinenko /// the global variable referenced. 53875e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 53975e5f0aaSAlex Zinenko Location loc, Value sizeBytes, 54075e5f0aaSAlex Zinenko Operation *op) const override { 54175e5f0aaSAlex Zinenko auto getGlobalOp = cast<memref::GetGlobalOp>(op); 5425550c821STres Popp MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType()); 543499abb24SKrzysztof Drewniak 544499abb24SKrzysztof Drewniak // This is called after a type conversion, which would have failed if this 545499abb24SKrzysztof Drewniak // call fails. 54673c6248cSKrzysztof Drewniak FailureOr<unsigned> maybeAddressSpace = 547620e2bb2SNicolas Vasilache getTypeConverter()->getMemRefAddressSpace(type); 54873c6248cSKrzysztof Drewniak if (failed(maybeAddressSpace)) 549620e2bb2SNicolas Vasilache return std::make_tuple(Value(), Value()); 550620e2bb2SNicolas Vasilache unsigned memSpace = *maybeAddressSpace; 55175e5f0aaSAlex Zinenko 55275e5f0aaSAlex Zinenko Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 553b28a296cSChristian Ulmann auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); 55450ea17b8SMarkus Böck auto addressOf = 555b28a296cSChristian Ulmann rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName()); 55675e5f0aaSAlex Zinenko 55775e5f0aaSAlex Zinenko // Get the address of the first element in the array by creating a GEP with 55875e5f0aaSAlex Zinenko // the address of the GV as the base, and (rank + 1) number of 0 indices. 559bd7eff1fSMarkus Böck auto gep = rewriter.create<LLVM::GEPOp>( 560b28a296cSChristian Ulmann loc, ptrTy, arrayTy, addressOf, 561bd7eff1fSMarkus Böck SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0)); 56275e5f0aaSAlex Zinenko 56375e5f0aaSAlex Zinenko // We do not expect the memref obtained using `memref.get_global` to be 56475e5f0aaSAlex Zinenko // ever deallocated. Set the allocated pointer to be known bad value to 56575e5f0aaSAlex Zinenko // help debug if that ever happens. 56675e5f0aaSAlex Zinenko auto intPtrType = getIntPtrType(memSpace); 56775e5f0aaSAlex Zinenko Value deadBeefConst = 56875e5f0aaSAlex Zinenko createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 56975e5f0aaSAlex Zinenko auto deadBeefPtr = 570b28a296cSChristian Ulmann rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst); 57175e5f0aaSAlex Zinenko 57275e5f0aaSAlex Zinenko // Both allocated and aligned pointers are same. We could potentially stash 57375e5f0aaSAlex Zinenko // a nullptr for the allocated pointer since we do not expect any dealloc. 57475e5f0aaSAlex Zinenko return std::make_tuple(deadBeefPtr, gep); 57575e5f0aaSAlex Zinenko } 57675e5f0aaSAlex Zinenko }; 57775e5f0aaSAlex Zinenko 57875e5f0aaSAlex Zinenko // Load operation is lowered to obtaining a pointer to the indexed element 57975e5f0aaSAlex Zinenko // and loading it. 58075e5f0aaSAlex Zinenko struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 58175e5f0aaSAlex Zinenko using Base::Base; 58275e5f0aaSAlex Zinenko 58375e5f0aaSAlex Zinenko LogicalResult 584ef976337SRiver Riddle matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 58575e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 58675e5f0aaSAlex Zinenko auto type = loadOp.getMemRefType(); 58775e5f0aaSAlex Zinenko 588136d746eSJacques Pienaar Value dataPtr = 589136d746eSJacques Pienaar getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), 590136d746eSJacques Pienaar adaptor.getIndices(), rewriter); 59150ea17b8SMarkus Böck rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 59250ea17b8SMarkus Böck loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, 59350ea17b8SMarkus Böck false, loadOp.getNontemporal()); 59475e5f0aaSAlex Zinenko return success(); 59575e5f0aaSAlex Zinenko } 59675e5f0aaSAlex Zinenko }; 59775e5f0aaSAlex Zinenko 59875e5f0aaSAlex Zinenko // Store operation is lowered to obtaining a pointer to the indexed element, 59975e5f0aaSAlex Zinenko // and storing the given value to it. 60075e5f0aaSAlex Zinenko struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 60175e5f0aaSAlex Zinenko using Base::Base; 60275e5f0aaSAlex Zinenko 60375e5f0aaSAlex Zinenko LogicalResult 604ef976337SRiver Riddle matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 60575e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 60675e5f0aaSAlex Zinenko auto type = op.getMemRefType(); 60775e5f0aaSAlex Zinenko 608136d746eSJacques Pienaar Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), 609136d746eSJacques Pienaar adaptor.getIndices(), rewriter); 6101cb91b42SGuray Ozen rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr, 6111cb91b42SGuray Ozen 0, false, op.getNontemporal()); 61275e5f0aaSAlex Zinenko return success(); 61375e5f0aaSAlex Zinenko } 61475e5f0aaSAlex Zinenko }; 61575e5f0aaSAlex Zinenko 61675e5f0aaSAlex Zinenko // The prefetch operation is lowered in a way similar to the load operation 61775e5f0aaSAlex Zinenko // except that the llvm.prefetch operation is used for replacement. 61875e5f0aaSAlex Zinenko struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 61975e5f0aaSAlex Zinenko using Base::Base; 62075e5f0aaSAlex Zinenko 62175e5f0aaSAlex Zinenko LogicalResult 622ef976337SRiver Riddle matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 62375e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 62475e5f0aaSAlex Zinenko auto type = prefetchOp.getMemRefType(); 62575e5f0aaSAlex Zinenko auto loc = prefetchOp.getLoc(); 62675e5f0aaSAlex Zinenko 627136d746eSJacques Pienaar Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), 628136d746eSJacques Pienaar adaptor.getIndices(), rewriter); 62975e5f0aaSAlex Zinenko 63075e5f0aaSAlex Zinenko // Replace with llvm.prefetch. 63148b126e3SChristian Ulmann IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()); 63248b126e3SChristian Ulmann IntegerAttr localityHint = prefetchOp.getLocalityHintAttr(); 63348b126e3SChristian Ulmann IntegerAttr isData = 63448b126e3SChristian Ulmann rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()); 63575e5f0aaSAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 63675e5f0aaSAlex Zinenko localityHint, isData); 63775e5f0aaSAlex Zinenko return success(); 63875e5f0aaSAlex Zinenko } 63975e5f0aaSAlex Zinenko }; 64075e5f0aaSAlex Zinenko 64115f8f3e2SAlexander Belyaev struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 64215f8f3e2SAlexander Belyaev using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 64315f8f3e2SAlexander Belyaev 64415f8f3e2SAlexander Belyaev LogicalResult 64515f8f3e2SAlexander Belyaev matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 64615f8f3e2SAlexander Belyaev ConversionPatternRewriter &rewriter) const override { 64715f8f3e2SAlexander Belyaev Location loc = op.getLoc(); 648136d746eSJacques Pienaar Type operandType = op.getMemref().getType(); 6490a0aff2dSMikhail Goncharov if (dyn_cast<UnrankedMemRefType>(operandType)) { 650136d746eSJacques Pienaar UnrankedMemRefDescriptor desc(adaptor.getMemref()); 65115f8f3e2SAlexander Belyaev rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 65215f8f3e2SAlexander Belyaev return success(); 65315f8f3e2SAlexander Belyaev } 6545550c821STres Popp if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) { 655620e2bb2SNicolas Vasilache Type indexType = getIndexType(); 656620e2bb2SNicolas Vasilache rewriter.replaceOp(op, 657620e2bb2SNicolas Vasilache {createIndexAttrConstant(rewriter, loc, indexType, 658620e2bb2SNicolas Vasilache rankedMemRefType.getRank())}); 65915f8f3e2SAlexander Belyaev return success(); 66015f8f3e2SAlexander Belyaev } 66115f8f3e2SAlexander Belyaev return failure(); 66215f8f3e2SAlexander Belyaev } 66315f8f3e2SAlexander Belyaev }; 66415f8f3e2SAlexander Belyaev 66575e5f0aaSAlex Zinenko struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 66675e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 66775e5f0aaSAlex Zinenko 66875e5f0aaSAlex Zinenko LogicalResult match(memref::CastOp memRefCastOp) const override { 66975e5f0aaSAlex Zinenko Type srcType = memRefCastOp.getOperand().getType(); 67075e5f0aaSAlex Zinenko Type dstType = memRefCastOp.getType(); 67175e5f0aaSAlex Zinenko 67275e5f0aaSAlex Zinenko // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 67375e5f0aaSAlex Zinenko // used for type erasure. For now they must preserve underlying element type 67475e5f0aaSAlex Zinenko // and require source and result type to have the same rank. Therefore, 67575e5f0aaSAlex Zinenko // perform a sanity check that the underlying structs are the same. Once op 67675e5f0aaSAlex Zinenko // semantics are relaxed we can revisit. 6775550c821STres Popp if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) 67875e5f0aaSAlex Zinenko return success(typeConverter->convertType(srcType) == 67975e5f0aaSAlex Zinenko typeConverter->convertType(dstType)); 68075e5f0aaSAlex Zinenko 68175e5f0aaSAlex Zinenko // At least one of the operands is unranked type 6825550c821STres Popp assert(isa<UnrankedMemRefType>(srcType) || 6835550c821STres Popp isa<UnrankedMemRefType>(dstType)); 68475e5f0aaSAlex Zinenko 68575e5f0aaSAlex Zinenko // Unranked to unranked cast is disallowed 6865550c821STres Popp return !(isa<UnrankedMemRefType>(srcType) && 6875550c821STres Popp isa<UnrankedMemRefType>(dstType)) 68875e5f0aaSAlex Zinenko ? success() 68975e5f0aaSAlex Zinenko : failure(); 69075e5f0aaSAlex Zinenko } 69175e5f0aaSAlex Zinenko 692ef976337SRiver Riddle void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 69375e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 69475e5f0aaSAlex Zinenko auto srcType = memRefCastOp.getOperand().getType(); 69575e5f0aaSAlex Zinenko auto dstType = memRefCastOp.getType(); 69675e5f0aaSAlex Zinenko auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 69775e5f0aaSAlex Zinenko auto loc = memRefCastOp.getLoc(); 69875e5f0aaSAlex Zinenko 69975e5f0aaSAlex Zinenko // For ranked/ranked case, just keep the original descriptor. 7005550c821STres Popp if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) 701136d746eSJacques Pienaar return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); 70275e5f0aaSAlex Zinenko 7035550c821STres Popp if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) { 70475e5f0aaSAlex Zinenko // Casting ranked to unranked memref type 70575e5f0aaSAlex Zinenko // Set the rank in the destination from the memref type 70675e5f0aaSAlex Zinenko // Allocate space on the stack and copy the src memref descriptor 70775e5f0aaSAlex Zinenko // Set the ptr in the destination to the stack space 7085550c821STres Popp auto srcMemRefType = cast<MemRefType>(srcType); 70975e5f0aaSAlex Zinenko int64_t rank = srcMemRefType.getRank(); 71075e5f0aaSAlex Zinenko // ptr = AllocaOp sizeof(MemRefDescriptor) 71175e5f0aaSAlex Zinenko auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 712136d746eSJacques Pienaar loc, adaptor.getSource(), rewriter); 71350ea17b8SMarkus Böck 71475e5f0aaSAlex Zinenko // rank = ConstantOp srcRank 71575e5f0aaSAlex Zinenko auto rankVal = rewriter.create<LLVM::ConstantOp>( 7165b02a480SAdrian Kuegel loc, getIndexType(), rewriter.getIndexAttr(rank)); 71775e5f0aaSAlex Zinenko // undef = UndefOp 71875e5f0aaSAlex Zinenko UnrankedMemRefDescriptor memRefDesc = 71975e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 72075e5f0aaSAlex Zinenko // d1 = InsertValueOp undef, rank, 0 72175e5f0aaSAlex Zinenko memRefDesc.setRank(rewriter, loc, rankVal); 722b28a296cSChristian Ulmann // d2 = InsertValueOp d1, ptr, 1 723b28a296cSChristian Ulmann memRefDesc.setMemRefDescPtr(rewriter, loc, ptr); 72475e5f0aaSAlex Zinenko rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 72575e5f0aaSAlex Zinenko 7265550c821STres Popp } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) { 72775e5f0aaSAlex Zinenko // Casting from unranked type to ranked. 72875e5f0aaSAlex Zinenko // The operation is assumed to be doing a correct cast. If the destination 72975e5f0aaSAlex Zinenko // type mismatches the unranked the type, it is undefined behavior. 730136d746eSJacques Pienaar UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); 73175e5f0aaSAlex Zinenko // ptr = ExtractValueOp src, 1 73275e5f0aaSAlex Zinenko auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 73350ea17b8SMarkus Böck 734b28a296cSChristian Ulmann // struct = LoadOp ptr 735b28a296cSChristian Ulmann auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr); 73675e5f0aaSAlex Zinenko rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 73775e5f0aaSAlex Zinenko } else { 73875e5f0aaSAlex Zinenko llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 73975e5f0aaSAlex Zinenko } 74075e5f0aaSAlex Zinenko } 74175e5f0aaSAlex Zinenko }; 74275e5f0aaSAlex Zinenko 743ab95ba70SStephan Herhut /// Pattern to lower a `memref.copy` to llvm. 744ab95ba70SStephan Herhut /// 745ab95ba70SStephan Herhut /// For memrefs with identity layouts, the copy is lowered to the llvm 746ab95ba70SStephan Herhut /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 747ab95ba70SStephan Herhut /// to the generic `MemrefCopyFn`. 74875e5f0aaSAlex Zinenko struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 74975e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 75075e5f0aaSAlex Zinenko 75175e5f0aaSAlex Zinenko LogicalResult 752ab95ba70SStephan Herhut lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 753ab95ba70SStephan Herhut ConversionPatternRewriter &rewriter) const { 754ab95ba70SStephan Herhut auto loc = op.getLoc(); 7555550c821STres Popp auto srcType = dyn_cast<MemRefType>(op.getSource().getType()); 756ab95ba70SStephan Herhut 757136d746eSJacques Pienaar MemRefDescriptor srcDesc(adaptor.getSource()); 758ab95ba70SStephan Herhut 759ab95ba70SStephan Herhut // Compute number of elements. 760aa3cabe3SStephan Herhut Value numElements = rewriter.create<LLVM::ConstantOp>( 761aa3cabe3SStephan Herhut loc, getIndexType(), rewriter.getIndexAttr(1)); 762ab95ba70SStephan Herhut for (int pos = 0; pos < srcType.getRank(); ++pos) { 763ab95ba70SStephan Herhut auto size = srcDesc.size(rewriter, loc, pos); 764aa3cabe3SStephan Herhut numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 765ab95ba70SStephan Herhut } 766aa3cabe3SStephan Herhut 767ab95ba70SStephan Herhut // Get element size. 768ab95ba70SStephan Herhut auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 769ab95ba70SStephan Herhut // Compute total. 770ab95ba70SStephan Herhut Value totalSize = 771ab95ba70SStephan Herhut rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 772ab95ba70SStephan Herhut 77350ea17b8SMarkus Böck Type elementType = typeConverter->convertType(srcType.getElementType()); 77450ea17b8SMarkus Böck 775ab95ba70SStephan Herhut Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 77627cd2a62SBenjamin Kramer Value srcOffset = srcDesc.offset(rewriter, loc); 77750ea17b8SMarkus Böck Value srcPtr = rewriter.create<LLVM::GEPOp>( 77850ea17b8SMarkus Böck loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset); 779136d746eSJacques Pienaar MemRefDescriptor targetDesc(adaptor.getTarget()); 780ab95ba70SStephan Herhut Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 78127cd2a62SBenjamin Kramer Value targetOffset = targetDesc.offset(rewriter, loc); 78250ea17b8SMarkus Böck Value targetPtr = rewriter.create<LLVM::GEPOp>( 78350ea17b8SMarkus Böck loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset); 78427cd2a62SBenjamin Kramer rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 78548b126e3SChristian Ulmann /*isVolatile=*/false); 786ab95ba70SStephan Herhut rewriter.eraseOp(op); 787ab95ba70SStephan Herhut 788ab95ba70SStephan Herhut return success(); 789ab95ba70SStephan Herhut } 790ab95ba70SStephan Herhut 791ab95ba70SStephan Herhut LogicalResult 792ab95ba70SStephan Herhut lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 793ab95ba70SStephan Herhut ConversionPatternRewriter &rewriter) const { 79475e5f0aaSAlex Zinenko auto loc = op.getLoc(); 7955550c821STres Popp auto srcType = cast<BaseMemRefType>(op.getSource().getType()); 7965550c821STres Popp auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); 79775e5f0aaSAlex Zinenko 79875e5f0aaSAlex Zinenko // First make sure we have an unranked memref descriptor representation. 799eb7f3557SMatthias Springer auto makeUnranked = [&, this](Value ranked, MemRefType type) { 8000af643f3SJeff Niu auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 8010af643f3SJeff Niu type.getRank()); 80275e5f0aaSAlex Zinenko auto *typeConverter = getTypeConverter(); 80375e5f0aaSAlex Zinenko auto ptr = 80475e5f0aaSAlex Zinenko typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 80550ea17b8SMarkus Böck 80675e5f0aaSAlex Zinenko auto unrankedType = 80775e5f0aaSAlex Zinenko UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 808b28a296cSChristian Ulmann return UnrankedMemRefDescriptor::pack( 809b28a296cSChristian Ulmann rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr}); 81075e5f0aaSAlex Zinenko }; 81175e5f0aaSAlex Zinenko 812f76e40d1SAndi Drebes // Save stack position before promoting descriptors 813f76e40d1SAndi Drebes auto stackSaveOp = 814f76e40d1SAndi Drebes rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 815f76e40d1SAndi Drebes 8165550c821STres Popp auto srcMemRefType = dyn_cast<MemRefType>(srcType); 817eb7f3557SMatthias Springer Value unrankedSource = 818eb7f3557SMatthias Springer srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) 819136d746eSJacques Pienaar : adaptor.getSource(); 8205550c821STres Popp auto targetMemRefType = dyn_cast<MemRefType>(targetType); 821eb7f3557SMatthias Springer Value unrankedTarget = 822eb7f3557SMatthias Springer targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) 823136d746eSJacques Pienaar : adaptor.getTarget(); 82475e5f0aaSAlex Zinenko 82575e5f0aaSAlex Zinenko // Now promote the unranked descriptors to the stack. 82675e5f0aaSAlex Zinenko auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 82775e5f0aaSAlex Zinenko rewriter.getIndexAttr(1)); 82875e5f0aaSAlex Zinenko auto promote = [&](Value desc) { 829b28a296cSChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 83075e5f0aaSAlex Zinenko auto allocated = 83150ea17b8SMarkus Böck rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one); 83275e5f0aaSAlex Zinenko rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 83375e5f0aaSAlex Zinenko return allocated; 83475e5f0aaSAlex Zinenko }; 83575e5f0aaSAlex Zinenko 83675e5f0aaSAlex Zinenko auto sourcePtr = promote(unrankedSource); 83775e5f0aaSAlex Zinenko auto targetPtr = promote(unrankedTarget); 83875e5f0aaSAlex Zinenko 839c336a061SFelix Schneider // Derive size from llvm.getelementptr which will account for any 840c336a061SFelix Schneider // potential alignment 841c336a061SFelix Schneider auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); 84275e5f0aaSAlex Zinenko auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 84375e5f0aaSAlex Zinenko op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 844*e84f6b6aSLuohao Wang if (failed(copyFn)) 845*e84f6b6aSLuohao Wang return failure(); 846*e84f6b6aSLuohao Wang rewriter.create<LLVM::CallOp>(loc, copyFn.value(), 84775e5f0aaSAlex Zinenko ValueRange{elemSize, sourcePtr, targetPtr}); 848f76e40d1SAndi Drebes 849f76e40d1SAndi Drebes // Restore stack used for descriptors 850f76e40d1SAndi Drebes rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 851f76e40d1SAndi Drebes 85275e5f0aaSAlex Zinenko rewriter.eraseOp(op); 85375e5f0aaSAlex Zinenko 85475e5f0aaSAlex Zinenko return success(); 85575e5f0aaSAlex Zinenko } 856ab95ba70SStephan Herhut 857ab95ba70SStephan Herhut LogicalResult 858ab95ba70SStephan Herhut matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 859ab95ba70SStephan Herhut ConversionPatternRewriter &rewriter) const override { 8605550c821STres Popp auto srcType = cast<BaseMemRefType>(op.getSource().getType()); 8615550c821STres Popp auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); 862ab95ba70SStephan Herhut 86346b90a7bSAlex Zinenko auto isContiguousMemrefType = [&](BaseMemRefType type) { 8645550c821STres Popp auto memrefType = dyn_cast<mlir::MemRefType>(type); 86527cd2a62SBenjamin Kramer // We can use memcpy for memrefs if they have an identity layout or are 86627cd2a62SBenjamin Kramer // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 86727cd2a62SBenjamin Kramer // special case handled by memrefCopy. 86827cd2a62SBenjamin Kramer return memrefType && 86927cd2a62SBenjamin Kramer (memrefType.getLayout().isIdentity() || 87027cd2a62SBenjamin Kramer (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 871b4d6aadaSOleg Shyshkov memref::isStaticShapeAndContiguousRowMajor(memrefType))); 87227cd2a62SBenjamin Kramer }; 87327cd2a62SBenjamin Kramer 87427cd2a62SBenjamin Kramer if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 875ab95ba70SStephan Herhut return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 876ab95ba70SStephan Herhut 877ab95ba70SStephan Herhut return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 878ab95ba70SStephan Herhut } 87975e5f0aaSAlex Zinenko }; 88075e5f0aaSAlex Zinenko 8817fb9bbe5SKrzysztof Drewniak struct MemorySpaceCastOpLowering 8827fb9bbe5SKrzysztof Drewniak : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> { 8837fb9bbe5SKrzysztof Drewniak using ConvertOpToLLVMPattern< 8847fb9bbe5SKrzysztof Drewniak memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; 8857fb9bbe5SKrzysztof Drewniak 8867fb9bbe5SKrzysztof Drewniak LogicalResult 8877fb9bbe5SKrzysztof Drewniak matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, 8887fb9bbe5SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const override { 8897fb9bbe5SKrzysztof Drewniak Location loc = op.getLoc(); 8907fb9bbe5SKrzysztof Drewniak 8917fb9bbe5SKrzysztof Drewniak Type resultType = op.getDest().getType(); 8925550c821STres Popp if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) { 8937fb9bbe5SKrzysztof Drewniak auto resultDescType = 8945550c821STres Popp cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR)); 8957fb9bbe5SKrzysztof Drewniak Type newPtrType = resultDescType.getBody()[0]; 8967fb9bbe5SKrzysztof Drewniak 8977fb9bbe5SKrzysztof Drewniak SmallVector<Value> descVals; 8987fb9bbe5SKrzysztof Drewniak MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, 8997fb9bbe5SKrzysztof Drewniak descVals); 9007fb9bbe5SKrzysztof Drewniak descVals[0] = 9017fb9bbe5SKrzysztof Drewniak rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]); 9027fb9bbe5SKrzysztof Drewniak descVals[1] = 9037fb9bbe5SKrzysztof Drewniak rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]); 9047fb9bbe5SKrzysztof Drewniak Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), 9057fb9bbe5SKrzysztof Drewniak resultTypeR, descVals); 9067fb9bbe5SKrzysztof Drewniak rewriter.replaceOp(op, result); 9077fb9bbe5SKrzysztof Drewniak return success(); 9087fb9bbe5SKrzysztof Drewniak } 9095550c821STres Popp if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) { 9107fb9bbe5SKrzysztof Drewniak // Since the type converter won't be doing this for us, get the address 9117fb9bbe5SKrzysztof Drewniak // space. 9125550c821STres Popp auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType()); 9137fb9bbe5SKrzysztof Drewniak FailureOr<unsigned> maybeSourceAddrSpace = 9147fb9bbe5SKrzysztof Drewniak getTypeConverter()->getMemRefAddressSpace(sourceType); 9157fb9bbe5SKrzysztof Drewniak if (failed(maybeSourceAddrSpace)) 9167fb9bbe5SKrzysztof Drewniak return rewriter.notifyMatchFailure(loc, 9177fb9bbe5SKrzysztof Drewniak "non-integer source address space"); 9187fb9bbe5SKrzysztof Drewniak unsigned sourceAddrSpace = *maybeSourceAddrSpace; 9197fb9bbe5SKrzysztof Drewniak FailureOr<unsigned> maybeResultAddrSpace = 9207fb9bbe5SKrzysztof Drewniak getTypeConverter()->getMemRefAddressSpace(resultTypeU); 9217fb9bbe5SKrzysztof Drewniak if (failed(maybeResultAddrSpace)) 9227fb9bbe5SKrzysztof Drewniak return rewriter.notifyMatchFailure(loc, 9237fb9bbe5SKrzysztof Drewniak "non-integer result address space"); 9247fb9bbe5SKrzysztof Drewniak unsigned resultAddrSpace = *maybeResultAddrSpace; 9257fb9bbe5SKrzysztof Drewniak 9267fb9bbe5SKrzysztof Drewniak UnrankedMemRefDescriptor sourceDesc(adaptor.getSource()); 9277fb9bbe5SKrzysztof Drewniak Value rank = sourceDesc.rank(rewriter, loc); 9287fb9bbe5SKrzysztof Drewniak Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc); 9297fb9bbe5SKrzysztof Drewniak 9307fb9bbe5SKrzysztof Drewniak // Create and allocate storage for new memref descriptor. 9317fb9bbe5SKrzysztof Drewniak auto result = UnrankedMemRefDescriptor::undef( 9327fb9bbe5SKrzysztof Drewniak rewriter, loc, typeConverter->convertType(resultTypeU)); 9337fb9bbe5SKrzysztof Drewniak result.setRank(rewriter, loc, rank); 9347fb9bbe5SKrzysztof Drewniak SmallVector<Value, 1> sizes; 9357fb9bbe5SKrzysztof Drewniak UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 9367fb9bbe5SKrzysztof Drewniak result, resultAddrSpace, sizes); 9377fb9bbe5SKrzysztof Drewniak Value resultUnderlyingSize = sizes.front(); 9387fb9bbe5SKrzysztof Drewniak Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>( 9397fb9bbe5SKrzysztof Drewniak loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize); 9407fb9bbe5SKrzysztof Drewniak result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); 9417fb9bbe5SKrzysztof Drewniak 9427fb9bbe5SKrzysztof Drewniak // Copy pointers, performing address space casts. 943b28a296cSChristian Ulmann auto sourceElemPtrType = 944b28a296cSChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace); 9457fb9bbe5SKrzysztof Drewniak auto resultElemPtrType = 946b28a296cSChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace); 9477fb9bbe5SKrzysztof Drewniak 9487fb9bbe5SKrzysztof Drewniak Value allocatedPtr = sourceDesc.allocatedPtr( 9497fb9bbe5SKrzysztof Drewniak rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType); 9507fb9bbe5SKrzysztof Drewniak Value alignedPtr = 9517fb9bbe5SKrzysztof Drewniak sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), 9527fb9bbe5SKrzysztof Drewniak sourceUnderlyingDesc, sourceElemPtrType); 9537fb9bbe5SKrzysztof Drewniak allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 9547fb9bbe5SKrzysztof Drewniak loc, resultElemPtrType, allocatedPtr); 9557fb9bbe5SKrzysztof Drewniak alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 9567fb9bbe5SKrzysztof Drewniak loc, resultElemPtrType, alignedPtr); 9577fb9bbe5SKrzysztof Drewniak 9587fb9bbe5SKrzysztof Drewniak result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, 9597fb9bbe5SKrzysztof Drewniak resultElemPtrType, allocatedPtr); 9607fb9bbe5SKrzysztof Drewniak result.setAlignedPtr(rewriter, loc, *getTypeConverter(), 9617fb9bbe5SKrzysztof Drewniak resultUnderlyingDesc, resultElemPtrType, alignedPtr); 9627fb9bbe5SKrzysztof Drewniak 9637fb9bbe5SKrzysztof Drewniak // Copy all the index-valued operands. 9647fb9bbe5SKrzysztof Drewniak Value sourceIndexVals = 9657fb9bbe5SKrzysztof Drewniak sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(), 9667fb9bbe5SKrzysztof Drewniak sourceUnderlyingDesc, sourceElemPtrType); 9677fb9bbe5SKrzysztof Drewniak Value resultIndexVals = 9687fb9bbe5SKrzysztof Drewniak result.offsetBasePtr(rewriter, loc, *getTypeConverter(), 9697fb9bbe5SKrzysztof Drewniak resultUnderlyingDesc, resultElemPtrType); 9707fb9bbe5SKrzysztof Drewniak 9717fb9bbe5SKrzysztof Drewniak int64_t bytesToSkip = 972e843f029SRamkumar Ramachandra 2 * llvm::divideCeil( 9730fb216fbSRamkumar Ramachandra getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); 9747fb9bbe5SKrzysztof Drewniak Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>( 9757fb9bbe5SKrzysztof Drewniak loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); 9767fb9bbe5SKrzysztof Drewniak Value copySize = rewriter.create<LLVM::SubOp>( 9777fb9bbe5SKrzysztof Drewniak loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); 9787fb9bbe5SKrzysztof Drewniak rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals, 97948b126e3SChristian Ulmann copySize, /*isVolatile=*/false); 9807fb9bbe5SKrzysztof Drewniak 9817fb9bbe5SKrzysztof Drewniak rewriter.replaceOp(op, ValueRange{result}); 9827fb9bbe5SKrzysztof Drewniak return success(); 9837fb9bbe5SKrzysztof Drewniak } 9847fb9bbe5SKrzysztof Drewniak return rewriter.notifyMatchFailure(loc, "unexpected memref type"); 9857fb9bbe5SKrzysztof Drewniak } 9867fb9bbe5SKrzysztof Drewniak }; 9877fb9bbe5SKrzysztof Drewniak 98875e5f0aaSAlex Zinenko /// Extracts allocated, aligned pointers and offset from a ranked or unranked 98975e5f0aaSAlex Zinenko /// memref type. In unranked case, the fields are extracted from the underlying 99075e5f0aaSAlex Zinenko /// ranked descriptor. 99175e5f0aaSAlex Zinenko static void extractPointersAndOffset(Location loc, 99275e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter, 993ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 99475e5f0aaSAlex Zinenko Value originalOperand, 99575e5f0aaSAlex Zinenko Value convertedOperand, 99675e5f0aaSAlex Zinenko Value *allocatedPtr, Value *alignedPtr, 99775e5f0aaSAlex Zinenko Value *offset = nullptr) { 99875e5f0aaSAlex Zinenko Type operandType = originalOperand.getType(); 9995550c821STres Popp if (isa<MemRefType>(operandType)) { 100075e5f0aaSAlex Zinenko MemRefDescriptor desc(convertedOperand); 100175e5f0aaSAlex Zinenko *allocatedPtr = desc.allocatedPtr(rewriter, loc); 100275e5f0aaSAlex Zinenko *alignedPtr = desc.alignedPtr(rewriter, loc); 100375e5f0aaSAlex Zinenko if (offset != nullptr) 100475e5f0aaSAlex Zinenko *offset = desc.offset(rewriter, loc); 100575e5f0aaSAlex Zinenko return; 100675e5f0aaSAlex Zinenko } 100775e5f0aaSAlex Zinenko 1008499abb24SKrzysztof Drewniak // These will all cause assert()s on unconvertible types. 1009499abb24SKrzysztof Drewniak unsigned memorySpace = *typeConverter.getMemRefAddressSpace( 10105550c821STres Popp cast<UnrankedMemRefType>(operandType)); 1011b28a296cSChristian Ulmann auto elementPtrType = 1012b28a296cSChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace); 101375e5f0aaSAlex Zinenko 101475e5f0aaSAlex Zinenko // Extract pointer to the underlying ranked memref descriptor and cast it to 101575e5f0aaSAlex Zinenko // ElemType**. 101675e5f0aaSAlex Zinenko UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 101775e5f0aaSAlex Zinenko Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 101875e5f0aaSAlex Zinenko 101975e5f0aaSAlex Zinenko *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 102050ea17b8SMarkus Böck rewriter, loc, underlyingDescPtr, elementPtrType); 102175e5f0aaSAlex Zinenko *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 102250ea17b8SMarkus Böck rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType); 102375e5f0aaSAlex Zinenko if (offset != nullptr) { 102475e5f0aaSAlex Zinenko *offset = UnrankedMemRefDescriptor::offset( 102550ea17b8SMarkus Böck rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType); 102675e5f0aaSAlex Zinenko } 102775e5f0aaSAlex Zinenko } 102875e5f0aaSAlex Zinenko 102975e5f0aaSAlex Zinenko struct MemRefReinterpretCastOpLowering 103075e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 103175e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern< 103275e5f0aaSAlex Zinenko memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 103375e5f0aaSAlex Zinenko 103475e5f0aaSAlex Zinenko LogicalResult 1035ef976337SRiver Riddle matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 103675e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 1037136d746eSJacques Pienaar Type srcType = castOp.getSource().getType(); 103875e5f0aaSAlex Zinenko 103975e5f0aaSAlex Zinenko Value descriptor; 104075e5f0aaSAlex Zinenko if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 104175e5f0aaSAlex Zinenko adaptor, &descriptor))) 104275e5f0aaSAlex Zinenko return failure(); 104375e5f0aaSAlex Zinenko rewriter.replaceOp(castOp, {descriptor}); 104475e5f0aaSAlex Zinenko return success(); 104575e5f0aaSAlex Zinenko } 104675e5f0aaSAlex Zinenko 104775e5f0aaSAlex Zinenko private: 104875e5f0aaSAlex Zinenko LogicalResult convertSourceMemRefToDescriptor( 104975e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter, Type srcType, 105075e5f0aaSAlex Zinenko memref::ReinterpretCastOp castOp, 105175e5f0aaSAlex Zinenko memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 105275e5f0aaSAlex Zinenko MemRefType targetMemRefType = 10535550c821STres Popp cast<MemRefType>(castOp.getResult().getType()); 10545550c821STres Popp auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 10555550c821STres Popp typeConverter->convertType(targetMemRefType)); 105675e5f0aaSAlex Zinenko if (!llvmTargetDescriptorTy) 105775e5f0aaSAlex Zinenko return failure(); 105875e5f0aaSAlex Zinenko 105975e5f0aaSAlex Zinenko // Create descriptor. 106075e5f0aaSAlex Zinenko Location loc = castOp.getLoc(); 106175e5f0aaSAlex Zinenko auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 106275e5f0aaSAlex Zinenko 106375e5f0aaSAlex Zinenko // Set allocated and aligned pointers. 106475e5f0aaSAlex Zinenko Value allocatedPtr, alignedPtr; 106575e5f0aaSAlex Zinenko extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1066136d746eSJacques Pienaar castOp.getSource(), adaptor.getSource(), 1067136d746eSJacques Pienaar &allocatedPtr, &alignedPtr); 106875e5f0aaSAlex Zinenko desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 106975e5f0aaSAlex Zinenko desc.setAlignedPtr(rewriter, loc, alignedPtr); 107075e5f0aaSAlex Zinenko 107175e5f0aaSAlex Zinenko // Set offset. 107275e5f0aaSAlex Zinenko if (castOp.isDynamicOffset(0)) 1073136d746eSJacques Pienaar desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); 107475e5f0aaSAlex Zinenko else 107575e5f0aaSAlex Zinenko desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 107675e5f0aaSAlex Zinenko 107775e5f0aaSAlex Zinenko // Set sizes and strides. 107875e5f0aaSAlex Zinenko unsigned dynSizeId = 0; 107975e5f0aaSAlex Zinenko unsigned dynStrideId = 0; 108075e5f0aaSAlex Zinenko for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 108175e5f0aaSAlex Zinenko if (castOp.isDynamicSize(i)) 1082136d746eSJacques Pienaar desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); 108375e5f0aaSAlex Zinenko else 108475e5f0aaSAlex Zinenko desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 108575e5f0aaSAlex Zinenko 108675e5f0aaSAlex Zinenko if (castOp.isDynamicStride(i)) 1087136d746eSJacques Pienaar desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); 108875e5f0aaSAlex Zinenko else 108975e5f0aaSAlex Zinenko desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 109075e5f0aaSAlex Zinenko } 109175e5f0aaSAlex Zinenko *descriptor = desc; 109275e5f0aaSAlex Zinenko return success(); 109375e5f0aaSAlex Zinenko } 109475e5f0aaSAlex Zinenko }; 109575e5f0aaSAlex Zinenko 109675e5f0aaSAlex Zinenko struct MemRefReshapeOpLowering 109775e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 109875e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 109975e5f0aaSAlex Zinenko 110075e5f0aaSAlex Zinenko LogicalResult 1101ef976337SRiver Riddle matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 110275e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 1103136d746eSJacques Pienaar Type srcType = reshapeOp.getSource().getType(); 110475e5f0aaSAlex Zinenko 110575e5f0aaSAlex Zinenko Value descriptor; 110675e5f0aaSAlex Zinenko if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 110775e5f0aaSAlex Zinenko adaptor, &descriptor))) 110875e5f0aaSAlex Zinenko return failure(); 1109ef976337SRiver Riddle rewriter.replaceOp(reshapeOp, {descriptor}); 111075e5f0aaSAlex Zinenko return success(); 111175e5f0aaSAlex Zinenko } 111275e5f0aaSAlex Zinenko 111375e5f0aaSAlex Zinenko private: 111475e5f0aaSAlex Zinenko LogicalResult 111575e5f0aaSAlex Zinenko convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 111675e5f0aaSAlex Zinenko Type srcType, memref::ReshapeOp reshapeOp, 111775e5f0aaSAlex Zinenko memref::ReshapeOp::Adaptor adaptor, 111875e5f0aaSAlex Zinenko Value *descriptor) const { 11195550c821STres Popp auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType()); 11205380e30eSAshay Rane if (shapeMemRefType.hasStaticShape()) { 11215380e30eSAshay Rane MemRefType targetMemRefType = 11225550c821STres Popp cast<MemRefType>(reshapeOp.getResult().getType()); 11235550c821STres Popp auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 11245550c821STres Popp typeConverter->convertType(targetMemRefType)); 11255380e30eSAshay Rane if (!llvmTargetDescriptorTy) 112675e5f0aaSAlex Zinenko return failure(); 112775e5f0aaSAlex Zinenko 11285380e30eSAshay Rane // Create descriptor. 11295380e30eSAshay Rane Location loc = reshapeOp.getLoc(); 11305380e30eSAshay Rane auto desc = 11315380e30eSAshay Rane MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11325380e30eSAshay Rane 11335380e30eSAshay Rane // Set allocated and aligned pointers. 11345380e30eSAshay Rane Value allocatedPtr, alignedPtr; 11355380e30eSAshay Rane extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1136136d746eSJacques Pienaar reshapeOp.getSource(), adaptor.getSource(), 11375380e30eSAshay Rane &allocatedPtr, &alignedPtr); 11385380e30eSAshay Rane desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 11395380e30eSAshay Rane desc.setAlignedPtr(rewriter, loc, alignedPtr); 11405380e30eSAshay Rane 11415380e30eSAshay Rane // Extract the offset and strides from the type. 11425380e30eSAshay Rane int64_t offset; 11435380e30eSAshay Rane SmallVector<int64_t> strides; 11446aaa8f25SMatthias Springer if (failed(targetMemRefType.getStridesAndOffset(strides, offset))) 11455380e30eSAshay Rane return rewriter.notifyMatchFailure( 11465380e30eSAshay Rane reshapeOp, "failed to get stride and offset exprs"); 11475380e30eSAshay Rane 11485380e30eSAshay Rane if (!isStaticStrideOrOffset(offset)) 11495380e30eSAshay Rane return rewriter.notifyMatchFailure(reshapeOp, 11505380e30eSAshay Rane "dynamic offset is unsupported"); 11515380e30eSAshay Rane 11525380e30eSAshay Rane desc.setConstantOffset(rewriter, loc, offset); 11535fee1799SAshay Rane 11545fee1799SAshay Rane assert(targetMemRefType.getLayout().isIdentity() && 11555fee1799SAshay Rane "Identity layout map is a precondition of a valid reshape op"); 11565fee1799SAshay Rane 1157e98e5995SAlex Zinenko Type indexType = getIndexType(); 11585fee1799SAshay Rane Value stride = nullptr; 11595fee1799SAshay Rane int64_t targetRank = targetMemRefType.getRank(); 11605fee1799SAshay Rane for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { 1161399638f9SAliia Khasanova if (!ShapedType::isDynamic(strides[i])) { 11625fee1799SAshay Rane // If the stride for this dimension is dynamic, then use the product 11635fee1799SAshay Rane // of the sizes of the inner dimensions. 1164620e2bb2SNicolas Vasilache stride = 1165620e2bb2SNicolas Vasilache createIndexAttrConstant(rewriter, loc, indexType, strides[i]); 11665fee1799SAshay Rane } else if (!stride) { 11675fee1799SAshay Rane // `stride` is null only in the first iteration of the loop. However, 11685fee1799SAshay Rane // since the target memref has an identity layout, we can safely set 11695fee1799SAshay Rane // the innermost stride to 1. 1170620e2bb2SNicolas Vasilache stride = createIndexAttrConstant(rewriter, loc, indexType, 1); 11715fee1799SAshay Rane } 11725fee1799SAshay Rane 11735fee1799SAshay Rane Value dimSize; 11745fee1799SAshay Rane // If the size of this dimension is dynamic, then load it at runtime 11755fee1799SAshay Rane // from the shape operand. 1176eb7f3557SMatthias Springer if (!targetMemRefType.isDynamicDim(i)) { 1177620e2bb2SNicolas Vasilache dimSize = createIndexAttrConstant(rewriter, loc, indexType, 1178eb7f3557SMatthias Springer targetMemRefType.getDimSize(i)); 11795fee1799SAshay Rane } else { 1180136d746eSJacques Pienaar Value shapeOp = reshapeOp.getShape(); 1181620e2bb2SNicolas Vasilache Value index = createIndexAttrConstant(rewriter, loc, indexType, i); 11825fee1799SAshay Rane dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); 1183d4217e6cSIvan Butygin Type indexType = getIndexType(); 1184d4217e6cSIvan Butygin if (dimSize.getType() != indexType) 1185d4217e6cSIvan Butygin dimSize = typeConverter->materializeTargetConversion( 1186d4217e6cSIvan Butygin rewriter, loc, indexType, dimSize); 1187d4217e6cSIvan Butygin assert(dimSize && "Invalid memref element type"); 11885fee1799SAshay Rane } 11895fee1799SAshay Rane 11905fee1799SAshay Rane desc.setSize(rewriter, loc, i, dimSize); 11915fee1799SAshay Rane desc.setStride(rewriter, loc, i, stride); 11925fee1799SAshay Rane 11935fee1799SAshay Rane // Prepare the stride value for the next dimension. 11945fee1799SAshay Rane stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); 11955380e30eSAshay Rane } 11965380e30eSAshay Rane 11975380e30eSAshay Rane *descriptor = desc; 11985380e30eSAshay Rane return success(); 11995380e30eSAshay Rane } 12005380e30eSAshay Rane 120175e5f0aaSAlex Zinenko // The shape is a rank-1 tensor with unknown length. 120275e5f0aaSAlex Zinenko Location loc = reshapeOp.getLoc(); 1203136d746eSJacques Pienaar MemRefDescriptor shapeDesc(adaptor.getShape()); 120475e5f0aaSAlex Zinenko Value resultRank = shapeDesc.size(rewriter, loc, 0); 120575e5f0aaSAlex Zinenko 120675e5f0aaSAlex Zinenko // Extract address space and element type. 12075550c821STres Popp auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType()); 1208499abb24SKrzysztof Drewniak unsigned addressSpace = 1209499abb24SKrzysztof Drewniak *getTypeConverter()->getMemRefAddressSpace(targetType); 121075e5f0aaSAlex Zinenko 121175e5f0aaSAlex Zinenko // Create the unranked memref descriptor that holds the ranked one. The 121275e5f0aaSAlex Zinenko // inner descriptor is allocated on stack. 121375e5f0aaSAlex Zinenko auto targetDesc = UnrankedMemRefDescriptor::undef( 121475e5f0aaSAlex Zinenko rewriter, loc, typeConverter->convertType(targetType)); 121575e5f0aaSAlex Zinenko targetDesc.setRank(rewriter, loc, resultRank); 121675e5f0aaSAlex Zinenko SmallVector<Value, 4> sizes; 121775e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1218d0f19ce7SKrzysztof Drewniak targetDesc, addressSpace, sizes); 121975e5f0aaSAlex Zinenko Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 122050ea17b8SMarkus Böck loc, getVoidPtrType(), IntegerType::get(getContext(), 8), 122150ea17b8SMarkus Böck sizes.front()); 122275e5f0aaSAlex Zinenko targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 122375e5f0aaSAlex Zinenko 122475e5f0aaSAlex Zinenko // Extract pointers and offset from the source memref. 122575e5f0aaSAlex Zinenko Value allocatedPtr, alignedPtr, offset; 122675e5f0aaSAlex Zinenko extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1227136d746eSJacques Pienaar reshapeOp.getSource(), adaptor.getSource(), 122875e5f0aaSAlex Zinenko &allocatedPtr, &alignedPtr, &offset); 122975e5f0aaSAlex Zinenko 123075e5f0aaSAlex Zinenko // Set pointers and offset. 1231b28a296cSChristian Ulmann auto elementPtrType = 1232b28a296cSChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); 123350ea17b8SMarkus Böck 123475e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 123550ea17b8SMarkus Böck elementPtrType, allocatedPtr); 123675e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 123750ea17b8SMarkus Böck underlyingDescPtr, elementPtrType, 123850ea17b8SMarkus Böck alignedPtr); 123975e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 124050ea17b8SMarkus Böck underlyingDescPtr, elementPtrType, 124175e5f0aaSAlex Zinenko offset); 124275e5f0aaSAlex Zinenko 124375e5f0aaSAlex Zinenko // Use the offset pointer as base for further addressing. Copy over the new 124475e5f0aaSAlex Zinenko // shape and compute strides. For this, we create a loop from rank-1 to 0. 124575e5f0aaSAlex Zinenko Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 124650ea17b8SMarkus Böck rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType); 124775e5f0aaSAlex Zinenko Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 124875e5f0aaSAlex Zinenko rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 124975e5f0aaSAlex Zinenko Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1250e98e5995SAlex Zinenko Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); 125175e5f0aaSAlex Zinenko Value resultRankMinusOne = 125275e5f0aaSAlex Zinenko rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 125375e5f0aaSAlex Zinenko 125475e5f0aaSAlex Zinenko Block *initBlock = rewriter.getInsertionBlock(); 125575e5f0aaSAlex Zinenko Type indexType = getTypeConverter()->getIndexType(); 125675e5f0aaSAlex Zinenko Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 125775e5f0aaSAlex Zinenko 125875e5f0aaSAlex Zinenko Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1259e084679fSRiver Riddle {indexType, indexType}, {loc, loc}); 126075e5f0aaSAlex Zinenko 126175e5f0aaSAlex Zinenko // Move the remaining initBlock ops to condBlock. 126275e5f0aaSAlex Zinenko Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 126375e5f0aaSAlex Zinenko rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 126475e5f0aaSAlex Zinenko 126575e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(initBlock); 126675e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 126775e5f0aaSAlex Zinenko condBlock); 126875e5f0aaSAlex Zinenko rewriter.setInsertionPointToStart(condBlock); 126975e5f0aaSAlex Zinenko Value indexArg = condBlock->getArgument(0); 127075e5f0aaSAlex Zinenko Value strideArg = condBlock->getArgument(1); 127175e5f0aaSAlex Zinenko 1272620e2bb2SNicolas Vasilache Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); 127375e5f0aaSAlex Zinenko Value pred = rewriter.create<LLVM::ICmpOp>( 127475e5f0aaSAlex Zinenko loc, IntegerType::get(rewriter.getContext(), 1), 127575e5f0aaSAlex Zinenko LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 127675e5f0aaSAlex Zinenko 127775e5f0aaSAlex Zinenko Block *bodyBlock = 127875e5f0aaSAlex Zinenko rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 127975e5f0aaSAlex Zinenko rewriter.setInsertionPointToStart(bodyBlock); 128075e5f0aaSAlex Zinenko 128175e5f0aaSAlex Zinenko // Copy size from shape to descriptor. 1282b28a296cSChristian Ulmann auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 128350ea17b8SMarkus Böck Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 128450ea17b8SMarkus Böck loc, llvmIndexPtrType, 128550ea17b8SMarkus Böck typeConverter->convertType(shapeMemRefType.getElementType()), 1286bd7eff1fSMarkus Böck shapeOperandPtr, indexArg); 128750ea17b8SMarkus Böck Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep); 128875e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 128975e5f0aaSAlex Zinenko targetSizesBase, indexArg, size); 129075e5f0aaSAlex Zinenko 129175e5f0aaSAlex Zinenko // Write stride value and compute next one. 129275e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 129375e5f0aaSAlex Zinenko targetStridesBase, indexArg, strideArg); 129475e5f0aaSAlex Zinenko Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 129575e5f0aaSAlex Zinenko 129675e5f0aaSAlex Zinenko // Decrement loop counter and branch back. 129775e5f0aaSAlex Zinenko Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 129875e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 129975e5f0aaSAlex Zinenko condBlock); 130075e5f0aaSAlex Zinenko 130175e5f0aaSAlex Zinenko Block *remainder = 130275e5f0aaSAlex Zinenko rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 130375e5f0aaSAlex Zinenko 130475e5f0aaSAlex Zinenko // Hook up the cond exit to the remainder. 130575e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(condBlock); 13061a36588eSKazu Hirata rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt, 13071a36588eSKazu Hirata remainder, std::nullopt); 130875e5f0aaSAlex Zinenko 130975e5f0aaSAlex Zinenko // Reset position to beginning of new remainder block. 131075e5f0aaSAlex Zinenko rewriter.setInsertionPointToStart(remainder); 131175e5f0aaSAlex Zinenko 131275e5f0aaSAlex Zinenko *descriptor = targetDesc; 131375e5f0aaSAlex Zinenko return success(); 131475e5f0aaSAlex Zinenko } 131575e5f0aaSAlex Zinenko }; 131675e5f0aaSAlex Zinenko 131739148362SQuentin Colombet /// RessociatingReshapeOp must be expanded before we reach this stage. 131839148362SQuentin Colombet /// Report that information. 131946ef86b5SAlexander Belyaev template <typename ReshapeOp> 132046ef86b5SAlexander Belyaev class ReassociatingReshapeOpConversion 132146ef86b5SAlexander Belyaev : public ConvertOpToLLVMPattern<ReshapeOp> { 132246ef86b5SAlexander Belyaev public: 132346ef86b5SAlexander Belyaev using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 132446ef86b5SAlexander Belyaev using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 132546ef86b5SAlexander Belyaev 132646ef86b5SAlexander Belyaev LogicalResult 1327ef976337SRiver Riddle matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 132846ef86b5SAlexander Belyaev ConversionPatternRewriter &rewriter) const override { 1329381c3b92SYi Zhang return rewriter.notifyMatchFailure( 133039148362SQuentin Colombet reshapeOp, 133139148362SQuentin Colombet "reassociation operations should have been expanded beforehand"); 133246ef86b5SAlexander Belyaev } 133346ef86b5SAlexander Belyaev }; 1334381c3b92SYi Zhang 1335786cbb09SQuentin Colombet /// Subviews must be expanded before we reach this stage. 1336786cbb09SQuentin Colombet /// Report that information. 133775e5f0aaSAlex Zinenko struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 133875e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 133975e5f0aaSAlex Zinenko 134075e5f0aaSAlex Zinenko LogicalResult 1341ef976337SRiver Riddle matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 134275e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 1343786cbb09SQuentin Colombet return rewriter.notifyMatchFailure( 1344786cbb09SQuentin Colombet subViewOp, "subview operations should have been expanded beforehand"); 134575e5f0aaSAlex Zinenko } 134675e5f0aaSAlex Zinenko }; 134775e5f0aaSAlex Zinenko 134875e5f0aaSAlex Zinenko /// Conversion pattern that transforms a transpose op into: 134975e5f0aaSAlex Zinenko /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 135075e5f0aaSAlex Zinenko /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 135175e5f0aaSAlex Zinenko /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 135275e5f0aaSAlex Zinenko /// and stride. Size and stride are permutations of the original values. 135375e5f0aaSAlex Zinenko /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 135475e5f0aaSAlex Zinenko /// The transpose op is replaced by the alloca'ed pointer. 135575e5f0aaSAlex Zinenko class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 135675e5f0aaSAlex Zinenko public: 135775e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 135875e5f0aaSAlex Zinenko 135975e5f0aaSAlex Zinenko LogicalResult 1360ef976337SRiver Riddle matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 136175e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 136275e5f0aaSAlex Zinenko auto loc = transposeOp.getLoc(); 1363136d746eSJacques Pienaar MemRefDescriptor viewMemRef(adaptor.getIn()); 136475e5f0aaSAlex Zinenko 136575e5f0aaSAlex Zinenko // No permutation, early exit. 1366136d746eSJacques Pienaar if (transposeOp.getPermutation().isIdentity()) 136775e5f0aaSAlex Zinenko return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 136875e5f0aaSAlex Zinenko 136975e5f0aaSAlex Zinenko auto targetMemRef = MemRefDescriptor::undef( 1370eb7f3557SMatthias Springer rewriter, loc, 1371eb7f3557SMatthias Springer typeConverter->convertType(transposeOp.getIn().getType())); 137275e5f0aaSAlex Zinenko 137375e5f0aaSAlex Zinenko // Copy the base and aligned pointers from the old descriptor to the new 137475e5f0aaSAlex Zinenko // one. 137575e5f0aaSAlex Zinenko targetMemRef.setAllocatedPtr(rewriter, loc, 137675e5f0aaSAlex Zinenko viewMemRef.allocatedPtr(rewriter, loc)); 137775e5f0aaSAlex Zinenko targetMemRef.setAlignedPtr(rewriter, loc, 137875e5f0aaSAlex Zinenko viewMemRef.alignedPtr(rewriter, loc)); 137975e5f0aaSAlex Zinenko 138075e5f0aaSAlex Zinenko // Copy the offset pointer from the old descriptor to the new one. 138175e5f0aaSAlex Zinenko targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 138275e5f0aaSAlex Zinenko 138355088efeSFelix Schneider // Iterate over the dimensions and apply size/stride permutation: 1384b28a296cSChristian Ulmann // When enumerating the results of the permutation map, the enumeration 1385b28a296cSChristian Ulmann // index is the index into the target dimensions and the DimExpr points to 1386b28a296cSChristian Ulmann // the dimension of the source memref. 1387e4853be2SMehdi Amini for (const auto &en : 1388136d746eSJacques Pienaar llvm::enumerate(transposeOp.getPermutation().getResults())) { 138955088efeSFelix Schneider int targetPos = en.index(); 13901609f1c2Slong.chen int sourcePos = cast<AffineDimExpr>(en.value()).getPosition(); 139175e5f0aaSAlex Zinenko targetMemRef.setSize(rewriter, loc, targetPos, 139275e5f0aaSAlex Zinenko viewMemRef.size(rewriter, loc, sourcePos)); 139375e5f0aaSAlex Zinenko targetMemRef.setStride(rewriter, loc, targetPos, 139475e5f0aaSAlex Zinenko viewMemRef.stride(rewriter, loc, sourcePos)); 139575e5f0aaSAlex Zinenko } 139675e5f0aaSAlex Zinenko 139775e5f0aaSAlex Zinenko rewriter.replaceOp(transposeOp, {targetMemRef}); 139875e5f0aaSAlex Zinenko return success(); 139975e5f0aaSAlex Zinenko } 140075e5f0aaSAlex Zinenko }; 140175e5f0aaSAlex Zinenko 140275e5f0aaSAlex Zinenko /// Conversion pattern that transforms an op into: 140375e5f0aaSAlex Zinenko /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 140475e5f0aaSAlex Zinenko /// 2. Updates to the descriptor to introduce the data ptr, offset, size 140575e5f0aaSAlex Zinenko /// and stride. 140675e5f0aaSAlex Zinenko /// The view op is replaced by the descriptor. 140775e5f0aaSAlex Zinenko struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 140875e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 140975e5f0aaSAlex Zinenko 141075e5f0aaSAlex Zinenko // Build and return the value for the idx^th shape dimension, either by 141175e5f0aaSAlex Zinenko // returning the constant shape dimension or counting the proper dynamic size. 141275e5f0aaSAlex Zinenko Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1413620e2bb2SNicolas Vasilache ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx, 1414620e2bb2SNicolas Vasilache Type indexType) const { 141575e5f0aaSAlex Zinenko assert(idx < shape.size()); 141675e5f0aaSAlex Zinenko if (!ShapedType::isDynamic(shape[idx])) 1417620e2bb2SNicolas Vasilache return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]); 141875e5f0aaSAlex Zinenko // Count the number of dynamic dims in range [0, idx] 1419380a1b20SKazu Hirata unsigned nDynamic = 1420380a1b20SKazu Hirata llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); 142175e5f0aaSAlex Zinenko return dynamicSizes[nDynamic]; 142275e5f0aaSAlex Zinenko } 142375e5f0aaSAlex Zinenko 142475e5f0aaSAlex Zinenko // Build and return the idx^th stride, either by returning the constant stride 142575e5f0aaSAlex Zinenko // or by computing the dynamic stride from the current `runningStride` and 142675e5f0aaSAlex Zinenko // `nextSize`. The caller should keep a running stride and update it with the 142775e5f0aaSAlex Zinenko // result returned by this function. 142875e5f0aaSAlex Zinenko Value getStride(ConversionPatternRewriter &rewriter, Location loc, 142975e5f0aaSAlex Zinenko ArrayRef<int64_t> strides, Value nextSize, 1430620e2bb2SNicolas Vasilache Value runningStride, unsigned idx, Type indexType) const { 143175e5f0aaSAlex Zinenko assert(idx < strides.size()); 1432399638f9SAliia Khasanova if (!ShapedType::isDynamic(strides[idx])) 1433620e2bb2SNicolas Vasilache return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); 143475e5f0aaSAlex Zinenko if (nextSize) 143575e5f0aaSAlex Zinenko return runningStride 143675e5f0aaSAlex Zinenko ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 143775e5f0aaSAlex Zinenko : nextSize; 143875e5f0aaSAlex Zinenko assert(!runningStride); 1439620e2bb2SNicolas Vasilache return createIndexAttrConstant(rewriter, loc, indexType, 1); 144075e5f0aaSAlex Zinenko } 144175e5f0aaSAlex Zinenko 144275e5f0aaSAlex Zinenko LogicalResult 1443ef976337SRiver Riddle matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 144475e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override { 144575e5f0aaSAlex Zinenko auto loc = viewOp.getLoc(); 144675e5f0aaSAlex Zinenko 144775e5f0aaSAlex Zinenko auto viewMemRefType = viewOp.getType(); 144875e5f0aaSAlex Zinenko auto targetElementTy = 144975e5f0aaSAlex Zinenko typeConverter->convertType(viewMemRefType.getElementType()); 145075e5f0aaSAlex Zinenko auto targetDescTy = typeConverter->convertType(viewMemRefType); 145175e5f0aaSAlex Zinenko if (!targetDescTy || !targetElementTy || 145275e5f0aaSAlex Zinenko !LLVM::isCompatibleType(targetElementTy) || 145375e5f0aaSAlex Zinenko !LLVM::isCompatibleType(targetDescTy)) 145475e5f0aaSAlex Zinenko return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 145575e5f0aaSAlex Zinenko failure(); 145675e5f0aaSAlex Zinenko 145775e5f0aaSAlex Zinenko int64_t offset; 145875e5f0aaSAlex Zinenko SmallVector<int64_t, 4> strides; 14596aaa8f25SMatthias Springer auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset); 146075e5f0aaSAlex Zinenko if (failed(successStrides)) 146175e5f0aaSAlex Zinenko return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 146275e5f0aaSAlex Zinenko assert(offset == 0 && "expected offset to be 0"); 146375e5f0aaSAlex Zinenko 1464705f048cSEugene Zhulenev // Target memref must be contiguous in memory (innermost stride is 1), or 1465705f048cSEugene Zhulenev // empty (special case when at least one of the memref dimensions is 0). 1466705f048cSEugene Zhulenev if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1467705f048cSEugene Zhulenev return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1468705f048cSEugene Zhulenev failure(); 1469705f048cSEugene Zhulenev 147075e5f0aaSAlex Zinenko // Create the descriptor. 1471136d746eSJacques Pienaar MemRefDescriptor sourceMemRef(adaptor.getSource()); 147275e5f0aaSAlex Zinenko auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 147375e5f0aaSAlex Zinenko 147475e5f0aaSAlex Zinenko // Field 1: Copy the allocated pointer, used for malloc/free. 147575e5f0aaSAlex Zinenko Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 14765550c821STres Popp auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType()); 1477b28a296cSChristian Ulmann targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr); 147875e5f0aaSAlex Zinenko 147975e5f0aaSAlex Zinenko // Field 2: Copy the actual aligned pointer to payload. 148075e5f0aaSAlex Zinenko Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1481136d746eSJacques Pienaar alignedPtr = rewriter.create<LLVM::GEPOp>( 148250ea17b8SMarkus Böck loc, alignedPtr.getType(), 148350ea17b8SMarkus Böck typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, 148450ea17b8SMarkus Böck adaptor.getByteShift()); 148550ea17b8SMarkus Böck 1486b28a296cSChristian Ulmann targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr); 148775e5f0aaSAlex Zinenko 1488e98e5995SAlex Zinenko Type indexType = getIndexType(); 1489620e2bb2SNicolas Vasilache // Field 3: The offset in the resulting type must be 0. This is 1490620e2bb2SNicolas Vasilache // because of the type change: an offset on srcType* may not be 1491620e2bb2SNicolas Vasilache // expressible as an offset on dstType*. 1492620e2bb2SNicolas Vasilache targetMemRef.setOffset( 1493620e2bb2SNicolas Vasilache rewriter, loc, 1494620e2bb2SNicolas Vasilache createIndexAttrConstant(rewriter, loc, indexType, offset)); 149575e5f0aaSAlex Zinenko 149675e5f0aaSAlex Zinenko // Early exit for 0-D corner case. 149775e5f0aaSAlex Zinenko if (viewMemRefType.getRank() == 0) 149875e5f0aaSAlex Zinenko return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 149975e5f0aaSAlex Zinenko 150075e5f0aaSAlex Zinenko // Fields 4 and 5: Update sizes and strides. 150175e5f0aaSAlex Zinenko Value stride = nullptr, nextSize = nullptr; 150275e5f0aaSAlex Zinenko for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 150375e5f0aaSAlex Zinenko // Update size. 1504136d746eSJacques Pienaar Value size = getSize(rewriter, loc, viewMemRefType.getShape(), 1505620e2bb2SNicolas Vasilache adaptor.getSizes(), i, indexType); 150675e5f0aaSAlex Zinenko targetMemRef.setSize(rewriter, loc, i, size); 150775e5f0aaSAlex Zinenko // Update stride. 1508620e2bb2SNicolas Vasilache stride = 1509620e2bb2SNicolas Vasilache getStride(rewriter, loc, strides, nextSize, stride, i, indexType); 151075e5f0aaSAlex Zinenko targetMemRef.setStride(rewriter, loc, i, stride); 151175e5f0aaSAlex Zinenko nextSize = size; 151275e5f0aaSAlex Zinenko } 151375e5f0aaSAlex Zinenko 151475e5f0aaSAlex Zinenko rewriter.replaceOp(viewOp, {targetMemRef}); 151575e5f0aaSAlex Zinenko return success(); 151675e5f0aaSAlex Zinenko } 151775e5f0aaSAlex Zinenko }; 151875e5f0aaSAlex Zinenko 1519a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===// 1520a6a583daSWilliam S. Moses // AtomicRMWOpLowering 1521a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===// 1522a6a583daSWilliam S. Moses 152323aa5a74SRiver Riddle /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1524a6a583daSWilliam S. Moses /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 15257d2b180eSKazu Hirata static std::optional<LLVM::AtomicBinOp> 1526a6a583daSWilliam S. Moses matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1527136d746eSJacques Pienaar switch (atomicOp.getKind()) { 1528a6a583daSWilliam S. Moses case arith::AtomicRMWKind::addf: 1529a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::fadd; 1530a6a583daSWilliam S. Moses case arith::AtomicRMWKind::addi: 1531a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::add; 1532a6a583daSWilliam S. Moses case arith::AtomicRMWKind::assign: 1533a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::xchg; 1534c46a0433SDaniil Dudkin case arith::AtomicRMWKind::maximumf: 15357db18533SKrzysztof Drewniak return LLVM::AtomicBinOp::fmax; 1536a6a583daSWilliam S. Moses case arith::AtomicRMWKind::maxs: 1537a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::max; 1538a6a583daSWilliam S. Moses case arith::AtomicRMWKind::maxu: 1539a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::umax; 1540c46a0433SDaniil Dudkin case arith::AtomicRMWKind::minimumf: 15417db18533SKrzysztof Drewniak return LLVM::AtomicBinOp::fmin; 1542a6a583daSWilliam S. Moses case arith::AtomicRMWKind::mins: 1543a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::min; 1544a6a583daSWilliam S. Moses case arith::AtomicRMWKind::minu: 1545a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::umin; 1546a6a583daSWilliam S. Moses case arith::AtomicRMWKind::ori: 1547a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::_or; 1548a6a583daSWilliam S. Moses case arith::AtomicRMWKind::andi: 1549a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::_and; 1550a6a583daSWilliam S. Moses default: 15511a36588eSKazu Hirata return std::nullopt; 1552a6a583daSWilliam S. Moses } 1553a6a583daSWilliam S. Moses llvm_unreachable("Invalid AtomicRMWKind"); 1554a6a583daSWilliam S. Moses } 1555a6a583daSWilliam S. Moses 1556a6a583daSWilliam S. Moses struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1557a6a583daSWilliam S. Moses using Base::Base; 1558a6a583daSWilliam S. Moses 1559a6a583daSWilliam S. Moses LogicalResult 1560a6a583daSWilliam S. Moses matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1561a6a583daSWilliam S. Moses ConversionPatternRewriter &rewriter) const override { 1562a6a583daSWilliam S. Moses auto maybeKind = matchSimpleAtomicOp(atomicOp); 1563a6a583daSWilliam S. Moses if (!maybeKind) 1564a6a583daSWilliam S. Moses return failure(); 1565a6a583daSWilliam S. Moses auto memRefType = atomicOp.getMemRefType(); 1566ce6ef990SMax191 SmallVector<int64_t> strides; 1567ce6ef990SMax191 int64_t offset; 15686aaa8f25SMatthias Springer if (failed(memRefType.getStridesAndOffset(strides, offset))) 1569ce6ef990SMax191 return failure(); 1570a6a583daSWilliam S. Moses auto dataPtr = 1571136d746eSJacques Pienaar getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), 1572136d746eSJacques Pienaar adaptor.getIndices(), rewriter); 1573a6a583daSWilliam S. Moses rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 15747f97895fSTobias Gysi atomicOp, *maybeKind, dataPtr, adaptor.getValue(), 1575a6a583daSWilliam S. Moses LLVM::AtomicOrdering::acq_rel); 1576a6a583daSWilliam S. Moses return success(); 1577a6a583daSWilliam S. Moses } 1578a6a583daSWilliam S. Moses }; 1579a6a583daSWilliam S. Moses 158007801f71SNicolas Vasilache /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. 158107801f71SNicolas Vasilache class ConvertExtractAlignedPointerAsIndex 158207801f71SNicolas Vasilache : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> { 158307801f71SNicolas Vasilache public: 158407801f71SNicolas Vasilache using ConvertOpToLLVMPattern< 158507801f71SNicolas Vasilache memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; 158607801f71SNicolas Vasilache 158707801f71SNicolas Vasilache LogicalResult 158807801f71SNicolas Vasilache matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, 158907801f71SNicolas Vasilache OpAdaptor adaptor, 159007801f71SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 1591846103c7SSpenser Bauman BaseMemRefType sourceTy = extractOp.getSource().getType(); 1592846103c7SSpenser Bauman 1593846103c7SSpenser Bauman Value alignedPtr; 1594846103c7SSpenser Bauman if (sourceTy.hasRank()) { 159507801f71SNicolas Vasilache MemRefDescriptor desc(adaptor.getSource()); 1596846103c7SSpenser Bauman alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc()); 1597846103c7SSpenser Bauman } else { 1598846103c7SSpenser Bauman auto elementPtrTy = LLVM::LLVMPointerType::get( 1599846103c7SSpenser Bauman rewriter.getContext(), sourceTy.getMemorySpaceAsInt()); 1600846103c7SSpenser Bauman 1601846103c7SSpenser Bauman UnrankedMemRefDescriptor desc(adaptor.getSource()); 1602846103c7SSpenser Bauman Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc()); 1603846103c7SSpenser Bauman 1604846103c7SSpenser Bauman alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1605846103c7SSpenser Bauman rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr, 1606846103c7SSpenser Bauman elementPtrTy); 1607846103c7SSpenser Bauman } 1608846103c7SSpenser Bauman 160907801f71SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>( 1610846103c7SSpenser Bauman extractOp, getTypeConverter()->getIndexType(), alignedPtr); 161107801f71SNicolas Vasilache return success(); 161207801f71SNicolas Vasilache } 161307801f71SNicolas Vasilache }; 161407801f71SNicolas Vasilache 161555d08a86SQuentin Colombet /// Materialize the MemRef descriptor represented by the results of 161655d08a86SQuentin Colombet /// ExtractStridedMetadataOp. 161755d08a86SQuentin Colombet class ExtractStridedMetadataOpLowering 161855d08a86SQuentin Colombet : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> { 161955d08a86SQuentin Colombet public: 162055d08a86SQuentin Colombet using ConvertOpToLLVMPattern< 162155d08a86SQuentin Colombet memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern; 162255d08a86SQuentin Colombet 162355d08a86SQuentin Colombet LogicalResult 162455d08a86SQuentin Colombet matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, 162555d08a86SQuentin Colombet OpAdaptor adaptor, 162655d08a86SQuentin Colombet ConversionPatternRewriter &rewriter) const override { 162755d08a86SQuentin Colombet 162855d08a86SQuentin Colombet if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 162955d08a86SQuentin Colombet return failure(); 163055d08a86SQuentin Colombet 163155d08a86SQuentin Colombet // Create the descriptor. 1632200266a0SQuentin Colombet MemRefDescriptor sourceMemRef(adaptor.getSource()); 163355d08a86SQuentin Colombet Location loc = extractStridedMetadataOp.getLoc(); 163455d08a86SQuentin Colombet Value source = extractStridedMetadataOp.getSource(); 163555d08a86SQuentin Colombet 16365550c821STres Popp auto sourceMemRefType = cast<MemRefType>(source.getType()); 163755d08a86SQuentin Colombet int64_t rank = sourceMemRefType.getRank(); 163855d08a86SQuentin Colombet SmallVector<Value> results; 163955d08a86SQuentin Colombet results.reserve(2 + rank * 2); 164055d08a86SQuentin Colombet 164155d08a86SQuentin Colombet // Base buffer. 1642200266a0SQuentin Colombet Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc); 1643200266a0SQuentin Colombet Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); 1644200266a0SQuentin Colombet MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( 1645200266a0SQuentin Colombet rewriter, loc, *getTypeConverter(), 16465550c821STres Popp cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()), 1647200266a0SQuentin Colombet baseBuffer, alignedBuffer); 1648200266a0SQuentin Colombet results.push_back((Value)dstMemRef); 164955d08a86SQuentin Colombet 165055d08a86SQuentin Colombet // Offset. 165155d08a86SQuentin Colombet results.push_back(sourceMemRef.offset(rewriter, loc)); 165255d08a86SQuentin Colombet 165355d08a86SQuentin Colombet // Sizes. 165455d08a86SQuentin Colombet for (unsigned i = 0; i < rank; ++i) 165555d08a86SQuentin Colombet results.push_back(sourceMemRef.size(rewriter, loc, i)); 165655d08a86SQuentin Colombet // Strides. 165755d08a86SQuentin Colombet for (unsigned i = 0; i < rank; ++i) 165855d08a86SQuentin Colombet results.push_back(sourceMemRef.stride(rewriter, loc, i)); 165955d08a86SQuentin Colombet 166055d08a86SQuentin Colombet rewriter.replaceOp(extractStridedMetadataOp, results); 166155d08a86SQuentin Colombet return success(); 166255d08a86SQuentin Colombet } 166355d08a86SQuentin Colombet }; 166455d08a86SQuentin Colombet 166575e5f0aaSAlex Zinenko } // namespace 166675e5f0aaSAlex Zinenko 1667fcb0294bSAlex Zinenko void mlir::populateFinalizeMemRefToLLVMConversionPatterns( 1668206fad0eSMatthias Springer const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 166975e5f0aaSAlex Zinenko // clang-format off 167075e5f0aaSAlex Zinenko patterns.add< 167175e5f0aaSAlex Zinenko AllocaOpLowering, 167275e5f0aaSAlex Zinenko AllocaScopeOpLowering, 1673a6a583daSWilliam S. Moses AtomicRMWOpLowering, 167475e5f0aaSAlex Zinenko AssumeAlignmentOpLowering, 167507801f71SNicolas Vasilache ConvertExtractAlignedPointerAsIndex, 167675e5f0aaSAlex Zinenko DimOpLowering, 167755d08a86SQuentin Colombet ExtractStridedMetadataOpLowering, 1678632a4f88SRiver Riddle GenericAtomicRMWOpLowering, 167975e5f0aaSAlex Zinenko GlobalMemrefOpLowering, 168075e5f0aaSAlex Zinenko GetGlobalMemrefOpLowering, 168175e5f0aaSAlex Zinenko LoadOpLowering, 168275e5f0aaSAlex Zinenko MemRefCastOpLowering, 1683fcb0294bSAlex Zinenko MemRefCopyOpLowering, 16847fb9bbe5SKrzysztof Drewniak MemorySpaceCastOpLowering, 168575e5f0aaSAlex Zinenko MemRefReinterpretCastOpLowering, 168675e5f0aaSAlex Zinenko MemRefReshapeOpLowering, 168775e5f0aaSAlex Zinenko PrefetchOpLowering, 168815f8f3e2SAlexander Belyaev RankOpLowering, 168946ef86b5SAlexander Belyaev ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 169046ef86b5SAlexander Belyaev ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 169175e5f0aaSAlex Zinenko StoreOpLowering, 169275e5f0aaSAlex Zinenko SubViewOpLowering, 169375e5f0aaSAlex Zinenko TransposeOpLowering, 169475e5f0aaSAlex Zinenko ViewOpLowering>(converter); 169575e5f0aaSAlex Zinenko // clang-format on 169675e5f0aaSAlex Zinenko auto allocLowering = converter.getOptions().allocLowering; 169775e5f0aaSAlex Zinenko if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 16988037deb7SMartin Erhart patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 169975e5f0aaSAlex Zinenko else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 17008037deb7SMartin Erhart patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 170175e5f0aaSAlex Zinenko } 170275e5f0aaSAlex Zinenko 170375e5f0aaSAlex Zinenko namespace { 1704cb4ccd38SQuentin Colombet struct FinalizeMemRefToLLVMConversionPass 1705cb4ccd38SQuentin Colombet : public impl::FinalizeMemRefToLLVMConversionPassBase< 1706cb4ccd38SQuentin Colombet FinalizeMemRefToLLVMConversionPass> { 1707cb4ccd38SQuentin Colombet using FinalizeMemRefToLLVMConversionPassBase:: 1708cb4ccd38SQuentin Colombet FinalizeMemRefToLLVMConversionPassBase; 170975e5f0aaSAlex Zinenko 171075e5f0aaSAlex Zinenko void runOnOperation() override { 171175e5f0aaSAlex Zinenko Operation *op = getOperation(); 171275e5f0aaSAlex Zinenko const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 171375e5f0aaSAlex Zinenko LowerToLLVMOptions options(&getContext(), 171475e5f0aaSAlex Zinenko dataLayoutAnalysis.getAtOrAbove(op)); 171575e5f0aaSAlex Zinenko options.allocLowering = 171675e5f0aaSAlex Zinenko (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 171775e5f0aaSAlex Zinenko : LowerToLLVMOptions::AllocLowering::Malloc); 1718a8601f11SMichele Scuttari 1719a8601f11SMichele Scuttari options.useGenericFunctions = useGenericFunctions; 1720a8601f11SMichele Scuttari 172175e5f0aaSAlex Zinenko if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 172275e5f0aaSAlex Zinenko options.overrideIndexBitwidth(indexBitwidth); 172375e5f0aaSAlex Zinenko 172475e5f0aaSAlex Zinenko LLVMTypeConverter typeConverter(&getContext(), options, 172575e5f0aaSAlex Zinenko &dataLayoutAnalysis); 172675e5f0aaSAlex Zinenko RewritePatternSet patterns(&getContext()); 1727cb4ccd38SQuentin Colombet populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); 172875e5f0aaSAlex Zinenko LLVMConversionTarget target(getContext()); 172958ceae95SRiver Riddle target.addLegalOp<func::FuncOp>(); 173075e5f0aaSAlex Zinenko if (failed(applyPartialConversion(op, target, std::move(patterns)))) 173175e5f0aaSAlex Zinenko signalPassFailure(); 173275e5f0aaSAlex Zinenko } 173375e5f0aaSAlex Zinenko }; 1734876a480cSMatthias Springer 1735876a480cSMatthias Springer /// Implement the interface to convert MemRef to LLVM. 1736876a480cSMatthias Springer struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 1737876a480cSMatthias Springer using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 1738876a480cSMatthias Springer void loadDependentDialects(MLIRContext *context) const final { 1739876a480cSMatthias Springer context->loadDialect<LLVM::LLVMDialect>(); 1740876a480cSMatthias Springer } 1741876a480cSMatthias Springer 1742876a480cSMatthias Springer /// Hook for derived dialect interface to provide conversion patterns 1743876a480cSMatthias Springer /// and mark dialect legal for the conversion target. 1744876a480cSMatthias Springer void populateConvertToLLVMConversionPatterns( 1745876a480cSMatthias Springer ConversionTarget &target, LLVMTypeConverter &typeConverter, 1746876a480cSMatthias Springer RewritePatternSet &patterns) const final { 1747876a480cSMatthias Springer populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); 1748876a480cSMatthias Springer } 1749876a480cSMatthias Springer }; 1750876a480cSMatthias Springer 175175e5f0aaSAlex Zinenko } // namespace 1752876a480cSMatthias Springer 1753876a480cSMatthias Springer void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry ®istry) { 1754876a480cSMatthias Springer registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { 1755876a480cSMatthias Springer dialect->addInterfaces<MemRefToLLVMDialectInterface>(); 1756876a480cSMatthias Springer }); 1757876a480cSMatthias Springer } 1758