157470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 257470abcSAlexander Belyaev // 357470abcSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 457470abcSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information. 557470abcSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 657470abcSAlexander Belyaev // 757470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 857470abcSAlexander Belyaev 9abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 10ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 129a3d60e0SAart Bik #include "mlir/Dialect/Func/IR/FuncOps.h" 13eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 149a3d60e0SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h" 16ec55f0bdSMatthias Springer #include "mlir/IR/Matchers.h" 17a1fe1f5fSKazu Hirata #include <optional> 1857470abcSAlexander Belyaev 1957470abcSAlexander Belyaev using namespace mlir; 2057470abcSAlexander Belyaev using namespace mlir::bufferization; 2157470abcSAlexander Belyaev 2257470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 23fa7c8cb4SMatthias Springer // Helper functions 24fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===// 25fa7c8cb4SMatthias Springer 26c515c780SMatthias Gehre FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue( 27c515c780SMatthias Gehre OpBuilder &b, Value value, MemRefType destType, 28c515c780SMatthias Gehre const BufferizationOptions &options) { 29c1fa60b4STres Popp auto srcType = llvm::cast<MemRefType>(value.getType()); 30fa7c8cb4SMatthias Springer 31fa7c8cb4SMatthias Springer // Element type, rank and memory space must match. 32fa7c8cb4SMatthias Springer if (srcType.getElementType() != destType.getElementType()) 33fa7c8cb4SMatthias Springer return failure(); 345d04f0c9SMatthias Springer if (srcType.getMemorySpace() != destType.getMemorySpace()) 35fa7c8cb4SMatthias Springer return failure(); 36fa7c8cb4SMatthias Springer if (srcType.getRank() != destType.getRank()) 37fa7c8cb4SMatthias Springer return failure(); 38fa7c8cb4SMatthias Springer 39fa7c8cb4SMatthias Springer // In case the affine maps are different, we may need to use a copy if we go 40fa7c8cb4SMatthias Springer // from dynamic to static offset or stride (the canonicalization cannot know 41fa7c8cb4SMatthias Springer // at this point that it is really cast compatible). 42fa7c8cb4SMatthias Springer auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 43fa7c8cb4SMatthias Springer int64_t sourceOffset, targetOffset; 44fa7c8cb4SMatthias Springer SmallVector<int64_t, 4> sourceStrides, targetStrides; 45*6aaa8f25SMatthias Springer if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) || 46*6aaa8f25SMatthias Springer failed(target.getStridesAndOffset(targetStrides, targetOffset))) 47fa7c8cb4SMatthias Springer return false; 48fa7c8cb4SMatthias Springer auto dynamicToStatic = [](int64_t a, int64_t b) { 49399638f9SAliia Khasanova return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b); 50fa7c8cb4SMatthias Springer }; 51fa7c8cb4SMatthias Springer if (dynamicToStatic(sourceOffset, targetOffset)) 52fa7c8cb4SMatthias Springer return false; 53fa7c8cb4SMatthias Springer for (auto it : zip(sourceStrides, targetStrides)) 54fa7c8cb4SMatthias Springer if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 55fa7c8cb4SMatthias Springer return false; 56fa7c8cb4SMatthias Springer return true; 57fa7c8cb4SMatthias Springer }; 58fa7c8cb4SMatthias Springer 59fa7c8cb4SMatthias Springer // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 60fa7c8cb4SMatthias Springer // ensure that we only generate casts that always succeed at runtime, we check 61fa7c8cb4SMatthias Springer // a fix extra conditions in `isGuaranteedCastCompatible`. 62fa7c8cb4SMatthias Springer if (memref::CastOp::areCastCompatible(srcType, destType) && 63fa7c8cb4SMatthias Springer isGuaranteedCastCompatible(srcType, destType)) { 64fa7c8cb4SMatthias Springer Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 65fa7c8cb4SMatthias Springer return casted; 66fa7c8cb4SMatthias Springer } 67fa7c8cb4SMatthias Springer 68fa7c8cb4SMatthias Springer auto loc = value.getLoc(); 69fa7c8cb4SMatthias Springer SmallVector<Value, 4> dynamicOperands; 70fa7c8cb4SMatthias Springer for (int i = 0; i < destType.getRank(); ++i) { 71399638f9SAliia Khasanova if (destType.getShape()[i] != ShapedType::kDynamic) 72fa7c8cb4SMatthias Springer continue; 73b23c8225SMatthias Springer Value size = b.create<memref::DimOp>(loc, value, i); 74fa7c8cb4SMatthias Springer dynamicOperands.push_back(size); 75fa7c8cb4SMatthias Springer } 76c515c780SMatthias Gehre 77c515c780SMatthias Gehre FailureOr<Value> copy = 78c515c780SMatthias Gehre options.createAlloc(b, loc, destType, dynamicOperands); 79c515c780SMatthias Gehre if (failed(copy)) 80c515c780SMatthias Gehre return failure(); 81c515c780SMatthias Gehre if (failed(options.createMemCpy(b, loc, value, *copy))) 82c515c780SMatthias Gehre return failure(); 83fa7c8cb4SMatthias Springer return copy; 84fa7c8cb4SMatthias Springer } 85fa7c8cb4SMatthias Springer 86d820acddSMatthias Springer /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 87d820acddSMatthias Springer /// to_memref op are different, a memref.cast is needed. 88c515c780SMatthias Gehre LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 89c515c780SMatthias Gehre RewriterBase &rewriter, ToMemrefOp toMemref, 90c515c780SMatthias Gehre const BufferizationOptions &options) { 9199260e95SMatthias Springer auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); 92d820acddSMatthias Springer if (!memrefToTensor) 93d820acddSMatthias Springer return failure(); 94d820acddSMatthias Springer 9599260e95SMatthias Springer Type srcType = memrefToTensor.getMemref().getType(); 96d820acddSMatthias Springer Type destType = toMemref.getType(); 97d820acddSMatthias Springer 98d820acddSMatthias Springer // Directly rewrite if the type did not change. 99d820acddSMatthias Springer if (srcType == destType) { 10099260e95SMatthias Springer rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); 101d820acddSMatthias Springer return success(); 102d820acddSMatthias Springer } 103d820acddSMatthias Springer 104c1fa60b4STres Popp auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType); 105c1fa60b4STres Popp auto rankedDestType = llvm::dyn_cast<MemRefType>(destType); 106c1fa60b4STres Popp auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType); 107d820acddSMatthias Springer 108d820acddSMatthias Springer // Ranked memref -> Ranked memref cast. 109d820acddSMatthias Springer if (rankedSrcType && rankedDestType) { 110d820acddSMatthias Springer FailureOr<Value> replacement = castOrReallocMemRefValue( 111c515c780SMatthias Gehre rewriter, memrefToTensor.getMemref(), rankedDestType, options); 112d820acddSMatthias Springer if (failed(replacement)) 113d820acddSMatthias Springer return failure(); 114d820acddSMatthias Springer 115d820acddSMatthias Springer rewriter.replaceOp(toMemref, *replacement); 116d820acddSMatthias Springer return success(); 117d820acddSMatthias Springer } 118d820acddSMatthias Springer 119d820acddSMatthias Springer // Unranked memref -> Ranked memref cast: May require a copy. 120d820acddSMatthias Springer // TODO: Not implemented at the moment. 121d820acddSMatthias Springer if (unrankedSrcType && rankedDestType) 122d820acddSMatthias Springer return failure(); 123d820acddSMatthias Springer 124d820acddSMatthias Springer // Unranked memref -> unranked memref cast 125d820acddSMatthias Springer // Ranked memref -> unranked memref cast: No copy needed. 126d820acddSMatthias Springer assert(memref::CastOp::areCastCompatible(srcType, destType) && 127d820acddSMatthias Springer "expected that types are cast compatible"); 128d820acddSMatthias Springer rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 12999260e95SMatthias Springer memrefToTensor.getMemref()); 130d820acddSMatthias Springer return success(); 131d820acddSMatthias Springer } 132d820acddSMatthias Springer 133b3ebe3beSMatthias Springer void mlir::bufferization::populateDynamicDimSizes( 134b3ebe3beSMatthias Springer OpBuilder &b, Location loc, Value shapedValue, 135b3ebe3beSMatthias Springer SmallVector<Value> &dynamicDims) { 136c1fa60b4STres Popp auto shapedType = llvm::cast<ShapedType>(shapedValue.getType()); 137b3ebe3beSMatthias Springer for (int64_t i = 0; i < shapedType.getRank(); ++i) { 138b3ebe3beSMatthias Springer if (shapedType.isDynamicDim(i)) { 139c1fa60b4STres Popp if (llvm::isa<MemRefType>(shapedType)) { 140b3ebe3beSMatthias Springer dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); 141b3ebe3beSMatthias Springer } else { 142c1fa60b4STres Popp assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor"); 143b3ebe3beSMatthias Springer dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 144b3ebe3beSMatthias Springer } 145b3ebe3beSMatthias Springer } 146b3ebe3beSMatthias Springer } 147b3ebe3beSMatthias Springer } 148b3ebe3beSMatthias Springer 149fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===// 150ffdbecccSMatthias Springer // AllocTensorOp 151ffdbecccSMatthias Springer //===----------------------------------------------------------------------===// 152ffdbecccSMatthias Springer 153ffdbecccSMatthias Springer LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 154b55d55ecSMatthias Springer const BufferizationOptions &options) { 155b3ebe3beSMatthias Springer OpBuilder::InsertionGuard g(rewriter); 156b3ebe3beSMatthias Springer Location loc = getLoc(); 157ffdbecccSMatthias Springer 158b3ebe3beSMatthias Springer // Nothing to do for dead AllocTensorOps. 159b3ebe3beSMatthias Springer if (getOperation()->getUses().empty()) { 160b3ebe3beSMatthias Springer rewriter.eraseOp(getOperation()); 161b3ebe3beSMatthias Springer return success(); 162b3ebe3beSMatthias Springer } 163b3ebe3beSMatthias Springer 164c06f01ffSMatthias Springer // Get "copy" buffer. 165b3ebe3beSMatthias Springer Value copyBuffer; 1665d50f51cSMatthias Springer if (getCopy()) { 1675d50f51cSMatthias Springer FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options); 1685d50f51cSMatthias Springer if (failed(maybeCopyBuffer)) 1695d50f51cSMatthias Springer return failure(); 1705d50f51cSMatthias Springer copyBuffer = *maybeCopyBuffer; 1715d50f51cSMatthias Springer } 172c06f01ffSMatthias Springer 173c06f01ffSMatthias Springer // Create memory allocation. 174123c4b02SMatthias Springer auto allocType = bufferization::getBufferType(getResult(), options); 175111c9196SMatthias Springer if (failed(allocType)) 176111c9196SMatthias Springer return failure(); 17799260e95SMatthias Springer SmallVector<Value> dynamicDims = getDynamicSizes(); 17899260e95SMatthias Springer if (getCopy()) { 179b3ebe3beSMatthias Springer assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); 180b3ebe3beSMatthias Springer populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); 181b3ebe3beSMatthias Springer } 182111c9196SMatthias Springer FailureOr<Value> alloc = options.createAlloc( 18368f58812STres Popp rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims); 184ffdbecccSMatthias Springer if (failed(alloc)) 185ffdbecccSMatthias Springer return failure(); 186b3ebe3beSMatthias Springer 187b3ebe3beSMatthias Springer // Create memory copy (if any). 18899260e95SMatthias Springer if (getCopy()) { 189b55d55ecSMatthias Springer if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) 19056d68e8dSMatthias Springer return failure(); 19156d68e8dSMatthias Springer } 192b3ebe3beSMatthias Springer 193b3ebe3beSMatthias Springer // Replace op. 194ffdbecccSMatthias Springer replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 195b3ebe3beSMatthias Springer 196ffdbecccSMatthias Springer return success(); 197ffdbecccSMatthias Springer } 198ffdbecccSMatthias Springer 19934d65e81SMatthias Springer bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult, 20056d68e8dSMatthias Springer const AnalysisState &state) { 20156d68e8dSMatthias Springer // AllocTensorOps do not write unless they have a `copy` value. 20299260e95SMatthias Springer return static_cast<bool>(getCopy()); 20356d68e8dSMatthias Springer } 20456d68e8dSMatthias Springer 20556d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, 20656d68e8dSMatthias Springer const AnalysisState &state) { 20756d68e8dSMatthias Springer assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 20856d68e8dSMatthias Springer "expected copy operand"); 20956d68e8dSMatthias Springer return true; 21056d68e8dSMatthias Springer } 21156d68e8dSMatthias Springer 21256d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, 21356d68e8dSMatthias Springer const AnalysisState &state) { 21456d68e8dSMatthias Springer assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 21556d68e8dSMatthias Springer "expected copy operand"); 21656d68e8dSMatthias Springer return false; 21756d68e8dSMatthias Springer } 21856d68e8dSMatthias Springer 219a02ad6c1SMatthias Springer AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand, 22056d68e8dSMatthias Springer const AnalysisState &state) { 22156d68e8dSMatthias Springer // This is a new allocation. It does not alias with any other buffer. 22256d68e8dSMatthias Springer return {}; 22356d68e8dSMatthias Springer } 22456d68e8dSMatthias Springer 225878950b8SMatthias Springer FailureOr<BaseMemRefType> 226878950b8SMatthias Springer AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, 227878950b8SMatthias Springer SmallVector<Value> &invocationStack) { 228111c9196SMatthias Springer assert(value == getResult() && "invalid value"); 229111c9196SMatthias Springer 230111c9196SMatthias Springer // Compute memory space of this allocation. 2319bb63374SLei Zhang Attribute memorySpace; 232111c9196SMatthias Springer if (getMemorySpace().has_value()) { 233111c9196SMatthias Springer memorySpace = *getMemorySpace(); 234111c9196SMatthias Springer } else if (getCopy()) { 235123c4b02SMatthias Springer auto copyBufferType = 236878950b8SMatthias Springer bufferization::getBufferType(getCopy(), options, invocationStack); 237111c9196SMatthias Springer if (failed(copyBufferType)) 238111c9196SMatthias Springer return failure(); 2399bb63374SLei Zhang memorySpace = copyBufferType->getMemorySpace(); 240067d2779Sian Bearman } else if (auto ms = options.defaultMemorySpaceFn(getType())) { 241067d2779Sian Bearman memorySpace = *ms; 242111c9196SMatthias Springer } else { 243111c9196SMatthias Springer return getOperation()->emitError("could not infer memory space"); 244111c9196SMatthias Springer } 245111c9196SMatthias Springer 246111c9196SMatthias Springer return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace); 247111c9196SMatthias Springer } 248111c9196SMatthias Springer 249ffdbecccSMatthias Springer LogicalResult AllocTensorOp::verify() { 25099260e95SMatthias Springer if (getCopy() && !getDynamicSizes().empty()) 25156d68e8dSMatthias Springer return emitError("dynamic sizes not needed when copying a tensor"); 252bfde1783SAndrzej Warzyński if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size()) 253ec55f0bdSMatthias Springer return emitError("expected ") 254ec55f0bdSMatthias Springer << getType().getNumDynamicDims() << " dynamic sizes"; 25599260e95SMatthias Springer if (getCopy() && getCopy().getType() != getType()) 25656d68e8dSMatthias Springer return emitError("expected that `copy` and return type match"); 257ffdbecccSMatthias Springer return success(); 258ffdbecccSMatthias Springer } 259ffdbecccSMatthias Springer 26056d68e8dSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 26156d68e8dSMatthias Springer RankedTensorType type, ValueRange dynamicSizes) { 262c06f01ffSMatthias Springer build(builder, result, type, dynamicSizes, /*copy=*/Value(), 26326ef3868Sbixia1 /*size_hint=*/Value(), 2640d0a94a7SMatthias Springer /*memory_space=*/IntegerAttr()); 265c06f01ffSMatthias Springer } 266c06f01ffSMatthias Springer 267c06f01ffSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 268c06f01ffSMatthias Springer RankedTensorType type, ValueRange dynamicSizes, 269c06f01ffSMatthias Springer Value copy) { 27026ef3868Sbixia1 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(), 2710d0a94a7SMatthias Springer /*memory_space=*/IntegerAttr()); 27256d68e8dSMatthias Springer } 27356d68e8dSMatthias Springer 27426ef3868Sbixia1 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 27526ef3868Sbixia1 TensorType type, ValueRange dynamicSizes, Value copy, 27626ef3868Sbixia1 IntegerAttr memorySpace) { 27726ef3868Sbixia1 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(), 27826ef3868Sbixia1 memorySpace); 27926ef3868Sbixia1 } 28026ef3868Sbixia1 281ffdbecccSMatthias Springer namespace { 282ffdbecccSMatthias Springer /// Change the type of the result of a `bufferization.alloc_tensor` by making 283ffdbecccSMatthias Springer /// the result type statically sized along dimension that in the original 284ffdbecccSMatthias Springer /// operation where defined as dynamic, but the size was defined using a 285ffdbecccSMatthias Springer /// `constant` op. For example: 286ffdbecccSMatthias Springer /// 287ffdbecccSMatthias Springer /// %c5 = arith.constant 5: index 288ec55f0bdSMatthias Springer /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 289ffdbecccSMatthias Springer /// 290ffdbecccSMatthias Springer /// to 291ffdbecccSMatthias Springer /// 292ec55f0bdSMatthias Springer /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 293ffdbecccSMatthias Springer struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 294ffdbecccSMatthias Springer using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 295ffdbecccSMatthias Springer 296ffdbecccSMatthias Springer LogicalResult matchAndRewrite(AllocTensorOp op, 297ffdbecccSMatthias Springer PatternRewriter &rewriter) const override { 29899260e95SMatthias Springer if (op.getCopy()) 29956d68e8dSMatthias Springer return failure(); 300ec55f0bdSMatthias Springer SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 301ec55f0bdSMatthias Springer SmallVector<Value> newDynamicSizes; 302ec55f0bdSMatthias Springer unsigned int dynValCounter = 0; 303ec55f0bdSMatthias Springer for (int64_t i = 0; i < op.getType().getRank(); ++i) { 304ec55f0bdSMatthias Springer if (!op.isDynamicDim(i)) 305ffdbecccSMatthias Springer continue; 30699260e95SMatthias Springer Value value = op.getDynamicSizes()[dynValCounter++]; 307ec55f0bdSMatthias Springer APInt intVal; 308ec55f0bdSMatthias Springer if (matchPattern(value, m_ConstantInt(&intVal))) { 3095a71f7a4SMehdi Amini int64_t dim = intVal.getSExtValue(); 3105a71f7a4SMehdi Amini if (dim >= 0) 311ec55f0bdSMatthias Springer newShape[i] = intVal.getSExtValue(); 3125a71f7a4SMehdi Amini else 3135a71f7a4SMehdi Amini newDynamicSizes.push_back(value); 314ec55f0bdSMatthias Springer } else { 315ec55f0bdSMatthias Springer newDynamicSizes.push_back(value); 316ffdbecccSMatthias Springer } 317ffdbecccSMatthias Springer } 318ec55f0bdSMatthias Springer RankedTensorType newType = RankedTensorType::get( 319ec55f0bdSMatthias Springer newShape, op.getType().getElementType(), op.getType().getEncoding()); 320ffdbecccSMatthias Springer if (newType == op.getType()) 321ffdbecccSMatthias Springer return failure(); 32256d68e8dSMatthias Springer auto newOp = rewriter.create<AllocTensorOp>( 3233474d10eSMatthias Springer op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); 324ffdbecccSMatthias Springer rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 325ffdbecccSMatthias Springer return success(); 326ffdbecccSMatthias Springer } 327ffdbecccSMatthias Springer }; 328ffdbecccSMatthias Springer 329ffdbecccSMatthias Springer struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 330ffdbecccSMatthias Springer using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 331ffdbecccSMatthias Springer 332ffdbecccSMatthias Springer LogicalResult matchAndRewrite(tensor::DimOp dimOp, 333ffdbecccSMatthias Springer PatternRewriter &rewriter) const override { 33422426110SRamkumar Ramachandra std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 33504235d07SJacques Pienaar auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>(); 336ffdbecccSMatthias Springer if (!allocTensorOp || !maybeConstantIndex) 337ffdbecccSMatthias Springer return failure(); 3384bb9f918SJianbang Yang if (*maybeConstantIndex < 0 || 3394bb9f918SJianbang Yang *maybeConstantIndex >= allocTensorOp.getType().getRank()) 3404bb9f918SJianbang Yang return failure(); 341ec55f0bdSMatthias Springer if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 342ffdbecccSMatthias Springer return failure(); 34356d68e8dSMatthias Springer rewriter.replaceOp( 34456d68e8dSMatthias Springer dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); 345ffdbecccSMatthias Springer return success(); 346ffdbecccSMatthias Springer } 347ffdbecccSMatthias Springer }; 348ffdbecccSMatthias Springer } // namespace 349ffdbecccSMatthias Springer 350ffdbecccSMatthias Springer void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 351ffdbecccSMatthias Springer MLIRContext *ctx) { 352ffdbecccSMatthias Springer results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 353ffdbecccSMatthias Springer } 354ffdbecccSMatthias Springer 355ffdbecccSMatthias Springer LogicalResult AllocTensorOp::reifyResultShapes( 356ffdbecccSMatthias Springer OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 3572a5b13e7SMatthias Springer auto shapes = llvm::to_vector<4>( 3582a5b13e7SMatthias Springer llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()), 3592a5b13e7SMatthias Springer [&](int64_t dim) -> OpFoldResult { 360ec55f0bdSMatthias Springer if (isDynamicDim(dim)) 36156d68e8dSMatthias Springer return getDynamicSize(builder, dim); 3622a5b13e7SMatthias Springer return builder.getIndexAttr(getStaticSize(dim)); 363ffdbecccSMatthias Springer })); 364ffdbecccSMatthias Springer reifiedReturnShapes.emplace_back(std::move(shapes)); 365ffdbecccSMatthias Springer return success(); 366ffdbecccSMatthias Springer } 367ffdbecccSMatthias Springer 36856d68e8dSMatthias Springer ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { 36956d68e8dSMatthias Springer SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands; 37056d68e8dSMatthias Springer if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || 37156d68e8dSMatthias Springer parser.parseRParen()) 37256d68e8dSMatthias Springer return failure(); 37356d68e8dSMatthias Springer ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); 37456d68e8dSMatthias Springer OpAsmParser::UnresolvedOperand copyOperand; 37556d68e8dSMatthias Springer if (copyKeyword.succeeded()) 37656d68e8dSMatthias Springer if (parser.parseLParen() || parser.parseOperand(copyOperand) || 37756d68e8dSMatthias Springer parser.parseRParen()) 37856d68e8dSMatthias Springer return failure(); 37926ef3868Sbixia1 ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint"); 38026ef3868Sbixia1 OpAsmParser::UnresolvedOperand sizeHintOperand; 38126ef3868Sbixia1 if (sizeHintKeyword.succeeded()) 38226ef3868Sbixia1 if (parser.parseEqual() || parser.parseOperand(sizeHintOperand)) 38326ef3868Sbixia1 return failure(); 38456d68e8dSMatthias Springer if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) 38556d68e8dSMatthias Springer return failure(); 38656d68e8dSMatthias Springer 38756d68e8dSMatthias Springer TensorType type; 38856d68e8dSMatthias Springer if (parser.parseCustomTypeWithFallback(type)) 38956d68e8dSMatthias Springer return failure(); 39056d68e8dSMatthias Springer result.addTypes(type); 39156d68e8dSMatthias Springer 39256d68e8dSMatthias Springer Type indexType = parser.getBuilder().getIndexType(); 39356d68e8dSMatthias Springer if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) 39456d68e8dSMatthias Springer return failure(); 39556d68e8dSMatthias Springer if (copyKeyword.succeeded()) 39656d68e8dSMatthias Springer if (parser.resolveOperand(copyOperand, type, result.operands)) 39756d68e8dSMatthias Springer return failure(); 39826ef3868Sbixia1 if (sizeHintKeyword.succeeded()) 39926ef3868Sbixia1 if (parser.resolveOperand(sizeHintOperand, indexType, result.operands)) 40026ef3868Sbixia1 return failure(); 40156d68e8dSMatthias Springer result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), 40258a47508SJeff Niu parser.getBuilder().getDenseI32ArrayAttr( 40356d68e8dSMatthias Springer {static_cast<int32_t>(dynamicSizesOperands.size()), 40426ef3868Sbixia1 static_cast<int32_t>(copyKeyword.succeeded()), 40526ef3868Sbixia1 static_cast<int32_t>(sizeHintKeyword.succeeded())})); 40656d68e8dSMatthias Springer return success(); 40756d68e8dSMatthias Springer } 40856d68e8dSMatthias Springer 40956d68e8dSMatthias Springer void AllocTensorOp::print(OpAsmPrinter &p) { 41099260e95SMatthias Springer p << "(" << getDynamicSizes() << ")"; 41199260e95SMatthias Springer if (getCopy()) 41299260e95SMatthias Springer p << " copy(" << getCopy() << ")"; 41326ef3868Sbixia1 if (getSizeHint()) 41426ef3868Sbixia1 p << " size_hint=" << getSizeHint(); 41556d68e8dSMatthias Springer p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 41656d68e8dSMatthias Springer AllocTensorOp::getOperandSegmentSizeAttr()}); 41756d68e8dSMatthias Springer p << " : "; 41899260e95SMatthias Springer auto type = getResult().getType(); 419c1fa60b4STres Popp if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type)) 42056d68e8dSMatthias Springer p.printStrippedAttrOrType(validType); 42156d68e8dSMatthias Springer else 42256d68e8dSMatthias Springer p << type; 42356d68e8dSMatthias Springer } 42456d68e8dSMatthias Springer 42556d68e8dSMatthias Springer Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { 42656d68e8dSMatthias Springer assert(isDynamicDim(idx) && "expected dynamic dim"); 42799260e95SMatthias Springer if (getCopy()) 42899260e95SMatthias Springer return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); 42956d68e8dSMatthias Springer return getOperand(getIndexOfDynamicSize(idx)); 43056d68e8dSMatthias Springer } 43156d68e8dSMatthias Springer 432ffdbecccSMatthias Springer //===----------------------------------------------------------------------===// 43357470abcSAlexander Belyaev // CloneOp 43457470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 43557470abcSAlexander Belyaev 4367df76121SMarkus Böck OpFoldResult CloneOp::fold(FoldAdaptor adaptor) { 43757470abcSAlexander Belyaev return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 43857470abcSAlexander Belyaev } 43957470abcSAlexander Belyaev 44057470abcSAlexander Belyaev namespace { 44157470abcSAlexander Belyaev 44257470abcSAlexander Belyaev /// Merge the clone and its source (by converting the clone to a cast) when 44357470abcSAlexander Belyaev /// possible. 44457470abcSAlexander Belyaev struct SimplifyClones : public OpRewritePattern<CloneOp> { 44557470abcSAlexander Belyaev using OpRewritePattern<CloneOp>::OpRewritePattern; 44657470abcSAlexander Belyaev 44757470abcSAlexander Belyaev LogicalResult matchAndRewrite(CloneOp cloneOp, 44857470abcSAlexander Belyaev PatternRewriter &rewriter) const override { 44957470abcSAlexander Belyaev if (cloneOp.use_empty()) { 45057470abcSAlexander Belyaev rewriter.eraseOp(cloneOp); 45157470abcSAlexander Belyaev return success(); 45257470abcSAlexander Belyaev } 45357470abcSAlexander Belyaev 45499260e95SMatthias Springer Value source = cloneOp.getInput(); 455eaa4b6cfSdonald chen if (source.getType() != cloneOp.getType() && 456eaa4b6cfSdonald chen !memref::CastOp::areCastCompatible({source.getType()}, 457eaa4b6cfSdonald chen {cloneOp.getType()})) 458eaa4b6cfSdonald chen return failure(); 459eaa4b6cfSdonald chen 460894e8a54Sroot // Aims to find the dealloc op for the canonical source 461894e8a54Sroot // which otherwise could prevent removal of unnecessary allocs. 462894e8a54Sroot Value canonicalSource = source; 463894e8a54Sroot while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>( 464894e8a54Sroot canonicalSource.getDefiningOp())) 465894e8a54Sroot canonicalSource = iface.getViewSource(); 46657470abcSAlexander Belyaev 4670a81ace0SKazu Hirata std::optional<Operation *> maybeCloneDeallocOp = 46899260e95SMatthias Springer memref::findDealloc(cloneOp.getOutput()); 46957470abcSAlexander Belyaev // Skip if either of them has > 1 deallocate operations. 470491d2701SKazu Hirata if (!maybeCloneDeallocOp.has_value()) 47157470abcSAlexander Belyaev return failure(); 4720a81ace0SKazu Hirata std::optional<Operation *> maybeSourceDeallocOp = 473894e8a54Sroot memref::findDealloc(canonicalSource); 474491d2701SKazu Hirata if (!maybeSourceDeallocOp.has_value()) 47557470abcSAlexander Belyaev return failure(); 47657470abcSAlexander Belyaev Operation *cloneDeallocOp = *maybeCloneDeallocOp; 47757470abcSAlexander Belyaev Operation *sourceDeallocOp = *maybeSourceDeallocOp; 47857470abcSAlexander Belyaev 47957470abcSAlexander Belyaev // If both are deallocated in the same block, their in-block lifetimes 48057470abcSAlexander Belyaev // might not fully overlap, so we cannot decide which one to drop. 48157470abcSAlexander Belyaev if (cloneDeallocOp && sourceDeallocOp && 48257470abcSAlexander Belyaev cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 48357470abcSAlexander Belyaev return failure(); 48457470abcSAlexander Belyaev 48557470abcSAlexander Belyaev Block *currentBlock = cloneOp->getBlock(); 48657470abcSAlexander Belyaev Operation *redundantDealloc = nullptr; 48757470abcSAlexander Belyaev if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 48857470abcSAlexander Belyaev redundantDealloc = cloneDeallocOp; 48957470abcSAlexander Belyaev } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 49057470abcSAlexander Belyaev redundantDealloc = sourceDeallocOp; 49157470abcSAlexander Belyaev } 49257470abcSAlexander Belyaev 49357470abcSAlexander Belyaev if (!redundantDealloc) 49457470abcSAlexander Belyaev return failure(); 49557470abcSAlexander Belyaev 49657470abcSAlexander Belyaev // Safety check that there are no other deallocations inbetween 49757470abcSAlexander Belyaev // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 49857470abcSAlexander Belyaev // of source before the uses of the clone. With alias information, we could 49957470abcSAlexander Belyaev // restrict this to only fail of the dealloc's operand is an alias 50057470abcSAlexander Belyaev // of the source. 50157470abcSAlexander Belyaev for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 50257470abcSAlexander Belyaev pos = pos->getNextNode()) { 503bcd14b09SKohei Yamaguchi // Bail if we run out of operations while looking for a deallocation op. 504bcd14b09SKohei Yamaguchi if (!pos) 505bcd14b09SKohei Yamaguchi return failure(); 50657470abcSAlexander Belyaev auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 50757470abcSAlexander Belyaev if (!effectInterface) 50857470abcSAlexander Belyaev continue; 50957470abcSAlexander Belyaev if (effectInterface.hasEffect<MemoryEffects::Free>()) 51057470abcSAlexander Belyaev return failure(); 51157470abcSAlexander Belyaev } 51257470abcSAlexander Belyaev 51368f91cd2SMatthias Springer if (source.getType() != cloneOp.getType()) 51468f91cd2SMatthias Springer source = rewriter.create<memref::CastOp>(cloneOp.getLoc(), 51568f91cd2SMatthias Springer cloneOp.getType(), source); 51668f91cd2SMatthias Springer rewriter.replaceOp(cloneOp, source); 51757470abcSAlexander Belyaev rewriter.eraseOp(redundantDealloc); 51857470abcSAlexander Belyaev return success(); 51957470abcSAlexander Belyaev } 52057470abcSAlexander Belyaev }; 52157470abcSAlexander Belyaev 522be0a7e9fSMehdi Amini } // namespace 52357470abcSAlexander Belyaev 5249f85c198SRiver Riddle void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 52557470abcSAlexander Belyaev MLIRContext *context) { 526b4e0507cSTres Popp results.add<SimplifyClones>(context); 52757470abcSAlexander Belyaev } 52857470abcSAlexander Belyaev 52957470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 53027a431f5SMatthias Springer // DeallocTensorOp 53127a431f5SMatthias Springer //===----------------------------------------------------------------------===// 53227a431f5SMatthias Springer 53327a431f5SMatthias Springer LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, 53427a431f5SMatthias Springer const BufferizationOptions &options) { 53527a431f5SMatthias Springer FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options); 53627a431f5SMatthias Springer if (failed(buffer)) 53727a431f5SMatthias Springer return failure(); 538caa2a4aeSMatthias Springer rewriter.create<memref::DeallocOp>(getLoc(), *buffer); 53927a431f5SMatthias Springer rewriter.eraseOp(getOperation()); 54027a431f5SMatthias Springer return success(); 54127a431f5SMatthias Springer } 54227a431f5SMatthias Springer 54327a431f5SMatthias Springer //===----------------------------------------------------------------------===// 54491464e1dSMatthias Springer // MaterializeInDestinationOp 54591464e1dSMatthias Springer //===----------------------------------------------------------------------===// 54691464e1dSMatthias Springer 54791464e1dSMatthias Springer bool MaterializeInDestinationOp::bufferizesToMemoryRead( 54891464e1dSMatthias Springer OpOperand &opOperand, const AnalysisState &state) { 54955585043SMatthias Springer return opOperand == getSourceMutable(); 55091464e1dSMatthias Springer } 55191464e1dSMatthias Springer 55291464e1dSMatthias Springer bool MaterializeInDestinationOp::bufferizesToMemoryWrite( 55391464e1dSMatthias Springer OpOperand &opOperand, const AnalysisState &state) { 55455585043SMatthias Springer if (opOperand == getDestMutable()) { 5550fcaca2fSMatthias Springer assert(isa<TensorType>(getDest().getType()) && "expected tensor type"); 5560fcaca2fSMatthias Springer return true; 5570fcaca2fSMatthias Springer } 5580fcaca2fSMatthias Springer return false; 55991464e1dSMatthias Springer } 56091464e1dSMatthias Springer 5618ee38f3bSMatthias Springer bool MaterializeInDestinationOp::mustBufferizeInPlace( 5628ee38f3bSMatthias Springer OpOperand &opOperand, const AnalysisState &state) { 5638ee38f3bSMatthias Springer // The source is only read and not written, so it always bufferizes in-place 5648ee38f3bSMatthias Springer // by default. The destination is written and is forced to bufferize in-place 5658ee38f3bSMatthias Springer // (if it is a tensor). 5668ee38f3bSMatthias Springer return true; 5678ee38f3bSMatthias Springer } 5688ee38f3bSMatthias Springer 56991464e1dSMatthias Springer AliasingValueList 57091464e1dSMatthias Springer MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, 57191464e1dSMatthias Springer const AnalysisState &state) { 57255585043SMatthias Springer if (opOperand == getDestMutable()) { 5730fcaca2fSMatthias Springer assert(isa<TensorType>(getDest().getType()) && "expected tensor type"); 57491464e1dSMatthias Springer return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; 5750fcaca2fSMatthias Springer } 57691464e1dSMatthias Springer return {}; 57791464e1dSMatthias Springer } 57891464e1dSMatthias Springer 57991464e1dSMatthias Springer LogicalResult 58091464e1dSMatthias Springer MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, 58191464e1dSMatthias Springer const BufferizationOptions &options) { 5820fcaca2fSMatthias Springer bool tensorDest = isa<TensorType>(getDest().getType()); 5830fcaca2fSMatthias Springer Value buffer; 5840fcaca2fSMatthias Springer if (tensorDest) { 5850fcaca2fSMatthias Springer FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options); 5860fcaca2fSMatthias Springer if (failed(maybeBuffer)) 58791464e1dSMatthias Springer return failure(); 5880fcaca2fSMatthias Springer buffer = *maybeBuffer; 5890fcaca2fSMatthias Springer } else { 5900fcaca2fSMatthias Springer assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type"); 5910fcaca2fSMatthias Springer buffer = getDest(); 5920fcaca2fSMatthias Springer } 593437c6217SMatthias Springer auto srcBuffer = getBuffer(rewriter, getSource(), options); 594437c6217SMatthias Springer if (failed(srcBuffer)) 595437c6217SMatthias Springer return failure(); 596437c6217SMatthias Springer if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer))) 597437c6217SMatthias Springer return failure(); 5980fcaca2fSMatthias Springer replaceOpWithBufferizedValues(rewriter, getOperation(), 5990fcaca2fSMatthias Springer tensorDest ? ValueRange(buffer) : ValueRange()); 60091464e1dSMatthias Springer return success(); 60191464e1dSMatthias Springer } 60291464e1dSMatthias Springer 60364839fbdSMatthias Springer bool MaterializeInDestinationOp::bufferizesToElementwiseAccess( 60464839fbdSMatthias Springer const AnalysisState &state, ArrayRef<OpOperand *> opOperands) { 60564839fbdSMatthias Springer // As elements are copied from the "source" buffer to the "dest" buffer, 60664839fbdSMatthias Springer // already copied elements are not read a second time. 60764839fbdSMatthias Springer return true; 60864839fbdSMatthias Springer } 60964839fbdSMatthias Springer 61091464e1dSMatthias Springer LogicalResult MaterializeInDestinationOp::reifyResultShapes( 61191464e1dSMatthias Springer OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 6120fcaca2fSMatthias Springer if (getOperation()->getNumResults() == 1) { 6130fcaca2fSMatthias Springer assert(isa<TensorType>(getDest().getType()) && "expected tensor type"); 6140fcaca2fSMatthias Springer reifiedReturnShapes.resize(1, 6150fcaca2fSMatthias Springer SmallVector<OpFoldResult>(getType().getRank())); 6160fcaca2fSMatthias Springer reifiedReturnShapes[0] = 6170fcaca2fSMatthias Springer tensor::getMixedSizes(builder, getLoc(), getDest()); 6180fcaca2fSMatthias Springer } 61991464e1dSMatthias Springer return success(); 62091464e1dSMatthias Springer } 62191464e1dSMatthias Springer 62264839fbdSMatthias Springer Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, 62364839fbdSMatthias Springer Location loc) { 6240fcaca2fSMatthias Springer if (isa<TensorType>(getDest().getType())) { 62564839fbdSMatthias Springer // The subset is the entire destination tensor. 62664839fbdSMatthias Springer return getDest(); 62764839fbdSMatthias Springer } 62864839fbdSMatthias Springer 6296d88ac11SMatthias Springer // The "restrict" attribute is transferred from this op to the newly created 6306d88ac11SMatthias Springer // to_tensor op. If this op does not the "restrict" attribute, the subset 6316d88ac11SMatthias Springer // extraction cannot be built because there is no guarantee that there is no 6326d88ac11SMatthias Springer // pre-existing "restrict" to_tensor op with the same/an aliasing destination. 6336d88ac11SMatthias Springer if (!getRestrict()) 6346d88ac11SMatthias Springer return {}; 6356d88ac11SMatthias Springer 6360fcaca2fSMatthias Springer // Build a bufferization.to_tensor op. 6370fcaca2fSMatthias Springer assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type"); 6380fcaca2fSMatthias Springer assert(getRestrict() && 6390fcaca2fSMatthias Springer "expected that ops with memrefs dest have 'restrict'"); 6406d88ac11SMatthias Springer setRestrict(false); 6416d88ac11SMatthias Springer return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true, 6420fcaca2fSMatthias Springer getWritable()); 6430fcaca2fSMatthias Springer } 6440fcaca2fSMatthias Springer 64564839fbdSMatthias Springer bool MaterializeInDestinationOp::isEquivalentSubset( 64664839fbdSMatthias Springer Value candidate, function_ref<bool(Value, Value)> equivalenceFn) { 64764839fbdSMatthias Springer return equivalenceFn(getDest(), candidate); 64864839fbdSMatthias Springer } 64964839fbdSMatthias Springer 65064839fbdSMatthias Springer SmallVector<Value> 65164839fbdSMatthias Springer MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() { 65264839fbdSMatthias Springer return {getDest()}; 65364839fbdSMatthias Springer } 65464839fbdSMatthias Springer 65564839fbdSMatthias Springer OpOperand &MaterializeInDestinationOp::getSourceOperand() { 65664839fbdSMatthias Springer return getOperation()->getOpOperand(0) /*source*/; 65764839fbdSMatthias Springer } 65864839fbdSMatthias Springer 6591abd8d1aSMatthias Springer bool MaterializeInDestinationOp::operatesOnEquivalentSubset( 6601abd8d1aSMatthias Springer SubsetOpInterface subsetOp, 6611abd8d1aSMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) { 6621abd8d1aSMatthias Springer return false; 6631abd8d1aSMatthias Springer } 6641abd8d1aSMatthias Springer 6651abd8d1aSMatthias Springer bool MaterializeInDestinationOp::operatesOnDisjointSubset( 6661abd8d1aSMatthias Springer SubsetOpInterface subsetOp, 6671abd8d1aSMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) { 6681abd8d1aSMatthias Springer return false; 6691abd8d1aSMatthias Springer } 6701abd8d1aSMatthias Springer 6710fcaca2fSMatthias Springer LogicalResult MaterializeInDestinationOp::verify() { 6720fcaca2fSMatthias Springer if (!isa<TensorType, BaseMemRefType>(getDest().getType())) 6730fcaca2fSMatthias Springer return emitOpError("'dest' must be a tensor or a memref"); 6740fcaca2fSMatthias Springer if (auto destType = dyn_cast<TensorType>(getDest().getType())) { 6750fcaca2fSMatthias Springer if (getOperation()->getNumResults() != 1) 6760fcaca2fSMatthias Springer return emitOpError("tensor 'dest' implies exactly one tensor result"); 6770fcaca2fSMatthias Springer if (destType != getResult().getType()) 6780fcaca2fSMatthias Springer return emitOpError("result and 'dest' types must match"); 6790fcaca2fSMatthias Springer } 6800fcaca2fSMatthias Springer if (isa<BaseMemRefType>(getDest().getType()) && 6810fcaca2fSMatthias Springer getOperation()->getNumResults() != 0) 6820fcaca2fSMatthias Springer return emitOpError("memref 'dest' implies zero results"); 6836d88ac11SMatthias Springer if (getRestrict() && !isa<BaseMemRefType>(getDest().getType())) 6846d88ac11SMatthias Springer return emitOpError("'restrict' is valid only for memref destinations"); 6850fcaca2fSMatthias Springer if (getWritable() != isa<BaseMemRefType>(getDest().getType())) 6860fcaca2fSMatthias Springer return emitOpError("'writable' must be specified if and only if the " 6870fcaca2fSMatthias Springer "destination is of memref type"); 6889d4b20a4SMatthias Springer TensorType srcType = getSource().getType(); 6899d4b20a4SMatthias Springer ShapedType destType = cast<ShapedType>(getDest().getType()); 6909d4b20a4SMatthias Springer if (srcType.hasRank() != destType.hasRank()) 6919d4b20a4SMatthias Springer return emitOpError("source/destination shapes are incompatible"); 6929d4b20a4SMatthias Springer if (srcType.hasRank()) { 6939d4b20a4SMatthias Springer if (srcType.getRank() != destType.getRank()) 6949d4b20a4SMatthias Springer return emitOpError("rank mismatch between source and destination shape"); 6959d4b20a4SMatthias Springer for (auto [src, dest] : 6969d4b20a4SMatthias Springer llvm::zip(srcType.getShape(), destType.getShape())) { 6979d4b20a4SMatthias Springer if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) { 6989d4b20a4SMatthias Springer // Cannot verify dynamic dimension size. Assume that that they match at 6999d4b20a4SMatthias Springer // runtime. 7009d4b20a4SMatthias Springer continue; 7019d4b20a4SMatthias Springer } 7029d4b20a4SMatthias Springer if (src != dest) 7039d4b20a4SMatthias Springer return emitOpError("source/destination shapes are incompatible"); 7049d4b20a4SMatthias Springer } 7059d4b20a4SMatthias Springer } 7060fcaca2fSMatthias Springer return success(); 7070fcaca2fSMatthias Springer } 7080fcaca2fSMatthias Springer 7090fcaca2fSMatthias Springer void MaterializeInDestinationOp::build(OpBuilder &builder, 7100fcaca2fSMatthias Springer OperationState &state, Value source, 7110fcaca2fSMatthias Springer Value dest) { 712437c6217SMatthias Springer auto destTensorType = dyn_cast<TensorType>(dest.getType()); 713437c6217SMatthias Springer build(builder, state, /*result=*/destTensorType ? destTensorType : Type(), 714437c6217SMatthias Springer source, dest); 7150fcaca2fSMatthias Springer } 7160fcaca2fSMatthias Springer 7170fcaca2fSMatthias Springer bool MaterializeInDestinationOp::isWritable(Value value, 7180fcaca2fSMatthias Springer const AnalysisState &state) { 7190fcaca2fSMatthias Springer return isa<TensorType>(getDest().getType()) ? true : getWritable(); 7200fcaca2fSMatthias Springer } 7210fcaca2fSMatthias Springer 7220fcaca2fSMatthias Springer MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() { 7230fcaca2fSMatthias Springer return getDestMutable(); 7240fcaca2fSMatthias Springer } 7250fcaca2fSMatthias Springer 7260fcaca2fSMatthias Springer void MaterializeInDestinationOp::getEffects( 7270fcaca2fSMatthias Springer SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 7280fcaca2fSMatthias Springer &effects) { 7290fcaca2fSMatthias Springer if (isa<BaseMemRefType>(getDest().getType())) 7302c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(), 7310fcaca2fSMatthias Springer SideEffects::DefaultResource::get()); 7320fcaca2fSMatthias Springer } 7330fcaca2fSMatthias Springer 73491464e1dSMatthias Springer //===----------------------------------------------------------------------===// 73557470abcSAlexander Belyaev // ToTensorOp 73657470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 73757470abcSAlexander Belyaev 7388f7e7400SMatthias Springer bool ToTensorOp::isWritable(Value value, const AnalysisState &state) { 7398f7e7400SMatthias Springer return getWritable(); 7408f7e7400SMatthias Springer } 7418f7e7400SMatthias Springer 7427df76121SMarkus Böck OpFoldResult ToTensorOp::fold(FoldAdaptor) { 74399260e95SMatthias Springer if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>()) 74457470abcSAlexander Belyaev // Approximate alias analysis by conservatively folding only when no there 74557470abcSAlexander Belyaev // is no interleaved operation. 74657470abcSAlexander Belyaev if (toMemref->getBlock() == this->getOperation()->getBlock() && 74757470abcSAlexander Belyaev toMemref->getNextNode() == this->getOperation()) 74899260e95SMatthias Springer return toMemref.getTensor(); 74957470abcSAlexander Belyaev return {}; 75057470abcSAlexander Belyaev } 75157470abcSAlexander Belyaev 75257470abcSAlexander Belyaev namespace { 75357470abcSAlexander Belyaev struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 75457470abcSAlexander Belyaev using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 75557470abcSAlexander Belyaev 75657470abcSAlexander Belyaev LogicalResult matchAndRewrite(tensor::DimOp dimOp, 75757470abcSAlexander Belyaev PatternRewriter &rewriter) const override { 75804235d07SJacques Pienaar auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>(); 75957470abcSAlexander Belyaev if (!memrefToTensorOp) 76057470abcSAlexander Belyaev return failure(); 76157470abcSAlexander Belyaev 76299260e95SMatthias Springer rewriter.replaceOpWithNewOp<memref::DimOp>( 76304235d07SJacques Pienaar dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex()); 76457470abcSAlexander Belyaev return success(); 76557470abcSAlexander Belyaev } 76657470abcSAlexander Belyaev }; 76757470abcSAlexander Belyaev } // namespace 76857470abcSAlexander Belyaev 76957470abcSAlexander Belyaev void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 77057470abcSAlexander Belyaev MLIRContext *context) { 771fc9b37ddSMatthias Springer results.add<DimOfToTensorFolder>(context); 77257470abcSAlexander Belyaev } 77357470abcSAlexander Belyaev 77457470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 77557470abcSAlexander Belyaev // ToMemrefOp 77657470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 77757470abcSAlexander Belyaev 7787df76121SMarkus Böck OpFoldResult ToMemrefOp::fold(FoldAdaptor) { 77999260e95SMatthias Springer if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>()) 78099260e95SMatthias Springer if (memrefToTensor.getMemref().getType() == getType()) 78199260e95SMatthias Springer return memrefToTensor.getMemref(); 78257470abcSAlexander Belyaev return {}; 78357470abcSAlexander Belyaev } 78457470abcSAlexander Belyaev 78557470abcSAlexander Belyaev namespace { 78657470abcSAlexander Belyaev 78757470abcSAlexander Belyaev /// Replace tensor.cast + to_memref by to_memref + memref.cast. 78857470abcSAlexander Belyaev struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 78957470abcSAlexander Belyaev using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 79057470abcSAlexander Belyaev 79157470abcSAlexander Belyaev LogicalResult matchAndRewrite(ToMemrefOp toMemref, 79257470abcSAlexander Belyaev PatternRewriter &rewriter) const final { 79357470abcSAlexander Belyaev auto tensorCastOperand = 79457470abcSAlexander Belyaev toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 79557470abcSAlexander Belyaev if (!tensorCastOperand) 79657470abcSAlexander Belyaev return failure(); 797c1fa60b4STres Popp auto srcTensorType = llvm::dyn_cast<RankedTensorType>( 798c1fa60b4STres Popp tensorCastOperand.getOperand().getType()); 79957470abcSAlexander Belyaev if (!srcTensorType) 80057470abcSAlexander Belyaev return failure(); 80157470abcSAlexander Belyaev auto memrefType = MemRefType::get(srcTensorType.getShape(), 80257470abcSAlexander Belyaev srcTensorType.getElementType()); 80357470abcSAlexander Belyaev Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 80457470abcSAlexander Belyaev tensorCastOperand.getOperand()); 80557470abcSAlexander Belyaev rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 80657470abcSAlexander Belyaev memref); 80757470abcSAlexander Belyaev return success(); 80857470abcSAlexander Belyaev } 80957470abcSAlexander Belyaev }; 81057470abcSAlexander Belyaev 811cb471241SMatthias Springer /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a 812cb471241SMatthias Springer /// cast if necessary. 813cb471241SMatthias Springer struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> { 814b00ee46bSMatthias Springer using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 815b00ee46bSMatthias Springer 816b00ee46bSMatthias Springer LogicalResult matchAndRewrite(ToMemrefOp toMemref, 817b00ee46bSMatthias Springer PatternRewriter &rewriter) const final { 818c515c780SMatthias Gehre BufferizationOptions options; 819c515c780SMatthias Gehre options.bufferAlignment = 0; 820c515c780SMatthias Gehre return foldToMemrefToTensorPair(rewriter, toMemref, options); 821b00ee46bSMatthias Springer } 82257470abcSAlexander Belyaev }; 82357470abcSAlexander Belyaev 82457470abcSAlexander Belyaev /// Fold a load on a to_memref operation into an tensor.extract on the 82557470abcSAlexander Belyaev /// corresponding tensor. 82657470abcSAlexander Belyaev struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 82757470abcSAlexander Belyaev using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 82857470abcSAlexander Belyaev 82957470abcSAlexander Belyaev LogicalResult matchAndRewrite(memref::LoadOp load, 83057470abcSAlexander Belyaev PatternRewriter &rewriter) const override { 831136d746eSJacques Pienaar auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>(); 83257470abcSAlexander Belyaev if (!toMemref) 83357470abcSAlexander Belyaev return failure(); 83457470abcSAlexander Belyaev 83599260e95SMatthias Springer rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(), 836136d746eSJacques Pienaar load.getIndices()); 83757470abcSAlexander Belyaev return success(); 83857470abcSAlexander Belyaev } 83957470abcSAlexander Belyaev }; 84057470abcSAlexander Belyaev 84157470abcSAlexander Belyaev /// Fold dim of a to_memref into the dim of the tensor. 84257470abcSAlexander Belyaev struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 84357470abcSAlexander Belyaev using OpRewritePattern<memref::DimOp>::OpRewritePattern; 84457470abcSAlexander Belyaev 84557470abcSAlexander Belyaev LogicalResult matchAndRewrite(memref::DimOp dimOp, 84657470abcSAlexander Belyaev PatternRewriter &rewriter) const override { 847136d746eSJacques Pienaar auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>(); 84857470abcSAlexander Belyaev if (!castOp) 84957470abcSAlexander Belyaev return failure(); 85057470abcSAlexander Belyaev Value newSource = castOp.getOperand(); 851136d746eSJacques Pienaar rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, 852136d746eSJacques Pienaar dimOp.getIndex()); 85357470abcSAlexander Belyaev return success(); 85457470abcSAlexander Belyaev } 85557470abcSAlexander Belyaev }; 85657470abcSAlexander Belyaev 85757470abcSAlexander Belyaev } // namespace 85857470abcSAlexander Belyaev 85957470abcSAlexander Belyaev void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 86057470abcSAlexander Belyaev MLIRContext *context) { 861cb471241SMatthias Springer results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, 862cb471241SMatthias Springer ToMemrefToTensorFolding>(context); 86357470abcSAlexander Belyaev } 86457470abcSAlexander Belyaev 865b00ee46bSMatthias Springer LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 866b55d55ecSMatthias Springer const BufferizationOptions &options) { 867b00ee46bSMatthias Springer // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 868c515c780SMatthias Gehre (void)foldToMemrefToTensorPair(rewriter, *this, options); 8690b293bf0SMatthias Springer // Note: The return value of `bufferize` indicates whether there was an error 8700b293bf0SMatthias Springer // or not. (And not whether the pattern matched or not.) 8710b293bf0SMatthias Springer return success(); 872b00ee46bSMatthias Springer } 873b00ee46bSMatthias Springer 874e8bcc37fSRamkumar Ramachandra std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, 875e8bcc37fSRamkumar Ramachandra Value alloc) { 87657470abcSAlexander Belyaev return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 87757470abcSAlexander Belyaev .getOperation(); 87857470abcSAlexander Belyaev } 87957470abcSAlexander Belyaev 880e8bcc37fSRamkumar Ramachandra std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 88157470abcSAlexander Belyaev return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 88257470abcSAlexander Belyaev } 88357470abcSAlexander Belyaev 88457470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 885d5825621SMartin Erhart // DeallocOp 886d5825621SMartin Erhart //===----------------------------------------------------------------------===// 887d5825621SMartin Erhart 888d5825621SMartin Erhart LogicalResult DeallocOp::inferReturnTypes( 889d5825621SMartin Erhart MLIRContext *context, std::optional<::mlir::Location> location, 890d5825621SMartin Erhart ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, 891d5825621SMartin Erhart RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { 892d5825621SMartin Erhart DeallocOpAdaptor adaptor(operands, attributes, properties, regions); 8934bde084fSMartin Erhart inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(), 8944bde084fSMartin Erhart IntegerType::get(context, 1)); 895d5825621SMartin Erhart return success(); 896d5825621SMartin Erhart } 897d5825621SMartin Erhart 898d5825621SMartin Erhart LogicalResult DeallocOp::verify() { 899d5825621SMartin Erhart if (getMemrefs().size() != getConditions().size()) 900d5825621SMartin Erhart return emitOpError( 901d5825621SMartin Erhart "must have the same number of conditions as memrefs to deallocate"); 9020ef990d5SMatthias Springer if (getRetained().size() != getUpdatedConditions().size()) 9030ef990d5SMatthias Springer return emitOpError("must have the same number of updated conditions " 9040ef990d5SMatthias Springer "(results) as retained operands"); 905d5825621SMartin Erhart return success(); 906d5825621SMartin Erhart } 907d5825621SMartin Erhart 9084bde084fSMartin Erhart static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, 9095c7d97beSMartin Erhart ValueRange memrefs, 9105c7d97beSMartin Erhart ValueRange conditions, 9114bde084fSMartin Erhart PatternRewriter &rewriter) { 9125c7d97beSMartin Erhart if (deallocOp.getMemrefs() == memrefs && 9135c7d97beSMartin Erhart deallocOp.getConditions() == conditions) 9144bde084fSMartin Erhart return failure(); 9154bde084fSMartin Erhart 9165fcf907bSMatthias Springer rewriter.modifyOpInPlace(deallocOp, [&]() { 9174bde084fSMartin Erhart deallocOp.getMemrefsMutable().assign(memrefs); 9184bde084fSMartin Erhart deallocOp.getConditionsMutable().assign(conditions); 9194bde084fSMartin Erhart }); 9204bde084fSMartin Erhart return success(); 9214bde084fSMartin Erhart } 9224bde084fSMartin Erhart 92317aaa651SMartin Erhart namespace { 92417aaa651SMartin Erhart 9254bde084fSMartin Erhart /// Remove duplicate values in the list of memrefs to be deallocated. We need to 9264bde084fSMartin Erhart /// make sure the corresponding condition value is updated accordingly since 9274bde084fSMartin Erhart /// their two conditions might not cover the same set of cases. In that case, we 9284bde084fSMartin Erhart /// have to combine them (by computing the disjunction of them). 92917aaa651SMartin Erhart /// Example: 93017aaa651SMartin Erhart /// ```mlir 9314bde084fSMartin Erhart /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2) 93217aaa651SMartin Erhart /// ``` 93317aaa651SMartin Erhart /// is canonicalized to 93417aaa651SMartin Erhart /// ```mlir 93517aaa651SMartin Erhart /// %0 = arith.ori %arg1, %arg2 : i1 9364bde084fSMartin Erhart /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0) 93717aaa651SMartin Erhart /// ``` 9384bde084fSMartin Erhart struct DeallocRemoveDuplicateDeallocMemrefs 9394bde084fSMartin Erhart : public OpRewritePattern<DeallocOp> { 94017aaa651SMartin Erhart using OpRewritePattern<DeallocOp>::OpRewritePattern; 94117aaa651SMartin Erhart 94217aaa651SMartin Erhart LogicalResult matchAndRewrite(DeallocOp deallocOp, 94317aaa651SMartin Erhart PatternRewriter &rewriter) const override { 94417aaa651SMartin Erhart // Unique memrefs to be deallocated. 94517aaa651SMartin Erhart DenseMap<Value, unsigned> memrefToCondition; 9464bde084fSMartin Erhart SmallVector<Value> newMemrefs, newConditions; 947b0688ed0SMartin Erhart for (auto [i, memref, cond] : 948b0688ed0SMartin Erhart llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) { 94917aaa651SMartin Erhart if (memrefToCondition.count(memref)) { 95017aaa651SMartin Erhart // If the dealloc conditions don't match, we need to make sure that the 95117aaa651SMartin Erhart // dealloc happens on the union of cases. 95217aaa651SMartin Erhart Value &newCond = newConditions[memrefToCondition[memref]]; 95317aaa651SMartin Erhart if (newCond != cond) 95417aaa651SMartin Erhart newCond = 95517aaa651SMartin Erhart rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond); 95617aaa651SMartin Erhart } else { 95717aaa651SMartin Erhart memrefToCondition.insert({memref, newConditions.size()}); 95817aaa651SMartin Erhart newMemrefs.push_back(memref); 95917aaa651SMartin Erhart newConditions.push_back(cond); 96017aaa651SMartin Erhart } 96117aaa651SMartin Erhart } 96217aaa651SMartin Erhart 96317aaa651SMartin Erhart // Return failure if we don't change anything such that we don't run into an 96417aaa651SMartin Erhart // infinite loop of pattern applications. 9654bde084fSMartin Erhart return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 9664bde084fSMartin Erhart rewriter); 9674bde084fSMartin Erhart } 9684bde084fSMartin Erhart }; 9694bde084fSMartin Erhart 9704bde084fSMartin Erhart /// Remove duplicate values in the list of retained memrefs. We need to make 9714bde084fSMartin Erhart /// sure the corresponding result condition value is replaced properly. 9724bde084fSMartin Erhart /// Example: 9734bde084fSMartin Erhart /// ```mlir 9744bde084fSMartin Erhart /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...) 9754bde084fSMartin Erhart /// ``` 9764bde084fSMartin Erhart /// is canonicalized to 9774bde084fSMartin Erhart /// ```mlir 9784bde084fSMartin Erhart /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>) 9794bde084fSMartin Erhart /// ``` 9804bde084fSMartin Erhart struct DeallocRemoveDuplicateRetainedMemrefs 9814bde084fSMartin Erhart : public OpRewritePattern<DeallocOp> { 9824bde084fSMartin Erhart using OpRewritePattern<DeallocOp>::OpRewritePattern; 9834bde084fSMartin Erhart 9844bde084fSMartin Erhart LogicalResult matchAndRewrite(DeallocOp deallocOp, 9854bde084fSMartin Erhart PatternRewriter &rewriter) const override { 9864bde084fSMartin Erhart // Unique retained values 9874bde084fSMartin Erhart DenseMap<Value, unsigned> seen; 9884bde084fSMartin Erhart SmallVector<Value> newRetained; 9894bde084fSMartin Erhart SmallVector<unsigned> resultReplacementIdx; 9904bde084fSMartin Erhart unsigned i = 0; 9914bde084fSMartin Erhart for (auto retained : deallocOp.getRetained()) { 9924bde084fSMartin Erhart if (seen.count(retained)) { 9934bde084fSMartin Erhart resultReplacementIdx.push_back(seen[retained]); 9944bde084fSMartin Erhart continue; 9954bde084fSMartin Erhart } 9964bde084fSMartin Erhart 9974bde084fSMartin Erhart seen[retained] = i; 9984bde084fSMartin Erhart newRetained.push_back(retained); 9994bde084fSMartin Erhart resultReplacementIdx.push_back(i++); 10004bde084fSMartin Erhart } 10014bde084fSMartin Erhart 10024bde084fSMartin Erhart // Return failure if we don't change anything such that we don't run into an 10034bde084fSMartin Erhart // infinite loop of pattern applications. 10044bde084fSMartin Erhart if (newRetained.size() == deallocOp.getRetained().size()) 100517aaa651SMartin Erhart return failure(); 100617aaa651SMartin Erhart 100717aaa651SMartin Erhart // We need to create a new op because the number of results is always the 100817aaa651SMartin Erhart // same as the number of condition operands. 10094bde084fSMartin Erhart auto newDeallocOp = 10104bde084fSMartin Erhart rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(), 10114bde084fSMartin Erhart deallocOp.getConditions(), newRetained); 10124bde084fSMartin Erhart SmallVector<Value> replacements( 10134bde084fSMartin Erhart llvm::map_range(resultReplacementIdx, [&](unsigned idx) { 10144bde084fSMartin Erhart return newDeallocOp.getUpdatedConditions()[idx]; 10154bde084fSMartin Erhart })); 10164bde084fSMartin Erhart rewriter.replaceOp(deallocOp, replacements); 101717aaa651SMartin Erhart return success(); 101817aaa651SMartin Erhart } 101917aaa651SMartin Erhart }; 102017aaa651SMartin Erhart 10214bde084fSMartin Erhart /// Erase deallocation operations where the variadic list of memrefs to 10224bde084fSMartin Erhart /// deallocate is empty. Example: 10234bde084fSMartin Erhart /// ```mlir 10244bde084fSMartin Erhart /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>) 1025b0688ed0SMartin Erhart /// ``` 1026b0688ed0SMartin Erhart struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> { 1027b0688ed0SMartin Erhart using OpRewritePattern<DeallocOp>::OpRewritePattern; 1028b0688ed0SMartin Erhart 1029b0688ed0SMartin Erhart LogicalResult matchAndRewrite(DeallocOp deallocOp, 1030b0688ed0SMartin Erhart PatternRewriter &rewriter) const override { 1031b0688ed0SMartin Erhart if (deallocOp.getMemrefs().empty()) { 10324bde084fSMartin Erhart Value constFalse = rewriter.create<arith::ConstantOp>( 10334bde084fSMartin Erhart deallocOp.getLoc(), rewriter.getBoolAttr(false)); 10344bde084fSMartin Erhart rewriter.replaceOp( 10354bde084fSMartin Erhart deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(), 10364bde084fSMartin Erhart constFalse)); 1037b0688ed0SMartin Erhart return success(); 1038b0688ed0SMartin Erhart } 1039b0688ed0SMartin Erhart return failure(); 1040b0688ed0SMartin Erhart } 1041b0688ed0SMartin Erhart }; 1042b0688ed0SMartin Erhart 1043d26eb822SMartin Erhart /// Removes memrefs from the deallocation list if their associated condition is 1044d26eb822SMartin Erhart /// always 'false'. 1045d26eb822SMartin Erhart /// 1046d26eb822SMartin Erhart /// Example: 1047d26eb822SMartin Erhart /// ``` 10484bde084fSMartin Erhart /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) 1049d26eb822SMartin Erhart /// if (%arg2, %false) 1050d26eb822SMartin Erhart /// ``` 1051d26eb822SMartin Erhart /// becomes 1052d26eb822SMartin Erhart /// ``` 10534bde084fSMartin Erhart /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2) 1054d26eb822SMartin Erhart /// ``` 1055d26eb822SMartin Erhart struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> { 1056d26eb822SMartin Erhart using OpRewritePattern<DeallocOp>::OpRewritePattern; 1057d26eb822SMartin Erhart 1058d26eb822SMartin Erhart LogicalResult matchAndRewrite(DeallocOp deallocOp, 1059d26eb822SMartin Erhart PatternRewriter &rewriter) const override { 1060d26eb822SMartin Erhart SmallVector<Value> newMemrefs, newConditions; 10614bde084fSMartin Erhart for (auto [memref, cond] : 10624bde084fSMartin Erhart llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { 10634bde084fSMartin Erhart if (!matchPattern(cond, m_Zero())) { 1064d26eb822SMartin Erhart newMemrefs.push_back(memref); 1065d26eb822SMartin Erhart newConditions.push_back(cond); 10664bde084fSMartin Erhart } 1067d26eb822SMartin Erhart } 1068d26eb822SMartin Erhart 10694bde084fSMartin Erhart return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 10704bde084fSMartin Erhart rewriter); 1071d26eb822SMartin Erhart } 1072d26eb822SMartin Erhart }; 1073d26eb822SMartin Erhart 10745c7d97beSMartin Erhart /// The `memref.extract_strided_metadata` is often inserted to get the base 10755c7d97beSMartin Erhart /// memref if the operand is not already guaranteed to be the result of a memref 10765c7d97beSMartin Erhart /// allocation operation. This canonicalization pattern removes this extraction 10775c7d97beSMartin Erhart /// operation if the operand is now produced by an allocation operation (e.g., 10785c7d97beSMartin Erhart /// due to other canonicalizations simplifying the IR). 10795c7d97beSMartin Erhart /// 10805c7d97beSMartin Erhart /// Example: 10815c7d97beSMartin Erhart /// ```mlir 10825c7d97beSMartin Erhart /// %alloc = memref.alloc() : memref<2xi32> 10835c7d97beSMartin Erhart /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata 10845c7d97beSMartin Erhart /// %alloc : memref<2xi32> -> memref<i32>, index, index, index 10855c7d97beSMartin Erhart /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond) 10865c7d97beSMartin Erhart /// ``` 10875c7d97beSMartin Erhart /// is canonicalized to 10885c7d97beSMartin Erhart /// ```mlir 10895c7d97beSMartin Erhart /// %alloc = memref.alloc() : memref<2xi32> 10905c7d97beSMartin Erhart /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond) 10915c7d97beSMartin Erhart /// ``` 10925c7d97beSMartin Erhart struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> { 10935c7d97beSMartin Erhart using OpRewritePattern<DeallocOp>::OpRewritePattern; 10945c7d97beSMartin Erhart 10955c7d97beSMartin Erhart LogicalResult matchAndRewrite(DeallocOp deallocOp, 10965c7d97beSMartin Erhart PatternRewriter &rewriter) const override { 10975c7d97beSMartin Erhart SmallVector<Value> newMemrefs( 10985c7d97beSMartin Erhart llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) { 10995c7d97beSMartin Erhart auto extractStridedOp = 11005c7d97beSMartin Erhart memref.getDefiningOp<memref::ExtractStridedMetadataOp>(); 11015c7d97beSMartin Erhart if (!extractStridedOp) 11025c7d97beSMartin Erhart return memref; 11035c7d97beSMartin Erhart Value allocMemref = extractStridedOp.getOperand(); 11045c7d97beSMartin Erhart auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>(); 11055c7d97beSMartin Erhart if (!allocOp) 11065c7d97beSMartin Erhart return memref; 11075c7d97beSMartin Erhart if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref)) 11085c7d97beSMartin Erhart return allocMemref; 11095c7d97beSMartin Erhart return memref; 11105c7d97beSMartin Erhart })); 11115c7d97beSMartin Erhart 11125c7d97beSMartin Erhart return updateDeallocIfChanged(deallocOp, newMemrefs, 11135c7d97beSMartin Erhart deallocOp.getConditions(), rewriter); 11145c7d97beSMartin Erhart } 11155c7d97beSMartin Erhart }; 11165c7d97beSMartin Erhart 1117778494aeSMartin Erhart /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no 1118778494aeSMartin Erhart /// other user of the allocated value and the allocating operation can be safely 1119778494aeSMartin Erhart /// removed. If the same value is present multiple times, this pattern relies on 1120778494aeSMartin Erhart /// other canonicalization patterns to remove the duplicate first. 1121778494aeSMartin Erhart /// 1122778494aeSMartin Erhart /// Example: 1123778494aeSMartin Erhart /// ```mlir 1124778494aeSMartin Erhart /// %alloc = memref.alloc() : memref<2xi32> 1125778494aeSMartin Erhart /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true) 1126778494aeSMartin Erhart /// ``` 1127778494aeSMartin Erhart /// is canonicalized to 1128778494aeSMartin Erhart /// ```mlir 1129778494aeSMartin Erhart /// bufferization.dealloc (%arg0 : ...) if (%true) 1130778494aeSMartin Erhart /// ``` 1131778494aeSMartin Erhart struct RemoveAllocDeallocPairWhenNoOtherUsers 1132778494aeSMartin Erhart : public OpRewritePattern<DeallocOp> { 1133778494aeSMartin Erhart using OpRewritePattern<DeallocOp>::OpRewritePattern; 1134778494aeSMartin Erhart 1135778494aeSMartin Erhart LogicalResult matchAndRewrite(DeallocOp deallocOp, 1136778494aeSMartin Erhart PatternRewriter &rewriter) const override { 1137778494aeSMartin Erhart SmallVector<Value> newMemrefs, newConditions; 1138778494aeSMartin Erhart SmallVector<Operation *> toDelete; 1139778494aeSMartin Erhart for (auto [memref, cond] : 1140778494aeSMartin Erhart llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { 1141778494aeSMartin Erhart if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) { 1142778494aeSMartin Erhart // Check that it is indeed an allocate effect, that the op has no other 1143778494aeSMartin Erhart // side effects (which would not allow us to remove the op), and that 1144778494aeSMartin Erhart // there are no other users. 1145778494aeSMartin Erhart if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) && 1146778494aeSMartin Erhart hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) && 1147778494aeSMartin Erhart memref.hasOneUse()) { 1148778494aeSMartin Erhart toDelete.push_back(allocOp); 1149778494aeSMartin Erhart continue; 1150778494aeSMartin Erhart } 1151778494aeSMartin Erhart } 1152778494aeSMartin Erhart 1153778494aeSMartin Erhart newMemrefs.push_back(memref); 1154778494aeSMartin Erhart newConditions.push_back(cond); 1155778494aeSMartin Erhart } 1156778494aeSMartin Erhart 1157778494aeSMartin Erhart if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 1158778494aeSMartin Erhart rewriter))) 1159778494aeSMartin Erhart return failure(); 1160778494aeSMartin Erhart 1161778494aeSMartin Erhart for (Operation *op : toDelete) 1162778494aeSMartin Erhart rewriter.eraseOp(op); 1163778494aeSMartin Erhart 1164778494aeSMartin Erhart return success(); 1165778494aeSMartin Erhart } 1166778494aeSMartin Erhart }; 1167778494aeSMartin Erhart 116817aaa651SMartin Erhart } // anonymous namespace 116917aaa651SMartin Erhart 117017aaa651SMartin Erhart void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, 117117aaa651SMartin Erhart MLIRContext *context) { 1172fff18305SMartin Erhart populateDeallocOpCanonicalizationPatterns(results, context); 1173fff18305SMartin Erhart } 1174fff18305SMartin Erhart 1175fff18305SMartin Erhart void bufferization::populateDeallocOpCanonicalizationPatterns( 1176fff18305SMartin Erhart RewritePatternSet &patterns, MLIRContext *context) { 1177fff18305SMartin Erhart patterns.add<DeallocRemoveDuplicateDeallocMemrefs, 117887f2dee4SMartin Erhart DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc, 1179778494aeSMartin Erhart EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc, 1180778494aeSMartin Erhart RemoveAllocDeallocPairWhenNoOtherUsers>(context); 118117aaa651SMartin Erhart } 118217aaa651SMartin Erhart 1183d5825621SMartin Erhart //===----------------------------------------------------------------------===// 118457470abcSAlexander Belyaev // TableGen'd op method definitions 118557470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 118657470abcSAlexander Belyaev 118757470abcSAlexander Belyaev #define GET_OP_CLASSES 118857470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 1189