xref: /llvm-project/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
157470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
257470abcSAlexander Belyaev //
357470abcSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
457470abcSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
557470abcSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
657470abcSAlexander Belyaev //
757470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
857470abcSAlexander Belyaev 
9abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
10ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
129a3d60e0SAart Bik #include "mlir/Dialect/Func/IR/FuncOps.h"
13eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
149a3d60e0SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
16ec55f0bdSMatthias Springer #include "mlir/IR/Matchers.h"
17a1fe1f5fSKazu Hirata #include <optional>
1857470abcSAlexander Belyaev 
1957470abcSAlexander Belyaev using namespace mlir;
2057470abcSAlexander Belyaev using namespace mlir::bufferization;
2157470abcSAlexander Belyaev 
2257470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
23fa7c8cb4SMatthias Springer // Helper functions
24fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===//
25fa7c8cb4SMatthias Springer 
26c515c780SMatthias Gehre FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
27c515c780SMatthias Gehre     OpBuilder &b, Value value, MemRefType destType,
28c515c780SMatthias Gehre     const BufferizationOptions &options) {
29c1fa60b4STres Popp   auto srcType = llvm::cast<MemRefType>(value.getType());
30fa7c8cb4SMatthias Springer 
31fa7c8cb4SMatthias Springer   // Element type, rank and memory space must match.
32fa7c8cb4SMatthias Springer   if (srcType.getElementType() != destType.getElementType())
33fa7c8cb4SMatthias Springer     return failure();
345d04f0c9SMatthias Springer   if (srcType.getMemorySpace() != destType.getMemorySpace())
35fa7c8cb4SMatthias Springer     return failure();
36fa7c8cb4SMatthias Springer   if (srcType.getRank() != destType.getRank())
37fa7c8cb4SMatthias Springer     return failure();
38fa7c8cb4SMatthias Springer 
39fa7c8cb4SMatthias Springer   // In case the affine maps are different, we may need to use a copy if we go
40fa7c8cb4SMatthias Springer   // from dynamic to static offset or stride (the canonicalization cannot know
41fa7c8cb4SMatthias Springer   // at this point that it is really cast compatible).
42fa7c8cb4SMatthias Springer   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43fa7c8cb4SMatthias Springer     int64_t sourceOffset, targetOffset;
44fa7c8cb4SMatthias Springer     SmallVector<int64_t, 4> sourceStrides, targetStrides;
45*6aaa8f25SMatthias Springer     if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
46*6aaa8f25SMatthias Springer         failed(target.getStridesAndOffset(targetStrides, targetOffset)))
47fa7c8cb4SMatthias Springer       return false;
48fa7c8cb4SMatthias Springer     auto dynamicToStatic = [](int64_t a, int64_t b) {
49399638f9SAliia Khasanova       return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50fa7c8cb4SMatthias Springer     };
51fa7c8cb4SMatthias Springer     if (dynamicToStatic(sourceOffset, targetOffset))
52fa7c8cb4SMatthias Springer       return false;
53fa7c8cb4SMatthias Springer     for (auto it : zip(sourceStrides, targetStrides))
54fa7c8cb4SMatthias Springer       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
55fa7c8cb4SMatthias Springer         return false;
56fa7c8cb4SMatthias Springer     return true;
57fa7c8cb4SMatthias Springer   };
58fa7c8cb4SMatthias Springer 
59fa7c8cb4SMatthias Springer   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60fa7c8cb4SMatthias Springer   // ensure that we only generate casts that always succeed at runtime, we check
61fa7c8cb4SMatthias Springer   // a fix extra conditions in `isGuaranteedCastCompatible`.
62fa7c8cb4SMatthias Springer   if (memref::CastOp::areCastCompatible(srcType, destType) &&
63fa7c8cb4SMatthias Springer       isGuaranteedCastCompatible(srcType, destType)) {
64fa7c8cb4SMatthias Springer     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65fa7c8cb4SMatthias Springer     return casted;
66fa7c8cb4SMatthias Springer   }
67fa7c8cb4SMatthias Springer 
68fa7c8cb4SMatthias Springer   auto loc = value.getLoc();
69fa7c8cb4SMatthias Springer   SmallVector<Value, 4> dynamicOperands;
70fa7c8cb4SMatthias Springer   for (int i = 0; i < destType.getRank(); ++i) {
71399638f9SAliia Khasanova     if (destType.getShape()[i] != ShapedType::kDynamic)
72fa7c8cb4SMatthias Springer       continue;
73b23c8225SMatthias Springer     Value size = b.create<memref::DimOp>(loc, value, i);
74fa7c8cb4SMatthias Springer     dynamicOperands.push_back(size);
75fa7c8cb4SMatthias Springer   }
76c515c780SMatthias Gehre 
77c515c780SMatthias Gehre   FailureOr<Value> copy =
78c515c780SMatthias Gehre       options.createAlloc(b, loc, destType, dynamicOperands);
79c515c780SMatthias Gehre   if (failed(copy))
80c515c780SMatthias Gehre     return failure();
81c515c780SMatthias Gehre   if (failed(options.createMemCpy(b, loc, value, *copy)))
82c515c780SMatthias Gehre     return failure();
83fa7c8cb4SMatthias Springer   return copy;
84fa7c8cb4SMatthias Springer }
85fa7c8cb4SMatthias Springer 
86d820acddSMatthias Springer /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
87d820acddSMatthias Springer /// to_memref op are different, a memref.cast is needed.
88c515c780SMatthias Gehre LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
89c515c780SMatthias Gehre     RewriterBase &rewriter, ToMemrefOp toMemref,
90c515c780SMatthias Gehre     const BufferizationOptions &options) {
9199260e95SMatthias Springer   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
92d820acddSMatthias Springer   if (!memrefToTensor)
93d820acddSMatthias Springer     return failure();
94d820acddSMatthias Springer 
9599260e95SMatthias Springer   Type srcType = memrefToTensor.getMemref().getType();
96d820acddSMatthias Springer   Type destType = toMemref.getType();
97d820acddSMatthias Springer 
98d820acddSMatthias Springer   // Directly rewrite if the type did not change.
99d820acddSMatthias Springer   if (srcType == destType) {
10099260e95SMatthias Springer     rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
101d820acddSMatthias Springer     return success();
102d820acddSMatthias Springer   }
103d820acddSMatthias Springer 
104c1fa60b4STres Popp   auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105c1fa60b4STres Popp   auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106c1fa60b4STres Popp   auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107d820acddSMatthias Springer 
108d820acddSMatthias Springer   // Ranked memref -> Ranked memref cast.
109d820acddSMatthias Springer   if (rankedSrcType && rankedDestType) {
110d820acddSMatthias Springer     FailureOr<Value> replacement = castOrReallocMemRefValue(
111c515c780SMatthias Gehre         rewriter, memrefToTensor.getMemref(), rankedDestType, options);
112d820acddSMatthias Springer     if (failed(replacement))
113d820acddSMatthias Springer       return failure();
114d820acddSMatthias Springer 
115d820acddSMatthias Springer     rewriter.replaceOp(toMemref, *replacement);
116d820acddSMatthias Springer     return success();
117d820acddSMatthias Springer   }
118d820acddSMatthias Springer 
119d820acddSMatthias Springer   // Unranked memref -> Ranked memref cast: May require a copy.
120d820acddSMatthias Springer   // TODO: Not implemented at the moment.
121d820acddSMatthias Springer   if (unrankedSrcType && rankedDestType)
122d820acddSMatthias Springer     return failure();
123d820acddSMatthias Springer 
124d820acddSMatthias Springer   // Unranked memref -> unranked memref cast
125d820acddSMatthias Springer   // Ranked memref -> unranked memref cast: No copy needed.
126d820acddSMatthias Springer   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127d820acddSMatthias Springer          "expected that types are cast compatible");
128d820acddSMatthias Springer   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
12999260e95SMatthias Springer                                               memrefToTensor.getMemref());
130d820acddSMatthias Springer   return success();
131d820acddSMatthias Springer }
132d820acddSMatthias Springer 
133b3ebe3beSMatthias Springer void mlir::bufferization::populateDynamicDimSizes(
134b3ebe3beSMatthias Springer     OpBuilder &b, Location loc, Value shapedValue,
135b3ebe3beSMatthias Springer     SmallVector<Value> &dynamicDims) {
136c1fa60b4STres Popp   auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
137b3ebe3beSMatthias Springer   for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138b3ebe3beSMatthias Springer     if (shapedType.isDynamicDim(i)) {
139c1fa60b4STres Popp       if (llvm::isa<MemRefType>(shapedType)) {
140b3ebe3beSMatthias Springer         dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
141b3ebe3beSMatthias Springer       } else {
142c1fa60b4STres Popp         assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
143b3ebe3beSMatthias Springer         dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
144b3ebe3beSMatthias Springer       }
145b3ebe3beSMatthias Springer     }
146b3ebe3beSMatthias Springer   }
147b3ebe3beSMatthias Springer }
148b3ebe3beSMatthias Springer 
149fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===//
150ffdbecccSMatthias Springer // AllocTensorOp
151ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
152ffdbecccSMatthias Springer 
153ffdbecccSMatthias Springer LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154b55d55ecSMatthias Springer                                        const BufferizationOptions &options) {
155b3ebe3beSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
156b3ebe3beSMatthias Springer   Location loc = getLoc();
157ffdbecccSMatthias Springer 
158b3ebe3beSMatthias Springer   // Nothing to do for dead AllocTensorOps.
159b3ebe3beSMatthias Springer   if (getOperation()->getUses().empty()) {
160b3ebe3beSMatthias Springer     rewriter.eraseOp(getOperation());
161b3ebe3beSMatthias Springer     return success();
162b3ebe3beSMatthias Springer   }
163b3ebe3beSMatthias Springer 
164c06f01ffSMatthias Springer   // Get "copy" buffer.
165b3ebe3beSMatthias Springer   Value copyBuffer;
1665d50f51cSMatthias Springer   if (getCopy()) {
1675d50f51cSMatthias Springer     FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
1685d50f51cSMatthias Springer     if (failed(maybeCopyBuffer))
1695d50f51cSMatthias Springer       return failure();
1705d50f51cSMatthias Springer     copyBuffer = *maybeCopyBuffer;
1715d50f51cSMatthias Springer   }
172c06f01ffSMatthias Springer 
173c06f01ffSMatthias Springer   // Create memory allocation.
174123c4b02SMatthias Springer   auto allocType = bufferization::getBufferType(getResult(), options);
175111c9196SMatthias Springer   if (failed(allocType))
176111c9196SMatthias Springer     return failure();
17799260e95SMatthias Springer   SmallVector<Value> dynamicDims = getDynamicSizes();
17899260e95SMatthias Springer   if (getCopy()) {
179b3ebe3beSMatthias Springer     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180b3ebe3beSMatthias Springer     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181b3ebe3beSMatthias Springer   }
182111c9196SMatthias Springer   FailureOr<Value> alloc = options.createAlloc(
18368f58812STres Popp       rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184ffdbecccSMatthias Springer   if (failed(alloc))
185ffdbecccSMatthias Springer     return failure();
186b3ebe3beSMatthias Springer 
187b3ebe3beSMatthias Springer   // Create memory copy (if any).
18899260e95SMatthias Springer   if (getCopy()) {
189b55d55ecSMatthias Springer     if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
19056d68e8dSMatthias Springer       return failure();
19156d68e8dSMatthias Springer   }
192b3ebe3beSMatthias Springer 
193b3ebe3beSMatthias Springer   // Replace op.
194ffdbecccSMatthias Springer   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195b3ebe3beSMatthias Springer 
196ffdbecccSMatthias Springer   return success();
197ffdbecccSMatthias Springer }
198ffdbecccSMatthias Springer 
19934d65e81SMatthias Springer bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
20056d68e8dSMatthias Springer                                                   const AnalysisState &state) {
20156d68e8dSMatthias Springer   // AllocTensorOps do not write unless they have a `copy` value.
20299260e95SMatthias Springer   return static_cast<bool>(getCopy());
20356d68e8dSMatthias Springer }
20456d68e8dSMatthias Springer 
20556d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
20656d68e8dSMatthias Springer                                            const AnalysisState &state) {
20756d68e8dSMatthias Springer   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
20856d68e8dSMatthias Springer          "expected copy operand");
20956d68e8dSMatthias Springer   return true;
21056d68e8dSMatthias Springer }
21156d68e8dSMatthias Springer 
21256d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
21356d68e8dSMatthias Springer                                             const AnalysisState &state) {
21456d68e8dSMatthias Springer   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
21556d68e8dSMatthias Springer          "expected copy operand");
21656d68e8dSMatthias Springer   return false;
21756d68e8dSMatthias Springer }
21856d68e8dSMatthias Springer 
219a02ad6c1SMatthias Springer AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
22056d68e8dSMatthias Springer                                                    const AnalysisState &state) {
22156d68e8dSMatthias Springer   // This is a new allocation. It does not alias with any other buffer.
22256d68e8dSMatthias Springer   return {};
22356d68e8dSMatthias Springer }
22456d68e8dSMatthias Springer 
225878950b8SMatthias Springer FailureOr<BaseMemRefType>
226878950b8SMatthias Springer AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
227878950b8SMatthias Springer                              SmallVector<Value> &invocationStack) {
228111c9196SMatthias Springer   assert(value == getResult() && "invalid value");
229111c9196SMatthias Springer 
230111c9196SMatthias Springer   // Compute memory space of this allocation.
2319bb63374SLei Zhang   Attribute memorySpace;
232111c9196SMatthias Springer   if (getMemorySpace().has_value()) {
233111c9196SMatthias Springer     memorySpace = *getMemorySpace();
234111c9196SMatthias Springer   } else if (getCopy()) {
235123c4b02SMatthias Springer     auto copyBufferType =
236878950b8SMatthias Springer         bufferization::getBufferType(getCopy(), options, invocationStack);
237111c9196SMatthias Springer     if (failed(copyBufferType))
238111c9196SMatthias Springer       return failure();
2399bb63374SLei Zhang     memorySpace = copyBufferType->getMemorySpace();
240067d2779Sian Bearman   } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
241067d2779Sian Bearman     memorySpace = *ms;
242111c9196SMatthias Springer   } else {
243111c9196SMatthias Springer     return getOperation()->emitError("could not infer memory space");
244111c9196SMatthias Springer   }
245111c9196SMatthias Springer 
246111c9196SMatthias Springer   return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
247111c9196SMatthias Springer }
248111c9196SMatthias Springer 
249ffdbecccSMatthias Springer LogicalResult AllocTensorOp::verify() {
25099260e95SMatthias Springer   if (getCopy() && !getDynamicSizes().empty())
25156d68e8dSMatthias Springer     return emitError("dynamic sizes not needed when copying a tensor");
252bfde1783SAndrzej Warzyński   if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
253ec55f0bdSMatthias Springer     return emitError("expected ")
254ec55f0bdSMatthias Springer            << getType().getNumDynamicDims() << " dynamic sizes";
25599260e95SMatthias Springer   if (getCopy() && getCopy().getType() != getType())
25656d68e8dSMatthias Springer     return emitError("expected that `copy` and return type match");
257ffdbecccSMatthias Springer   return success();
258ffdbecccSMatthias Springer }
259ffdbecccSMatthias Springer 
26056d68e8dSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
26156d68e8dSMatthias Springer                           RankedTensorType type, ValueRange dynamicSizes) {
262c06f01ffSMatthias Springer   build(builder, result, type, dynamicSizes, /*copy=*/Value(),
26326ef3868Sbixia1         /*size_hint=*/Value(),
2640d0a94a7SMatthias Springer         /*memory_space=*/IntegerAttr());
265c06f01ffSMatthias Springer }
266c06f01ffSMatthias Springer 
267c06f01ffSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
268c06f01ffSMatthias Springer                           RankedTensorType type, ValueRange dynamicSizes,
269c06f01ffSMatthias Springer                           Value copy) {
27026ef3868Sbixia1   build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
2710d0a94a7SMatthias Springer         /*memory_space=*/IntegerAttr());
27256d68e8dSMatthias Springer }
27356d68e8dSMatthias Springer 
27426ef3868Sbixia1 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
27526ef3868Sbixia1                           TensorType type, ValueRange dynamicSizes, Value copy,
27626ef3868Sbixia1                           IntegerAttr memorySpace) {
27726ef3868Sbixia1   build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
27826ef3868Sbixia1         memorySpace);
27926ef3868Sbixia1 }
28026ef3868Sbixia1 
281ffdbecccSMatthias Springer namespace {
282ffdbecccSMatthias Springer /// Change the type of the result of a `bufferization.alloc_tensor` by making
283ffdbecccSMatthias Springer /// the result type statically sized along dimension that in the original
284ffdbecccSMatthias Springer /// operation where defined as dynamic, but the size was defined using a
285ffdbecccSMatthias Springer /// `constant` op. For example:
286ffdbecccSMatthias Springer ///
287ffdbecccSMatthias Springer ///  %c5 = arith.constant 5: index
288ec55f0bdSMatthias Springer ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
289ffdbecccSMatthias Springer ///
290ffdbecccSMatthias Springer ///  to
291ffdbecccSMatthias Springer ///
292ec55f0bdSMatthias Springer ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
293ffdbecccSMatthias Springer struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
294ffdbecccSMatthias Springer   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
295ffdbecccSMatthias Springer 
296ffdbecccSMatthias Springer   LogicalResult matchAndRewrite(AllocTensorOp op,
297ffdbecccSMatthias Springer                                 PatternRewriter &rewriter) const override {
29899260e95SMatthias Springer     if (op.getCopy())
29956d68e8dSMatthias Springer       return failure();
300ec55f0bdSMatthias Springer     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
301ec55f0bdSMatthias Springer     SmallVector<Value> newDynamicSizes;
302ec55f0bdSMatthias Springer     unsigned int dynValCounter = 0;
303ec55f0bdSMatthias Springer     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
304ec55f0bdSMatthias Springer       if (!op.isDynamicDim(i))
305ffdbecccSMatthias Springer         continue;
30699260e95SMatthias Springer       Value value = op.getDynamicSizes()[dynValCounter++];
307ec55f0bdSMatthias Springer       APInt intVal;
308ec55f0bdSMatthias Springer       if (matchPattern(value, m_ConstantInt(&intVal))) {
3095a71f7a4SMehdi Amini         int64_t dim = intVal.getSExtValue();
3105a71f7a4SMehdi Amini         if (dim >= 0)
311ec55f0bdSMatthias Springer           newShape[i] = intVal.getSExtValue();
3125a71f7a4SMehdi Amini         else
3135a71f7a4SMehdi Amini           newDynamicSizes.push_back(value);
314ec55f0bdSMatthias Springer       } else {
315ec55f0bdSMatthias Springer         newDynamicSizes.push_back(value);
316ffdbecccSMatthias Springer       }
317ffdbecccSMatthias Springer     }
318ec55f0bdSMatthias Springer     RankedTensorType newType = RankedTensorType::get(
319ec55f0bdSMatthias Springer         newShape, op.getType().getElementType(), op.getType().getEncoding());
320ffdbecccSMatthias Springer     if (newType == op.getType())
321ffdbecccSMatthias Springer       return failure();
32256d68e8dSMatthias Springer     auto newOp = rewriter.create<AllocTensorOp>(
3233474d10eSMatthias Springer         op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324ffdbecccSMatthias Springer     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325ffdbecccSMatthias Springer     return success();
326ffdbecccSMatthias Springer   }
327ffdbecccSMatthias Springer };
328ffdbecccSMatthias Springer 
329ffdbecccSMatthias Springer struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
330ffdbecccSMatthias Springer   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
331ffdbecccSMatthias Springer 
332ffdbecccSMatthias Springer   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333ffdbecccSMatthias Springer                                 PatternRewriter &rewriter) const override {
33422426110SRamkumar Ramachandra     std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
33504235d07SJacques Pienaar     auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336ffdbecccSMatthias Springer     if (!allocTensorOp || !maybeConstantIndex)
337ffdbecccSMatthias Springer       return failure();
3384bb9f918SJianbang Yang     if (*maybeConstantIndex < 0 ||
3394bb9f918SJianbang Yang         *maybeConstantIndex >= allocTensorOp.getType().getRank())
3404bb9f918SJianbang Yang       return failure();
341ec55f0bdSMatthias Springer     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
342ffdbecccSMatthias Springer       return failure();
34356d68e8dSMatthias Springer     rewriter.replaceOp(
34456d68e8dSMatthias Springer         dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
345ffdbecccSMatthias Springer     return success();
346ffdbecccSMatthias Springer   }
347ffdbecccSMatthias Springer };
348ffdbecccSMatthias Springer } // namespace
349ffdbecccSMatthias Springer 
350ffdbecccSMatthias Springer void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
351ffdbecccSMatthias Springer                                                 MLIRContext *ctx) {
352ffdbecccSMatthias Springer   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
353ffdbecccSMatthias Springer }
354ffdbecccSMatthias Springer 
355ffdbecccSMatthias Springer LogicalResult AllocTensorOp::reifyResultShapes(
356ffdbecccSMatthias Springer     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3572a5b13e7SMatthias Springer   auto shapes = llvm::to_vector<4>(
3582a5b13e7SMatthias Springer       llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
3592a5b13e7SMatthias Springer                       [&](int64_t dim) -> OpFoldResult {
360ec55f0bdSMatthias Springer                         if (isDynamicDim(dim))
36156d68e8dSMatthias Springer                           return getDynamicSize(builder, dim);
3622a5b13e7SMatthias Springer                         return builder.getIndexAttr(getStaticSize(dim));
363ffdbecccSMatthias Springer                       }));
364ffdbecccSMatthias Springer   reifiedReturnShapes.emplace_back(std::move(shapes));
365ffdbecccSMatthias Springer   return success();
366ffdbecccSMatthias Springer }
367ffdbecccSMatthias Springer 
36856d68e8dSMatthias Springer ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
36956d68e8dSMatthias Springer   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
37056d68e8dSMatthias Springer   if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
37156d68e8dSMatthias Springer       parser.parseRParen())
37256d68e8dSMatthias Springer     return failure();
37356d68e8dSMatthias Springer   ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
37456d68e8dSMatthias Springer   OpAsmParser::UnresolvedOperand copyOperand;
37556d68e8dSMatthias Springer   if (copyKeyword.succeeded())
37656d68e8dSMatthias Springer     if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
37756d68e8dSMatthias Springer         parser.parseRParen())
37856d68e8dSMatthias Springer       return failure();
37926ef3868Sbixia1   ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
38026ef3868Sbixia1   OpAsmParser::UnresolvedOperand sizeHintOperand;
38126ef3868Sbixia1   if (sizeHintKeyword.succeeded())
38226ef3868Sbixia1     if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
38326ef3868Sbixia1       return failure();
38456d68e8dSMatthias Springer   if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
38556d68e8dSMatthias Springer     return failure();
38656d68e8dSMatthias Springer 
38756d68e8dSMatthias Springer   TensorType type;
38856d68e8dSMatthias Springer   if (parser.parseCustomTypeWithFallback(type))
38956d68e8dSMatthias Springer     return failure();
39056d68e8dSMatthias Springer   result.addTypes(type);
39156d68e8dSMatthias Springer 
39256d68e8dSMatthias Springer   Type indexType = parser.getBuilder().getIndexType();
39356d68e8dSMatthias Springer   if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
39456d68e8dSMatthias Springer     return failure();
39556d68e8dSMatthias Springer   if (copyKeyword.succeeded())
39656d68e8dSMatthias Springer     if (parser.resolveOperand(copyOperand, type, result.operands))
39756d68e8dSMatthias Springer       return failure();
39826ef3868Sbixia1   if (sizeHintKeyword.succeeded())
39926ef3868Sbixia1     if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
40026ef3868Sbixia1       return failure();
40156d68e8dSMatthias Springer   result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
40258a47508SJeff Niu                       parser.getBuilder().getDenseI32ArrayAttr(
40356d68e8dSMatthias Springer                           {static_cast<int32_t>(dynamicSizesOperands.size()),
40426ef3868Sbixia1                            static_cast<int32_t>(copyKeyword.succeeded()),
40526ef3868Sbixia1                            static_cast<int32_t>(sizeHintKeyword.succeeded())}));
40656d68e8dSMatthias Springer   return success();
40756d68e8dSMatthias Springer }
40856d68e8dSMatthias Springer 
40956d68e8dSMatthias Springer void AllocTensorOp::print(OpAsmPrinter &p) {
41099260e95SMatthias Springer   p << "(" << getDynamicSizes() << ")";
41199260e95SMatthias Springer   if (getCopy())
41299260e95SMatthias Springer     p << " copy(" << getCopy() << ")";
41326ef3868Sbixia1   if (getSizeHint())
41426ef3868Sbixia1     p << " size_hint=" << getSizeHint();
41556d68e8dSMatthias Springer   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
41656d68e8dSMatthias Springer                               AllocTensorOp::getOperandSegmentSizeAttr()});
41756d68e8dSMatthias Springer   p << " : ";
41899260e95SMatthias Springer   auto type = getResult().getType();
419c1fa60b4STres Popp   if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
42056d68e8dSMatthias Springer     p.printStrippedAttrOrType(validType);
42156d68e8dSMatthias Springer   else
42256d68e8dSMatthias Springer     p << type;
42356d68e8dSMatthias Springer }
42456d68e8dSMatthias Springer 
42556d68e8dSMatthias Springer Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
42656d68e8dSMatthias Springer   assert(isDynamicDim(idx) && "expected dynamic dim");
42799260e95SMatthias Springer   if (getCopy())
42899260e95SMatthias Springer     return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
42956d68e8dSMatthias Springer   return getOperand(getIndexOfDynamicSize(idx));
43056d68e8dSMatthias Springer }
43156d68e8dSMatthias Springer 
432ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
43357470abcSAlexander Belyaev // CloneOp
43457470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
43557470abcSAlexander Belyaev 
4367df76121SMarkus Böck OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
43757470abcSAlexander Belyaev   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
43857470abcSAlexander Belyaev }
43957470abcSAlexander Belyaev 
44057470abcSAlexander Belyaev namespace {
44157470abcSAlexander Belyaev 
44257470abcSAlexander Belyaev /// Merge the clone and its source (by converting the clone to a cast) when
44357470abcSAlexander Belyaev /// possible.
44457470abcSAlexander Belyaev struct SimplifyClones : public OpRewritePattern<CloneOp> {
44557470abcSAlexander Belyaev   using OpRewritePattern<CloneOp>::OpRewritePattern;
44657470abcSAlexander Belyaev 
44757470abcSAlexander Belyaev   LogicalResult matchAndRewrite(CloneOp cloneOp,
44857470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
44957470abcSAlexander Belyaev     if (cloneOp.use_empty()) {
45057470abcSAlexander Belyaev       rewriter.eraseOp(cloneOp);
45157470abcSAlexander Belyaev       return success();
45257470abcSAlexander Belyaev     }
45357470abcSAlexander Belyaev 
45499260e95SMatthias Springer     Value source = cloneOp.getInput();
455eaa4b6cfSdonald chen     if (source.getType() != cloneOp.getType() &&
456eaa4b6cfSdonald chen         !memref::CastOp::areCastCompatible({source.getType()},
457eaa4b6cfSdonald chen                                            {cloneOp.getType()}))
458eaa4b6cfSdonald chen       return failure();
459eaa4b6cfSdonald chen 
460894e8a54Sroot     // Aims to find the dealloc op for the canonical source
461894e8a54Sroot     // which otherwise could prevent removal of unnecessary allocs.
462894e8a54Sroot     Value canonicalSource = source;
463894e8a54Sroot     while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
464894e8a54Sroot                canonicalSource.getDefiningOp()))
465894e8a54Sroot       canonicalSource = iface.getViewSource();
46657470abcSAlexander Belyaev 
4670a81ace0SKazu Hirata     std::optional<Operation *> maybeCloneDeallocOp =
46899260e95SMatthias Springer         memref::findDealloc(cloneOp.getOutput());
46957470abcSAlexander Belyaev     // Skip if either of them has > 1 deallocate operations.
470491d2701SKazu Hirata     if (!maybeCloneDeallocOp.has_value())
47157470abcSAlexander Belyaev       return failure();
4720a81ace0SKazu Hirata     std::optional<Operation *> maybeSourceDeallocOp =
473894e8a54Sroot         memref::findDealloc(canonicalSource);
474491d2701SKazu Hirata     if (!maybeSourceDeallocOp.has_value())
47557470abcSAlexander Belyaev       return failure();
47657470abcSAlexander Belyaev     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
47757470abcSAlexander Belyaev     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
47857470abcSAlexander Belyaev 
47957470abcSAlexander Belyaev     // If both are deallocated in the same block, their in-block lifetimes
48057470abcSAlexander Belyaev     // might not fully overlap, so we cannot decide which one to drop.
48157470abcSAlexander Belyaev     if (cloneDeallocOp && sourceDeallocOp &&
48257470abcSAlexander Belyaev         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
48357470abcSAlexander Belyaev       return failure();
48457470abcSAlexander Belyaev 
48557470abcSAlexander Belyaev     Block *currentBlock = cloneOp->getBlock();
48657470abcSAlexander Belyaev     Operation *redundantDealloc = nullptr;
48757470abcSAlexander Belyaev     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
48857470abcSAlexander Belyaev       redundantDealloc = cloneDeallocOp;
48957470abcSAlexander Belyaev     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
49057470abcSAlexander Belyaev       redundantDealloc = sourceDeallocOp;
49157470abcSAlexander Belyaev     }
49257470abcSAlexander Belyaev 
49357470abcSAlexander Belyaev     if (!redundantDealloc)
49457470abcSAlexander Belyaev       return failure();
49557470abcSAlexander Belyaev 
49657470abcSAlexander Belyaev     // Safety check that there are no other deallocations inbetween
49757470abcSAlexander Belyaev     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
49857470abcSAlexander Belyaev     // of source before the uses of the clone. With alias information, we could
49957470abcSAlexander Belyaev     // restrict this to only fail of the dealloc's operand is an alias
50057470abcSAlexander Belyaev     // of the source.
50157470abcSAlexander Belyaev     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
50257470abcSAlexander Belyaev          pos = pos->getNextNode()) {
503bcd14b09SKohei Yamaguchi       // Bail if we run out of operations while looking for a deallocation op.
504bcd14b09SKohei Yamaguchi       if (!pos)
505bcd14b09SKohei Yamaguchi         return failure();
50657470abcSAlexander Belyaev       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
50757470abcSAlexander Belyaev       if (!effectInterface)
50857470abcSAlexander Belyaev         continue;
50957470abcSAlexander Belyaev       if (effectInterface.hasEffect<MemoryEffects::Free>())
51057470abcSAlexander Belyaev         return failure();
51157470abcSAlexander Belyaev     }
51257470abcSAlexander Belyaev 
51368f91cd2SMatthias Springer     if (source.getType() != cloneOp.getType())
51468f91cd2SMatthias Springer       source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
51568f91cd2SMatthias Springer                                                cloneOp.getType(), source);
51668f91cd2SMatthias Springer     rewriter.replaceOp(cloneOp, source);
51757470abcSAlexander Belyaev     rewriter.eraseOp(redundantDealloc);
51857470abcSAlexander Belyaev     return success();
51957470abcSAlexander Belyaev   }
52057470abcSAlexander Belyaev };
52157470abcSAlexander Belyaev 
522be0a7e9fSMehdi Amini } // namespace
52357470abcSAlexander Belyaev 
5249f85c198SRiver Riddle void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
52557470abcSAlexander Belyaev                                           MLIRContext *context) {
526b4e0507cSTres Popp   results.add<SimplifyClones>(context);
52757470abcSAlexander Belyaev }
52857470abcSAlexander Belyaev 
52957470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
53027a431f5SMatthias Springer // DeallocTensorOp
53127a431f5SMatthias Springer //===----------------------------------------------------------------------===//
53227a431f5SMatthias Springer 
53327a431f5SMatthias Springer LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
53427a431f5SMatthias Springer                                          const BufferizationOptions &options) {
53527a431f5SMatthias Springer   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
53627a431f5SMatthias Springer   if (failed(buffer))
53727a431f5SMatthias Springer     return failure();
538caa2a4aeSMatthias Springer   rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
53927a431f5SMatthias Springer   rewriter.eraseOp(getOperation());
54027a431f5SMatthias Springer   return success();
54127a431f5SMatthias Springer }
54227a431f5SMatthias Springer 
54327a431f5SMatthias Springer //===----------------------------------------------------------------------===//
54491464e1dSMatthias Springer // MaterializeInDestinationOp
54591464e1dSMatthias Springer //===----------------------------------------------------------------------===//
54691464e1dSMatthias Springer 
54791464e1dSMatthias Springer bool MaterializeInDestinationOp::bufferizesToMemoryRead(
54891464e1dSMatthias Springer     OpOperand &opOperand, const AnalysisState &state) {
54955585043SMatthias Springer   return opOperand == getSourceMutable();
55091464e1dSMatthias Springer }
55191464e1dSMatthias Springer 
55291464e1dSMatthias Springer bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
55391464e1dSMatthias Springer     OpOperand &opOperand, const AnalysisState &state) {
55455585043SMatthias Springer   if (opOperand == getDestMutable()) {
5550fcaca2fSMatthias Springer     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
5560fcaca2fSMatthias Springer     return true;
5570fcaca2fSMatthias Springer   }
5580fcaca2fSMatthias Springer   return false;
55991464e1dSMatthias Springer }
56091464e1dSMatthias Springer 
5618ee38f3bSMatthias Springer bool MaterializeInDestinationOp::mustBufferizeInPlace(
5628ee38f3bSMatthias Springer     OpOperand &opOperand, const AnalysisState &state) {
5638ee38f3bSMatthias Springer   // The source is only read and not written, so it always bufferizes in-place
5648ee38f3bSMatthias Springer   // by default. The destination is written and is forced to bufferize in-place
5658ee38f3bSMatthias Springer   // (if it is a tensor).
5668ee38f3bSMatthias Springer   return true;
5678ee38f3bSMatthias Springer }
5688ee38f3bSMatthias Springer 
56991464e1dSMatthias Springer AliasingValueList
57091464e1dSMatthias Springer MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
57191464e1dSMatthias Springer                                               const AnalysisState &state) {
57255585043SMatthias Springer   if (opOperand == getDestMutable()) {
5730fcaca2fSMatthias Springer     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
57491464e1dSMatthias Springer     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
5750fcaca2fSMatthias Springer   }
57691464e1dSMatthias Springer   return {};
57791464e1dSMatthias Springer }
57891464e1dSMatthias Springer 
57991464e1dSMatthias Springer LogicalResult
58091464e1dSMatthias Springer MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
58191464e1dSMatthias Springer                                       const BufferizationOptions &options) {
5820fcaca2fSMatthias Springer   bool tensorDest = isa<TensorType>(getDest().getType());
5830fcaca2fSMatthias Springer   Value buffer;
5840fcaca2fSMatthias Springer   if (tensorDest) {
5850fcaca2fSMatthias Springer     FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
5860fcaca2fSMatthias Springer     if (failed(maybeBuffer))
58791464e1dSMatthias Springer       return failure();
5880fcaca2fSMatthias Springer     buffer = *maybeBuffer;
5890fcaca2fSMatthias Springer   } else {
5900fcaca2fSMatthias Springer     assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
5910fcaca2fSMatthias Springer     buffer = getDest();
5920fcaca2fSMatthias Springer   }
593437c6217SMatthias Springer   auto srcBuffer = getBuffer(rewriter, getSource(), options);
594437c6217SMatthias Springer   if (failed(srcBuffer))
595437c6217SMatthias Springer     return failure();
596437c6217SMatthias Springer   if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
597437c6217SMatthias Springer     return failure();
5980fcaca2fSMatthias Springer   replaceOpWithBufferizedValues(rewriter, getOperation(),
5990fcaca2fSMatthias Springer                                 tensorDest ? ValueRange(buffer) : ValueRange());
60091464e1dSMatthias Springer   return success();
60191464e1dSMatthias Springer }
60291464e1dSMatthias Springer 
60364839fbdSMatthias Springer bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
60464839fbdSMatthias Springer     const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
60564839fbdSMatthias Springer   // As elements are copied from the "source" buffer to the "dest" buffer,
60664839fbdSMatthias Springer   // already copied elements are not read a second time.
60764839fbdSMatthias Springer   return true;
60864839fbdSMatthias Springer }
60964839fbdSMatthias Springer 
61091464e1dSMatthias Springer LogicalResult MaterializeInDestinationOp::reifyResultShapes(
61191464e1dSMatthias Springer     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6120fcaca2fSMatthias Springer   if (getOperation()->getNumResults() == 1) {
6130fcaca2fSMatthias Springer     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
6140fcaca2fSMatthias Springer     reifiedReturnShapes.resize(1,
6150fcaca2fSMatthias Springer                                SmallVector<OpFoldResult>(getType().getRank()));
6160fcaca2fSMatthias Springer     reifiedReturnShapes[0] =
6170fcaca2fSMatthias Springer         tensor::getMixedSizes(builder, getLoc(), getDest());
6180fcaca2fSMatthias Springer   }
61991464e1dSMatthias Springer   return success();
62091464e1dSMatthias Springer }
62191464e1dSMatthias Springer 
62264839fbdSMatthias Springer Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
62364839fbdSMatthias Springer                                                         Location loc) {
6240fcaca2fSMatthias Springer   if (isa<TensorType>(getDest().getType())) {
62564839fbdSMatthias Springer     // The subset is the entire destination tensor.
62664839fbdSMatthias Springer     return getDest();
62764839fbdSMatthias Springer   }
62864839fbdSMatthias Springer 
6296d88ac11SMatthias Springer   // The "restrict" attribute is transferred from this op to the newly created
6306d88ac11SMatthias Springer   // to_tensor op. If this op does not the "restrict" attribute, the subset
6316d88ac11SMatthias Springer   // extraction cannot be built because there is no guarantee that there is no
6326d88ac11SMatthias Springer   // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
6336d88ac11SMatthias Springer   if (!getRestrict())
6346d88ac11SMatthias Springer     return {};
6356d88ac11SMatthias Springer 
6360fcaca2fSMatthias Springer   // Build a bufferization.to_tensor op.
6370fcaca2fSMatthias Springer   assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
6380fcaca2fSMatthias Springer   assert(getRestrict() &&
6390fcaca2fSMatthias Springer          "expected that ops with memrefs dest have 'restrict'");
6406d88ac11SMatthias Springer   setRestrict(false);
6416d88ac11SMatthias Springer   return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
6420fcaca2fSMatthias Springer                                     getWritable());
6430fcaca2fSMatthias Springer }
6440fcaca2fSMatthias Springer 
64564839fbdSMatthias Springer bool MaterializeInDestinationOp::isEquivalentSubset(
64664839fbdSMatthias Springer     Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
64764839fbdSMatthias Springer   return equivalenceFn(getDest(), candidate);
64864839fbdSMatthias Springer }
64964839fbdSMatthias Springer 
65064839fbdSMatthias Springer SmallVector<Value>
65164839fbdSMatthias Springer MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
65264839fbdSMatthias Springer   return {getDest()};
65364839fbdSMatthias Springer }
65464839fbdSMatthias Springer 
65564839fbdSMatthias Springer OpOperand &MaterializeInDestinationOp::getSourceOperand() {
65664839fbdSMatthias Springer   return getOperation()->getOpOperand(0) /*source*/;
65764839fbdSMatthias Springer }
65864839fbdSMatthias Springer 
6591abd8d1aSMatthias Springer bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
6601abd8d1aSMatthias Springer     SubsetOpInterface subsetOp,
6611abd8d1aSMatthias Springer     function_ref<bool(Value, Value)> equivalenceFn) {
6621abd8d1aSMatthias Springer   return false;
6631abd8d1aSMatthias Springer }
6641abd8d1aSMatthias Springer 
6651abd8d1aSMatthias Springer bool MaterializeInDestinationOp::operatesOnDisjointSubset(
6661abd8d1aSMatthias Springer     SubsetOpInterface subsetOp,
6671abd8d1aSMatthias Springer     function_ref<bool(Value, Value)> equivalenceFn) {
6681abd8d1aSMatthias Springer   return false;
6691abd8d1aSMatthias Springer }
6701abd8d1aSMatthias Springer 
6710fcaca2fSMatthias Springer LogicalResult MaterializeInDestinationOp::verify() {
6720fcaca2fSMatthias Springer   if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
6730fcaca2fSMatthias Springer     return emitOpError("'dest' must be a tensor or a memref");
6740fcaca2fSMatthias Springer   if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
6750fcaca2fSMatthias Springer     if (getOperation()->getNumResults() != 1)
6760fcaca2fSMatthias Springer       return emitOpError("tensor 'dest' implies exactly one tensor result");
6770fcaca2fSMatthias Springer     if (destType != getResult().getType())
6780fcaca2fSMatthias Springer       return emitOpError("result and 'dest' types must match");
6790fcaca2fSMatthias Springer   }
6800fcaca2fSMatthias Springer   if (isa<BaseMemRefType>(getDest().getType()) &&
6810fcaca2fSMatthias Springer       getOperation()->getNumResults() != 0)
6820fcaca2fSMatthias Springer     return emitOpError("memref 'dest' implies zero results");
6836d88ac11SMatthias Springer   if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
6846d88ac11SMatthias Springer     return emitOpError("'restrict' is valid only for memref destinations");
6850fcaca2fSMatthias Springer   if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
6860fcaca2fSMatthias Springer     return emitOpError("'writable' must be specified if and only if the "
6870fcaca2fSMatthias Springer                        "destination is of memref type");
6889d4b20a4SMatthias Springer   TensorType srcType = getSource().getType();
6899d4b20a4SMatthias Springer   ShapedType destType = cast<ShapedType>(getDest().getType());
6909d4b20a4SMatthias Springer   if (srcType.hasRank() != destType.hasRank())
6919d4b20a4SMatthias Springer     return emitOpError("source/destination shapes are incompatible");
6929d4b20a4SMatthias Springer   if (srcType.hasRank()) {
6939d4b20a4SMatthias Springer     if (srcType.getRank() != destType.getRank())
6949d4b20a4SMatthias Springer       return emitOpError("rank mismatch between source and destination shape");
6959d4b20a4SMatthias Springer     for (auto [src, dest] :
6969d4b20a4SMatthias Springer          llvm::zip(srcType.getShape(), destType.getShape())) {
6979d4b20a4SMatthias Springer       if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
6989d4b20a4SMatthias Springer         // Cannot verify dynamic dimension size. Assume that that they match at
6999d4b20a4SMatthias Springer         // runtime.
7009d4b20a4SMatthias Springer         continue;
7019d4b20a4SMatthias Springer       }
7029d4b20a4SMatthias Springer       if (src != dest)
7039d4b20a4SMatthias Springer         return emitOpError("source/destination shapes are incompatible");
7049d4b20a4SMatthias Springer     }
7059d4b20a4SMatthias Springer   }
7060fcaca2fSMatthias Springer   return success();
7070fcaca2fSMatthias Springer }
7080fcaca2fSMatthias Springer 
7090fcaca2fSMatthias Springer void MaterializeInDestinationOp::build(OpBuilder &builder,
7100fcaca2fSMatthias Springer                                        OperationState &state, Value source,
7110fcaca2fSMatthias Springer                                        Value dest) {
712437c6217SMatthias Springer   auto destTensorType = dyn_cast<TensorType>(dest.getType());
713437c6217SMatthias Springer   build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
714437c6217SMatthias Springer         source, dest);
7150fcaca2fSMatthias Springer }
7160fcaca2fSMatthias Springer 
7170fcaca2fSMatthias Springer bool MaterializeInDestinationOp::isWritable(Value value,
7180fcaca2fSMatthias Springer                                             const AnalysisState &state) {
7190fcaca2fSMatthias Springer   return isa<TensorType>(getDest().getType()) ? true : getWritable();
7200fcaca2fSMatthias Springer }
7210fcaca2fSMatthias Springer 
7220fcaca2fSMatthias Springer MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
7230fcaca2fSMatthias Springer   return getDestMutable();
7240fcaca2fSMatthias Springer }
7250fcaca2fSMatthias Springer 
7260fcaca2fSMatthias Springer void MaterializeInDestinationOp::getEffects(
7270fcaca2fSMatthias Springer     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
7280fcaca2fSMatthias Springer         &effects) {
7290fcaca2fSMatthias Springer   if (isa<BaseMemRefType>(getDest().getType()))
7302c1ae801Sdonald chen     effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
7310fcaca2fSMatthias Springer                          SideEffects::DefaultResource::get());
7320fcaca2fSMatthias Springer }
7330fcaca2fSMatthias Springer 
73491464e1dSMatthias Springer //===----------------------------------------------------------------------===//
73557470abcSAlexander Belyaev // ToTensorOp
73657470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
73757470abcSAlexander Belyaev 
7388f7e7400SMatthias Springer bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
7398f7e7400SMatthias Springer   return getWritable();
7408f7e7400SMatthias Springer }
7418f7e7400SMatthias Springer 
7427df76121SMarkus Böck OpFoldResult ToTensorOp::fold(FoldAdaptor) {
74399260e95SMatthias Springer   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
74457470abcSAlexander Belyaev     // Approximate alias analysis by conservatively folding only when no there
74557470abcSAlexander Belyaev     // is no interleaved operation.
74657470abcSAlexander Belyaev     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
74757470abcSAlexander Belyaev         toMemref->getNextNode() == this->getOperation())
74899260e95SMatthias Springer       return toMemref.getTensor();
74957470abcSAlexander Belyaev   return {};
75057470abcSAlexander Belyaev }
75157470abcSAlexander Belyaev 
75257470abcSAlexander Belyaev namespace {
75357470abcSAlexander Belyaev struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
75457470abcSAlexander Belyaev   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
75557470abcSAlexander Belyaev 
75657470abcSAlexander Belyaev   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
75757470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
75804235d07SJacques Pienaar     auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
75957470abcSAlexander Belyaev     if (!memrefToTensorOp)
76057470abcSAlexander Belyaev       return failure();
76157470abcSAlexander Belyaev 
76299260e95SMatthias Springer     rewriter.replaceOpWithNewOp<memref::DimOp>(
76304235d07SJacques Pienaar         dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
76457470abcSAlexander Belyaev     return success();
76557470abcSAlexander Belyaev   }
76657470abcSAlexander Belyaev };
76757470abcSAlexander Belyaev } // namespace
76857470abcSAlexander Belyaev 
76957470abcSAlexander Belyaev void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
77057470abcSAlexander Belyaev                                              MLIRContext *context) {
771fc9b37ddSMatthias Springer   results.add<DimOfToTensorFolder>(context);
77257470abcSAlexander Belyaev }
77357470abcSAlexander Belyaev 
77457470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
77557470abcSAlexander Belyaev // ToMemrefOp
77657470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
77757470abcSAlexander Belyaev 
7787df76121SMarkus Böck OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
77999260e95SMatthias Springer   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
78099260e95SMatthias Springer     if (memrefToTensor.getMemref().getType() == getType())
78199260e95SMatthias Springer       return memrefToTensor.getMemref();
78257470abcSAlexander Belyaev   return {};
78357470abcSAlexander Belyaev }
78457470abcSAlexander Belyaev 
78557470abcSAlexander Belyaev namespace {
78657470abcSAlexander Belyaev 
78757470abcSAlexander Belyaev /// Replace tensor.cast + to_memref by to_memref + memref.cast.
78857470abcSAlexander Belyaev struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
78957470abcSAlexander Belyaev   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
79057470abcSAlexander Belyaev 
79157470abcSAlexander Belyaev   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
79257470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const final {
79357470abcSAlexander Belyaev     auto tensorCastOperand =
79457470abcSAlexander Belyaev         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
79557470abcSAlexander Belyaev     if (!tensorCastOperand)
79657470abcSAlexander Belyaev       return failure();
797c1fa60b4STres Popp     auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
798c1fa60b4STres Popp         tensorCastOperand.getOperand().getType());
79957470abcSAlexander Belyaev     if (!srcTensorType)
80057470abcSAlexander Belyaev       return failure();
80157470abcSAlexander Belyaev     auto memrefType = MemRefType::get(srcTensorType.getShape(),
80257470abcSAlexander Belyaev                                       srcTensorType.getElementType());
80357470abcSAlexander Belyaev     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
80457470abcSAlexander Belyaev                                                tensorCastOperand.getOperand());
80557470abcSAlexander Belyaev     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
80657470abcSAlexander Belyaev                                                 memref);
80757470abcSAlexander Belyaev     return success();
80857470abcSAlexander Belyaev   }
80957470abcSAlexander Belyaev };
81057470abcSAlexander Belyaev 
811cb471241SMatthias Springer /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
812cb471241SMatthias Springer /// cast if necessary.
813cb471241SMatthias Springer struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
814b00ee46bSMatthias Springer   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
815b00ee46bSMatthias Springer 
816b00ee46bSMatthias Springer   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
817b00ee46bSMatthias Springer                                 PatternRewriter &rewriter) const final {
818c515c780SMatthias Gehre     BufferizationOptions options;
819c515c780SMatthias Gehre     options.bufferAlignment = 0;
820c515c780SMatthias Gehre     return foldToMemrefToTensorPair(rewriter, toMemref, options);
821b00ee46bSMatthias Springer   }
82257470abcSAlexander Belyaev };
82357470abcSAlexander Belyaev 
82457470abcSAlexander Belyaev /// Fold a load on a to_memref operation into an tensor.extract on the
82557470abcSAlexander Belyaev /// corresponding tensor.
82657470abcSAlexander Belyaev struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
82757470abcSAlexander Belyaev   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
82857470abcSAlexander Belyaev 
82957470abcSAlexander Belyaev   LogicalResult matchAndRewrite(memref::LoadOp load,
83057470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
831136d746eSJacques Pienaar     auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
83257470abcSAlexander Belyaev     if (!toMemref)
83357470abcSAlexander Belyaev       return failure();
83457470abcSAlexander Belyaev 
83599260e95SMatthias Springer     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
836136d746eSJacques Pienaar                                                    load.getIndices());
83757470abcSAlexander Belyaev     return success();
83857470abcSAlexander Belyaev   }
83957470abcSAlexander Belyaev };
84057470abcSAlexander Belyaev 
84157470abcSAlexander Belyaev /// Fold dim of a to_memref into the dim of the tensor.
84257470abcSAlexander Belyaev struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
84357470abcSAlexander Belyaev   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
84457470abcSAlexander Belyaev 
84557470abcSAlexander Belyaev   LogicalResult matchAndRewrite(memref::DimOp dimOp,
84657470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
847136d746eSJacques Pienaar     auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
84857470abcSAlexander Belyaev     if (!castOp)
84957470abcSAlexander Belyaev       return failure();
85057470abcSAlexander Belyaev     Value newSource = castOp.getOperand();
851136d746eSJacques Pienaar     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
852136d746eSJacques Pienaar                                                dimOp.getIndex());
85357470abcSAlexander Belyaev     return success();
85457470abcSAlexander Belyaev   }
85557470abcSAlexander Belyaev };
85657470abcSAlexander Belyaev 
85757470abcSAlexander Belyaev } // namespace
85857470abcSAlexander Belyaev 
85957470abcSAlexander Belyaev void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
86057470abcSAlexander Belyaev                                              MLIRContext *context) {
861cb471241SMatthias Springer   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
862cb471241SMatthias Springer               ToMemrefToTensorFolding>(context);
86357470abcSAlexander Belyaev }
86457470abcSAlexander Belyaev 
865b00ee46bSMatthias Springer LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
866b55d55ecSMatthias Springer                                     const BufferizationOptions &options) {
867b00ee46bSMatthias Springer   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
868c515c780SMatthias Gehre   (void)foldToMemrefToTensorPair(rewriter, *this, options);
8690b293bf0SMatthias Springer   // Note: The return value of `bufferize` indicates whether there was an error
8700b293bf0SMatthias Springer   // or not. (And not whether the pattern matched or not.)
8710b293bf0SMatthias Springer   return success();
872b00ee46bSMatthias Springer }
873b00ee46bSMatthias Springer 
874e8bcc37fSRamkumar Ramachandra std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
875e8bcc37fSRamkumar Ramachandra                                                  Value alloc) {
87657470abcSAlexander Belyaev   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
87757470abcSAlexander Belyaev       .getOperation();
87857470abcSAlexander Belyaev }
87957470abcSAlexander Belyaev 
880e8bcc37fSRamkumar Ramachandra std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
88157470abcSAlexander Belyaev   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
88257470abcSAlexander Belyaev }
88357470abcSAlexander Belyaev 
88457470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
885d5825621SMartin Erhart // DeallocOp
886d5825621SMartin Erhart //===----------------------------------------------------------------------===//
887d5825621SMartin Erhart 
888d5825621SMartin Erhart LogicalResult DeallocOp::inferReturnTypes(
889d5825621SMartin Erhart     MLIRContext *context, std::optional<::mlir::Location> location,
890d5825621SMartin Erhart     ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
891d5825621SMartin Erhart     RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
892d5825621SMartin Erhart   DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
8934bde084fSMartin Erhart   inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
8944bde084fSMartin Erhart                                           IntegerType::get(context, 1));
895d5825621SMartin Erhart   return success();
896d5825621SMartin Erhart }
897d5825621SMartin Erhart 
898d5825621SMartin Erhart LogicalResult DeallocOp::verify() {
899d5825621SMartin Erhart   if (getMemrefs().size() != getConditions().size())
900d5825621SMartin Erhart     return emitOpError(
901d5825621SMartin Erhart         "must have the same number of conditions as memrefs to deallocate");
9020ef990d5SMatthias Springer   if (getRetained().size() != getUpdatedConditions().size())
9030ef990d5SMatthias Springer     return emitOpError("must have the same number of updated conditions "
9040ef990d5SMatthias Springer                        "(results) as retained operands");
905d5825621SMartin Erhart   return success();
906d5825621SMartin Erhart }
907d5825621SMartin Erhart 
9084bde084fSMartin Erhart static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
9095c7d97beSMartin Erhart                                             ValueRange memrefs,
9105c7d97beSMartin Erhart                                             ValueRange conditions,
9114bde084fSMartin Erhart                                             PatternRewriter &rewriter) {
9125c7d97beSMartin Erhart   if (deallocOp.getMemrefs() == memrefs &&
9135c7d97beSMartin Erhart       deallocOp.getConditions() == conditions)
9144bde084fSMartin Erhart     return failure();
9154bde084fSMartin Erhart 
9165fcf907bSMatthias Springer   rewriter.modifyOpInPlace(deallocOp, [&]() {
9174bde084fSMartin Erhart     deallocOp.getMemrefsMutable().assign(memrefs);
9184bde084fSMartin Erhart     deallocOp.getConditionsMutable().assign(conditions);
9194bde084fSMartin Erhart   });
9204bde084fSMartin Erhart   return success();
9214bde084fSMartin Erhart }
9224bde084fSMartin Erhart 
92317aaa651SMartin Erhart namespace {
92417aaa651SMartin Erhart 
9254bde084fSMartin Erhart /// Remove duplicate values in the list of memrefs to be deallocated. We need to
9264bde084fSMartin Erhart /// make sure the corresponding condition value is updated accordingly since
9274bde084fSMartin Erhart /// their two conditions might not cover the same set of cases. In that case, we
9284bde084fSMartin Erhart /// have to combine them (by computing the disjunction of them).
92917aaa651SMartin Erhart /// Example:
93017aaa651SMartin Erhart /// ```mlir
9314bde084fSMartin Erhart /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
93217aaa651SMartin Erhart /// ```
93317aaa651SMartin Erhart /// is canonicalized to
93417aaa651SMartin Erhart /// ```mlir
93517aaa651SMartin Erhart /// %0 = arith.ori %arg1, %arg2 : i1
9364bde084fSMartin Erhart /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
93717aaa651SMartin Erhart /// ```
9384bde084fSMartin Erhart struct DeallocRemoveDuplicateDeallocMemrefs
9394bde084fSMartin Erhart     : public OpRewritePattern<DeallocOp> {
94017aaa651SMartin Erhart   using OpRewritePattern<DeallocOp>::OpRewritePattern;
94117aaa651SMartin Erhart 
94217aaa651SMartin Erhart   LogicalResult matchAndRewrite(DeallocOp deallocOp,
94317aaa651SMartin Erhart                                 PatternRewriter &rewriter) const override {
94417aaa651SMartin Erhart     // Unique memrefs to be deallocated.
94517aaa651SMartin Erhart     DenseMap<Value, unsigned> memrefToCondition;
9464bde084fSMartin Erhart     SmallVector<Value> newMemrefs, newConditions;
947b0688ed0SMartin Erhart     for (auto [i, memref, cond] :
948b0688ed0SMartin Erhart          llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
94917aaa651SMartin Erhart       if (memrefToCondition.count(memref)) {
95017aaa651SMartin Erhart         // If the dealloc conditions don't match, we need to make sure that the
95117aaa651SMartin Erhart         // dealloc happens on the union of cases.
95217aaa651SMartin Erhart         Value &newCond = newConditions[memrefToCondition[memref]];
95317aaa651SMartin Erhart         if (newCond != cond)
95417aaa651SMartin Erhart           newCond =
95517aaa651SMartin Erhart               rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
95617aaa651SMartin Erhart       } else {
95717aaa651SMartin Erhart         memrefToCondition.insert({memref, newConditions.size()});
95817aaa651SMartin Erhart         newMemrefs.push_back(memref);
95917aaa651SMartin Erhart         newConditions.push_back(cond);
96017aaa651SMartin Erhart       }
96117aaa651SMartin Erhart     }
96217aaa651SMartin Erhart 
96317aaa651SMartin Erhart     // Return failure if we don't change anything such that we don't run into an
96417aaa651SMartin Erhart     // infinite loop of pattern applications.
9654bde084fSMartin Erhart     return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
9664bde084fSMartin Erhart                                   rewriter);
9674bde084fSMartin Erhart   }
9684bde084fSMartin Erhart };
9694bde084fSMartin Erhart 
9704bde084fSMartin Erhart /// Remove duplicate values in the list of retained memrefs. We need to make
9714bde084fSMartin Erhart /// sure the corresponding result condition value is replaced properly.
9724bde084fSMartin Erhart /// Example:
9734bde084fSMartin Erhart /// ```mlir
9744bde084fSMartin Erhart /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
9754bde084fSMartin Erhart /// ```
9764bde084fSMartin Erhart /// is canonicalized to
9774bde084fSMartin Erhart /// ```mlir
9784bde084fSMartin Erhart /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
9794bde084fSMartin Erhart /// ```
9804bde084fSMartin Erhart struct DeallocRemoveDuplicateRetainedMemrefs
9814bde084fSMartin Erhart     : public OpRewritePattern<DeallocOp> {
9824bde084fSMartin Erhart   using OpRewritePattern<DeallocOp>::OpRewritePattern;
9834bde084fSMartin Erhart 
9844bde084fSMartin Erhart   LogicalResult matchAndRewrite(DeallocOp deallocOp,
9854bde084fSMartin Erhart                                 PatternRewriter &rewriter) const override {
9864bde084fSMartin Erhart     // Unique retained values
9874bde084fSMartin Erhart     DenseMap<Value, unsigned> seen;
9884bde084fSMartin Erhart     SmallVector<Value> newRetained;
9894bde084fSMartin Erhart     SmallVector<unsigned> resultReplacementIdx;
9904bde084fSMartin Erhart     unsigned i = 0;
9914bde084fSMartin Erhart     for (auto retained : deallocOp.getRetained()) {
9924bde084fSMartin Erhart       if (seen.count(retained)) {
9934bde084fSMartin Erhart         resultReplacementIdx.push_back(seen[retained]);
9944bde084fSMartin Erhart         continue;
9954bde084fSMartin Erhart       }
9964bde084fSMartin Erhart 
9974bde084fSMartin Erhart       seen[retained] = i;
9984bde084fSMartin Erhart       newRetained.push_back(retained);
9994bde084fSMartin Erhart       resultReplacementIdx.push_back(i++);
10004bde084fSMartin Erhart     }
10014bde084fSMartin Erhart 
10024bde084fSMartin Erhart     // Return failure if we don't change anything such that we don't run into an
10034bde084fSMartin Erhart     // infinite loop of pattern applications.
10044bde084fSMartin Erhart     if (newRetained.size() == deallocOp.getRetained().size())
100517aaa651SMartin Erhart       return failure();
100617aaa651SMartin Erhart 
100717aaa651SMartin Erhart     // We need to create a new op because the number of results is always the
100817aaa651SMartin Erhart     // same as the number of condition operands.
10094bde084fSMartin Erhart     auto newDeallocOp =
10104bde084fSMartin Erhart         rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
10114bde084fSMartin Erhart                                    deallocOp.getConditions(), newRetained);
10124bde084fSMartin Erhart     SmallVector<Value> replacements(
10134bde084fSMartin Erhart         llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
10144bde084fSMartin Erhart           return newDeallocOp.getUpdatedConditions()[idx];
10154bde084fSMartin Erhart         }));
10164bde084fSMartin Erhart     rewriter.replaceOp(deallocOp, replacements);
101717aaa651SMartin Erhart     return success();
101817aaa651SMartin Erhart   }
101917aaa651SMartin Erhart };
102017aaa651SMartin Erhart 
10214bde084fSMartin Erhart /// Erase deallocation operations where the variadic list of memrefs to
10224bde084fSMartin Erhart /// deallocate is empty. Example:
10234bde084fSMartin Erhart /// ```mlir
10244bde084fSMartin Erhart /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1025b0688ed0SMartin Erhart /// ```
1026b0688ed0SMartin Erhart struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1027b0688ed0SMartin Erhart   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1028b0688ed0SMartin Erhart 
1029b0688ed0SMartin Erhart   LogicalResult matchAndRewrite(DeallocOp deallocOp,
1030b0688ed0SMartin Erhart                                 PatternRewriter &rewriter) const override {
1031b0688ed0SMartin Erhart     if (deallocOp.getMemrefs().empty()) {
10324bde084fSMartin Erhart       Value constFalse = rewriter.create<arith::ConstantOp>(
10334bde084fSMartin Erhart           deallocOp.getLoc(), rewriter.getBoolAttr(false));
10344bde084fSMartin Erhart       rewriter.replaceOp(
10354bde084fSMartin Erhart           deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
10364bde084fSMartin Erhart                                         constFalse));
1037b0688ed0SMartin Erhart       return success();
1038b0688ed0SMartin Erhart     }
1039b0688ed0SMartin Erhart     return failure();
1040b0688ed0SMartin Erhart   }
1041b0688ed0SMartin Erhart };
1042b0688ed0SMartin Erhart 
1043d26eb822SMartin Erhart /// Removes memrefs from the deallocation list if their associated condition is
1044d26eb822SMartin Erhart /// always 'false'.
1045d26eb822SMartin Erhart ///
1046d26eb822SMartin Erhart /// Example:
1047d26eb822SMartin Erhart /// ```
10484bde084fSMartin Erhart /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1049d26eb822SMartin Erhart ///                           if (%arg2, %false)
1050d26eb822SMartin Erhart /// ```
1051d26eb822SMartin Erhart /// becomes
1052d26eb822SMartin Erhart /// ```
10534bde084fSMartin Erhart /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1054d26eb822SMartin Erhart /// ```
1055d26eb822SMartin Erhart struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1056d26eb822SMartin Erhart   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1057d26eb822SMartin Erhart 
1058d26eb822SMartin Erhart   LogicalResult matchAndRewrite(DeallocOp deallocOp,
1059d26eb822SMartin Erhart                                 PatternRewriter &rewriter) const override {
1060d26eb822SMartin Erhart     SmallVector<Value> newMemrefs, newConditions;
10614bde084fSMartin Erhart     for (auto [memref, cond] :
10624bde084fSMartin Erhart          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
10634bde084fSMartin Erhart       if (!matchPattern(cond, m_Zero())) {
1064d26eb822SMartin Erhart         newMemrefs.push_back(memref);
1065d26eb822SMartin Erhart         newConditions.push_back(cond);
10664bde084fSMartin Erhart       }
1067d26eb822SMartin Erhart     }
1068d26eb822SMartin Erhart 
10694bde084fSMartin Erhart     return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
10704bde084fSMartin Erhart                                   rewriter);
1071d26eb822SMartin Erhart   }
1072d26eb822SMartin Erhart };
1073d26eb822SMartin Erhart 
10745c7d97beSMartin Erhart /// The `memref.extract_strided_metadata` is often inserted to get the base
10755c7d97beSMartin Erhart /// memref if the operand is not already guaranteed to be the result of a memref
10765c7d97beSMartin Erhart /// allocation operation. This canonicalization pattern removes this extraction
10775c7d97beSMartin Erhart /// operation if the operand is now produced by an allocation operation (e.g.,
10785c7d97beSMartin Erhart /// due to other canonicalizations simplifying the IR).
10795c7d97beSMartin Erhart ///
10805c7d97beSMartin Erhart /// Example:
10815c7d97beSMartin Erhart /// ```mlir
10825c7d97beSMartin Erhart /// %alloc = memref.alloc() : memref<2xi32>
10835c7d97beSMartin Erhart /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
10845c7d97beSMartin Erhart ///   %alloc : memref<2xi32> -> memref<i32>, index, index, index
10855c7d97beSMartin Erhart /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
10865c7d97beSMartin Erhart /// ```
10875c7d97beSMartin Erhart /// is canonicalized to
10885c7d97beSMartin Erhart /// ```mlir
10895c7d97beSMartin Erhart /// %alloc = memref.alloc() : memref<2xi32>
10905c7d97beSMartin Erhart /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
10915c7d97beSMartin Erhart /// ```
10925c7d97beSMartin Erhart struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
10935c7d97beSMartin Erhart   using OpRewritePattern<DeallocOp>::OpRewritePattern;
10945c7d97beSMartin Erhart 
10955c7d97beSMartin Erhart   LogicalResult matchAndRewrite(DeallocOp deallocOp,
10965c7d97beSMartin Erhart                                 PatternRewriter &rewriter) const override {
10975c7d97beSMartin Erhart     SmallVector<Value> newMemrefs(
10985c7d97beSMartin Erhart         llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
10995c7d97beSMartin Erhart           auto extractStridedOp =
11005c7d97beSMartin Erhart               memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
11015c7d97beSMartin Erhart           if (!extractStridedOp)
11025c7d97beSMartin Erhart             return memref;
11035c7d97beSMartin Erhart           Value allocMemref = extractStridedOp.getOperand();
11045c7d97beSMartin Erhart           auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
11055c7d97beSMartin Erhart           if (!allocOp)
11065c7d97beSMartin Erhart             return memref;
11075c7d97beSMartin Erhart           if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
11085c7d97beSMartin Erhart             return allocMemref;
11095c7d97beSMartin Erhart           return memref;
11105c7d97beSMartin Erhart         }));
11115c7d97beSMartin Erhart 
11125c7d97beSMartin Erhart     return updateDeallocIfChanged(deallocOp, newMemrefs,
11135c7d97beSMartin Erhart                                   deallocOp.getConditions(), rewriter);
11145c7d97beSMartin Erhart   }
11155c7d97beSMartin Erhart };
11165c7d97beSMartin Erhart 
1117778494aeSMartin Erhart /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1118778494aeSMartin Erhart /// other user of the allocated value and the allocating operation can be safely
1119778494aeSMartin Erhart /// removed. If the same value is present multiple times, this pattern relies on
1120778494aeSMartin Erhart /// other canonicalization patterns to remove the duplicate first.
1121778494aeSMartin Erhart ///
1122778494aeSMartin Erhart /// Example:
1123778494aeSMartin Erhart /// ```mlir
1124778494aeSMartin Erhart /// %alloc = memref.alloc() : memref<2xi32>
1125778494aeSMartin Erhart /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1126778494aeSMartin Erhart /// ```
1127778494aeSMartin Erhart /// is canonicalized to
1128778494aeSMartin Erhart /// ```mlir
1129778494aeSMartin Erhart /// bufferization.dealloc (%arg0 : ...) if (%true)
1130778494aeSMartin Erhart /// ```
1131778494aeSMartin Erhart struct RemoveAllocDeallocPairWhenNoOtherUsers
1132778494aeSMartin Erhart     : public OpRewritePattern<DeallocOp> {
1133778494aeSMartin Erhart   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1134778494aeSMartin Erhart 
1135778494aeSMartin Erhart   LogicalResult matchAndRewrite(DeallocOp deallocOp,
1136778494aeSMartin Erhart                                 PatternRewriter &rewriter) const override {
1137778494aeSMartin Erhart     SmallVector<Value> newMemrefs, newConditions;
1138778494aeSMartin Erhart     SmallVector<Operation *> toDelete;
1139778494aeSMartin Erhart     for (auto [memref, cond] :
1140778494aeSMartin Erhart          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1141778494aeSMartin Erhart       if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1142778494aeSMartin Erhart         // Check that it is indeed an allocate effect, that the op has no other
1143778494aeSMartin Erhart         // side effects (which would not allow us to remove the op), and that
1144778494aeSMartin Erhart         // there are no other users.
1145778494aeSMartin Erhart         if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1146778494aeSMartin Erhart             hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1147778494aeSMartin Erhart             memref.hasOneUse()) {
1148778494aeSMartin Erhart           toDelete.push_back(allocOp);
1149778494aeSMartin Erhart           continue;
1150778494aeSMartin Erhart         }
1151778494aeSMartin Erhart       }
1152778494aeSMartin Erhart 
1153778494aeSMartin Erhart       newMemrefs.push_back(memref);
1154778494aeSMartin Erhart       newConditions.push_back(cond);
1155778494aeSMartin Erhart     }
1156778494aeSMartin Erhart 
1157778494aeSMartin Erhart     if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1158778494aeSMartin Erhart                                       rewriter)))
1159778494aeSMartin Erhart       return failure();
1160778494aeSMartin Erhart 
1161778494aeSMartin Erhart     for (Operation *op : toDelete)
1162778494aeSMartin Erhart       rewriter.eraseOp(op);
1163778494aeSMartin Erhart 
1164778494aeSMartin Erhart     return success();
1165778494aeSMartin Erhart   }
1166778494aeSMartin Erhart };
1167778494aeSMartin Erhart 
116817aaa651SMartin Erhart } // anonymous namespace
116917aaa651SMartin Erhart 
117017aaa651SMartin Erhart void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
117117aaa651SMartin Erhart                                             MLIRContext *context) {
1172fff18305SMartin Erhart   populateDeallocOpCanonicalizationPatterns(results, context);
1173fff18305SMartin Erhart }
1174fff18305SMartin Erhart 
1175fff18305SMartin Erhart void bufferization::populateDeallocOpCanonicalizationPatterns(
1176fff18305SMartin Erhart     RewritePatternSet &patterns, MLIRContext *context) {
1177fff18305SMartin Erhart   patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
117887f2dee4SMartin Erhart                DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1179778494aeSMartin Erhart                EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1180778494aeSMartin Erhart                RemoveAllocDeallocPairWhenNoOtherUsers>(context);
118117aaa651SMartin Erhart }
118217aaa651SMartin Erhart 
1183d5825621SMartin Erhart //===----------------------------------------------------------------------===//
118457470abcSAlexander Belyaev // TableGen'd op method definitions
118557470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
118657470abcSAlexander Belyaev 
118757470abcSAlexander Belyaev #define GET_OP_CLASSES
118857470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
1189