xref: /llvm-project/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
26 FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
27     OpBuilder &b, Value value, MemRefType destType,
28     const BufferizationOptions &options) {
29   auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31   // Element type, rank and memory space must match.
32   if (srcType.getElementType() != destType.getElementType())
33     return failure();
34   if (srcType.getMemorySpace() != destType.getMemorySpace())
35     return failure();
36   if (srcType.getRank() != destType.getRank())
37     return failure();
38 
39   // In case the affine maps are different, we may need to use a copy if we go
40   // from dynamic to static offset or stride (the canonicalization cannot know
41   // at this point that it is really cast compatible).
42   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43     int64_t sourceOffset, targetOffset;
44     SmallVector<int64_t, 4> sourceStrides, targetStrides;
45     if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
46         failed(target.getStridesAndOffset(targetStrides, targetOffset)))
47       return false;
48     auto dynamicToStatic = [](int64_t a, int64_t b) {
49       return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50     };
51     if (dynamicToStatic(sourceOffset, targetOffset))
52       return false;
53     for (auto it : zip(sourceStrides, targetStrides))
54       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
55         return false;
56     return true;
57   };
58 
59   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60   // ensure that we only generate casts that always succeed at runtime, we check
61   // a fix extra conditions in `isGuaranteedCastCompatible`.
62   if (memref::CastOp::areCastCompatible(srcType, destType) &&
63       isGuaranteedCastCompatible(srcType, destType)) {
64     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65     return casted;
66   }
67 
68   auto loc = value.getLoc();
69   SmallVector<Value, 4> dynamicOperands;
70   for (int i = 0; i < destType.getRank(); ++i) {
71     if (destType.getShape()[i] != ShapedType::kDynamic)
72       continue;
73     Value size = b.create<memref::DimOp>(loc, value, i);
74     dynamicOperands.push_back(size);
75   }
76 
77   FailureOr<Value> copy =
78       options.createAlloc(b, loc, destType, dynamicOperands);
79   if (failed(copy))
80     return failure();
81   if (failed(options.createMemCpy(b, loc, value, *copy)))
82     return failure();
83   return copy;
84 }
85 
86 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
87 /// to_memref op are different, a memref.cast is needed.
88 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
89     RewriterBase &rewriter, ToMemrefOp toMemref,
90     const BufferizationOptions &options) {
91   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
92   if (!memrefToTensor)
93     return failure();
94 
95   Type srcType = memrefToTensor.getMemref().getType();
96   Type destType = toMemref.getType();
97 
98   // Directly rewrite if the type did not change.
99   if (srcType == destType) {
100     rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
101     return success();
102   }
103 
104   auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105   auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106   auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107 
108   // Ranked memref -> Ranked memref cast.
109   if (rankedSrcType && rankedDestType) {
110     FailureOr<Value> replacement = castOrReallocMemRefValue(
111         rewriter, memrefToTensor.getMemref(), rankedDestType, options);
112     if (failed(replacement))
113       return failure();
114 
115     rewriter.replaceOp(toMemref, *replacement);
116     return success();
117   }
118 
119   // Unranked memref -> Ranked memref cast: May require a copy.
120   // TODO: Not implemented at the moment.
121   if (unrankedSrcType && rankedDestType)
122     return failure();
123 
124   // Unranked memref -> unranked memref cast
125   // Ranked memref -> unranked memref cast: No copy needed.
126   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127          "expected that types are cast compatible");
128   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
129                                               memrefToTensor.getMemref());
130   return success();
131 }
132 
133 void mlir::bufferization::populateDynamicDimSizes(
134     OpBuilder &b, Location loc, Value shapedValue,
135     SmallVector<Value> &dynamicDims) {
136   auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
137   for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138     if (shapedType.isDynamicDim(i)) {
139       if (llvm::isa<MemRefType>(shapedType)) {
140         dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
141       } else {
142         assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
143         dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
144       }
145     }
146   }
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // AllocTensorOp
151 //===----------------------------------------------------------------------===//
152 
153 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154                                        const BufferizationOptions &options) {
155   OpBuilder::InsertionGuard g(rewriter);
156   Location loc = getLoc();
157 
158   // Nothing to do for dead AllocTensorOps.
159   if (getOperation()->getUses().empty()) {
160     rewriter.eraseOp(getOperation());
161     return success();
162   }
163 
164   // Get "copy" buffer.
165   Value copyBuffer;
166   if (getCopy()) {
167     FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168     if (failed(maybeCopyBuffer))
169       return failure();
170     copyBuffer = *maybeCopyBuffer;
171   }
172 
173   // Create memory allocation.
174   auto allocType = bufferization::getBufferType(getResult(), options);
175   if (failed(allocType))
176     return failure();
177   SmallVector<Value> dynamicDims = getDynamicSizes();
178   if (getCopy()) {
179     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181   }
182   FailureOr<Value> alloc = options.createAlloc(
183       rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184   if (failed(alloc))
185     return failure();
186 
187   // Create memory copy (if any).
188   if (getCopy()) {
189     if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190       return failure();
191   }
192 
193   // Replace op.
194   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195 
196   return success();
197 }
198 
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200                                                   const AnalysisState &state) {
201   // AllocTensorOps do not write unless they have a `copy` value.
202   return static_cast<bool>(getCopy());
203 }
204 
205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206                                            const AnalysisState &state) {
207   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208          "expected copy operand");
209   return true;
210 }
211 
212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213                                             const AnalysisState &state) {
214   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215          "expected copy operand");
216   return false;
217 }
218 
219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220                                                    const AnalysisState &state) {
221   // This is a new allocation. It does not alias with any other buffer.
222   return {};
223 }
224 
225 FailureOr<BaseMemRefType>
226 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
227                              SmallVector<Value> &invocationStack) {
228   assert(value == getResult() && "invalid value");
229 
230   // Compute memory space of this allocation.
231   Attribute memorySpace;
232   if (getMemorySpace().has_value()) {
233     memorySpace = *getMemorySpace();
234   } else if (getCopy()) {
235     auto copyBufferType =
236         bufferization::getBufferType(getCopy(), options, invocationStack);
237     if (failed(copyBufferType))
238       return failure();
239     memorySpace = copyBufferType->getMemorySpace();
240   } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
241     memorySpace = *ms;
242   } else {
243     return getOperation()->emitError("could not infer memory space");
244   }
245 
246   return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
247 }
248 
249 LogicalResult AllocTensorOp::verify() {
250   if (getCopy() && !getDynamicSizes().empty())
251     return emitError("dynamic sizes not needed when copying a tensor");
252   if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
253     return emitError("expected ")
254            << getType().getNumDynamicDims() << " dynamic sizes";
255   if (getCopy() && getCopy().getType() != getType())
256     return emitError("expected that `copy` and return type match");
257   return success();
258 }
259 
260 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
261                           RankedTensorType type, ValueRange dynamicSizes) {
262   build(builder, result, type, dynamicSizes, /*copy=*/Value(),
263         /*size_hint=*/Value(),
264         /*memory_space=*/IntegerAttr());
265 }
266 
267 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
268                           RankedTensorType type, ValueRange dynamicSizes,
269                           Value copy) {
270   build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
271         /*memory_space=*/IntegerAttr());
272 }
273 
274 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
275                           TensorType type, ValueRange dynamicSizes, Value copy,
276                           IntegerAttr memorySpace) {
277   build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
278         memorySpace);
279 }
280 
281 namespace {
282 /// Change the type of the result of a `bufferization.alloc_tensor` by making
283 /// the result type statically sized along dimension that in the original
284 /// operation where defined as dynamic, but the size was defined using a
285 /// `constant` op. For example:
286 ///
287 ///  %c5 = arith.constant 5: index
288 ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
289 ///
290 ///  to
291 ///
292 ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
293 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
294   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
295 
296   LogicalResult matchAndRewrite(AllocTensorOp op,
297                                 PatternRewriter &rewriter) const override {
298     if (op.getCopy())
299       return failure();
300     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
301     SmallVector<Value> newDynamicSizes;
302     unsigned int dynValCounter = 0;
303     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
304       if (!op.isDynamicDim(i))
305         continue;
306       Value value = op.getDynamicSizes()[dynValCounter++];
307       APInt intVal;
308       if (matchPattern(value, m_ConstantInt(&intVal))) {
309         int64_t dim = intVal.getSExtValue();
310         if (dim >= 0)
311           newShape[i] = intVal.getSExtValue();
312         else
313           newDynamicSizes.push_back(value);
314       } else {
315         newDynamicSizes.push_back(value);
316       }
317     }
318     RankedTensorType newType = RankedTensorType::get(
319         newShape, op.getType().getElementType(), op.getType().getEncoding());
320     if (newType == op.getType())
321       return failure();
322     auto newOp = rewriter.create<AllocTensorOp>(
323         op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325     return success();
326   }
327 };
328 
329 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
330   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
331 
332   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333                                 PatternRewriter &rewriter) const override {
334     std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
335     auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336     if (!allocTensorOp || !maybeConstantIndex)
337       return failure();
338     if (*maybeConstantIndex < 0 ||
339         *maybeConstantIndex >= allocTensorOp.getType().getRank())
340       return failure();
341     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
342       return failure();
343     rewriter.replaceOp(
344         dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
345     return success();
346   }
347 };
348 } // namespace
349 
350 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
351                                                 MLIRContext *ctx) {
352   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
353 }
354 
355 LogicalResult AllocTensorOp::reifyResultShapes(
356     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
357   auto shapes = llvm::to_vector<4>(
358       llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
359                       [&](int64_t dim) -> OpFoldResult {
360                         if (isDynamicDim(dim))
361                           return getDynamicSize(builder, dim);
362                         return builder.getIndexAttr(getStaticSize(dim));
363                       }));
364   reifiedReturnShapes.emplace_back(std::move(shapes));
365   return success();
366 }
367 
368 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
369   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
370   if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
371       parser.parseRParen())
372     return failure();
373   ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
374   OpAsmParser::UnresolvedOperand copyOperand;
375   if (copyKeyword.succeeded())
376     if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
377         parser.parseRParen())
378       return failure();
379   ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
380   OpAsmParser::UnresolvedOperand sizeHintOperand;
381   if (sizeHintKeyword.succeeded())
382     if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
383       return failure();
384   if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
385     return failure();
386 
387   TensorType type;
388   if (parser.parseCustomTypeWithFallback(type))
389     return failure();
390   result.addTypes(type);
391 
392   Type indexType = parser.getBuilder().getIndexType();
393   if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
394     return failure();
395   if (copyKeyword.succeeded())
396     if (parser.resolveOperand(copyOperand, type, result.operands))
397       return failure();
398   if (sizeHintKeyword.succeeded())
399     if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
400       return failure();
401   result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
402                       parser.getBuilder().getDenseI32ArrayAttr(
403                           {static_cast<int32_t>(dynamicSizesOperands.size()),
404                            static_cast<int32_t>(copyKeyword.succeeded()),
405                            static_cast<int32_t>(sizeHintKeyword.succeeded())}));
406   return success();
407 }
408 
409 void AllocTensorOp::print(OpAsmPrinter &p) {
410   p << "(" << getDynamicSizes() << ")";
411   if (getCopy())
412     p << " copy(" << getCopy() << ")";
413   if (getSizeHint())
414     p << " size_hint=" << getSizeHint();
415   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
416                               AllocTensorOp::getOperandSegmentSizeAttr()});
417   p << " : ";
418   auto type = getResult().getType();
419   if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
420     p.printStrippedAttrOrType(validType);
421   else
422     p << type;
423 }
424 
425 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
426   assert(isDynamicDim(idx) && "expected dynamic dim");
427   if (getCopy())
428     return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
429   return getOperand(getIndexOfDynamicSize(idx));
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // CloneOp
434 //===----------------------------------------------------------------------===//
435 
436 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
437   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
438 }
439 
440 namespace {
441 
442 /// Merge the clone and its source (by converting the clone to a cast) when
443 /// possible.
444 struct SimplifyClones : public OpRewritePattern<CloneOp> {
445   using OpRewritePattern<CloneOp>::OpRewritePattern;
446 
447   LogicalResult matchAndRewrite(CloneOp cloneOp,
448                                 PatternRewriter &rewriter) const override {
449     if (cloneOp.use_empty()) {
450       rewriter.eraseOp(cloneOp);
451       return success();
452     }
453 
454     Value source = cloneOp.getInput();
455     if (source.getType() != cloneOp.getType() &&
456         !memref::CastOp::areCastCompatible({source.getType()},
457                                            {cloneOp.getType()}))
458       return failure();
459 
460     // Aims to find the dealloc op for the canonical source
461     // which otherwise could prevent removal of unnecessary allocs.
462     Value canonicalSource = source;
463     while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
464                canonicalSource.getDefiningOp()))
465       canonicalSource = iface.getViewSource();
466 
467     std::optional<Operation *> maybeCloneDeallocOp =
468         memref::findDealloc(cloneOp.getOutput());
469     // Skip if either of them has > 1 deallocate operations.
470     if (!maybeCloneDeallocOp.has_value())
471       return failure();
472     std::optional<Operation *> maybeSourceDeallocOp =
473         memref::findDealloc(canonicalSource);
474     if (!maybeSourceDeallocOp.has_value())
475       return failure();
476     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
477     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
478 
479     // If both are deallocated in the same block, their in-block lifetimes
480     // might not fully overlap, so we cannot decide which one to drop.
481     if (cloneDeallocOp && sourceDeallocOp &&
482         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
483       return failure();
484 
485     Block *currentBlock = cloneOp->getBlock();
486     Operation *redundantDealloc = nullptr;
487     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
488       redundantDealloc = cloneDeallocOp;
489     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
490       redundantDealloc = sourceDeallocOp;
491     }
492 
493     if (!redundantDealloc)
494       return failure();
495 
496     // Safety check that there are no other deallocations inbetween
497     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
498     // of source before the uses of the clone. With alias information, we could
499     // restrict this to only fail of the dealloc's operand is an alias
500     // of the source.
501     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
502          pos = pos->getNextNode()) {
503       // Bail if we run out of operations while looking for a deallocation op.
504       if (!pos)
505         return failure();
506       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
507       if (!effectInterface)
508         continue;
509       if (effectInterface.hasEffect<MemoryEffects::Free>())
510         return failure();
511     }
512 
513     if (source.getType() != cloneOp.getType())
514       source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
515                                                cloneOp.getType(), source);
516     rewriter.replaceOp(cloneOp, source);
517     rewriter.eraseOp(redundantDealloc);
518     return success();
519   }
520 };
521 
522 } // namespace
523 
524 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
525                                           MLIRContext *context) {
526   results.add<SimplifyClones>(context);
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // DeallocTensorOp
531 //===----------------------------------------------------------------------===//
532 
533 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
534                                          const BufferizationOptions &options) {
535   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
536   if (failed(buffer))
537     return failure();
538   rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
539   rewriter.eraseOp(getOperation());
540   return success();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // MaterializeInDestinationOp
545 //===----------------------------------------------------------------------===//
546 
547 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
548     OpOperand &opOperand, const AnalysisState &state) {
549   return opOperand == getSourceMutable();
550 }
551 
552 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
553     OpOperand &opOperand, const AnalysisState &state) {
554   if (opOperand == getDestMutable()) {
555     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
556     return true;
557   }
558   return false;
559 }
560 
561 bool MaterializeInDestinationOp::mustBufferizeInPlace(
562     OpOperand &opOperand, const AnalysisState &state) {
563   // The source is only read and not written, so it always bufferizes in-place
564   // by default. The destination is written and is forced to bufferize in-place
565   // (if it is a tensor).
566   return true;
567 }
568 
569 AliasingValueList
570 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
571                                               const AnalysisState &state) {
572   if (opOperand == getDestMutable()) {
573     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
574     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
575   }
576   return {};
577 }
578 
579 LogicalResult
580 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
581                                       const BufferizationOptions &options) {
582   bool tensorDest = isa<TensorType>(getDest().getType());
583   Value buffer;
584   if (tensorDest) {
585     FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
586     if (failed(maybeBuffer))
587       return failure();
588     buffer = *maybeBuffer;
589   } else {
590     assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
591     buffer = getDest();
592   }
593   auto srcBuffer = getBuffer(rewriter, getSource(), options);
594   if (failed(srcBuffer))
595     return failure();
596   if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
597     return failure();
598   replaceOpWithBufferizedValues(rewriter, getOperation(),
599                                 tensorDest ? ValueRange(buffer) : ValueRange());
600   return success();
601 }
602 
603 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
604     const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
605   // As elements are copied from the "source" buffer to the "dest" buffer,
606   // already copied elements are not read a second time.
607   return true;
608 }
609 
610 LogicalResult MaterializeInDestinationOp::reifyResultShapes(
611     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
612   if (getOperation()->getNumResults() == 1) {
613     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
614     reifiedReturnShapes.resize(1,
615                                SmallVector<OpFoldResult>(getType().getRank()));
616     reifiedReturnShapes[0] =
617         tensor::getMixedSizes(builder, getLoc(), getDest());
618   }
619   return success();
620 }
621 
622 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
623                                                         Location loc) {
624   if (isa<TensorType>(getDest().getType())) {
625     // The subset is the entire destination tensor.
626     return getDest();
627   }
628 
629   // The "restrict" attribute is transferred from this op to the newly created
630   // to_tensor op. If this op does not the "restrict" attribute, the subset
631   // extraction cannot be built because there is no guarantee that there is no
632   // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
633   if (!getRestrict())
634     return {};
635 
636   // Build a bufferization.to_tensor op.
637   assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
638   assert(getRestrict() &&
639          "expected that ops with memrefs dest have 'restrict'");
640   setRestrict(false);
641   return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
642                                     getWritable());
643 }
644 
645 bool MaterializeInDestinationOp::isEquivalentSubset(
646     Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
647   return equivalenceFn(getDest(), candidate);
648 }
649 
650 SmallVector<Value>
651 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
652   return {getDest()};
653 }
654 
655 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
656   return getOperation()->getOpOperand(0) /*source*/;
657 }
658 
659 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
660     SubsetOpInterface subsetOp,
661     function_ref<bool(Value, Value)> equivalenceFn) {
662   return false;
663 }
664 
665 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
666     SubsetOpInterface subsetOp,
667     function_ref<bool(Value, Value)> equivalenceFn) {
668   return false;
669 }
670 
671 LogicalResult MaterializeInDestinationOp::verify() {
672   if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
673     return emitOpError("'dest' must be a tensor or a memref");
674   if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
675     if (getOperation()->getNumResults() != 1)
676       return emitOpError("tensor 'dest' implies exactly one tensor result");
677     if (destType != getResult().getType())
678       return emitOpError("result and 'dest' types must match");
679   }
680   if (isa<BaseMemRefType>(getDest().getType()) &&
681       getOperation()->getNumResults() != 0)
682     return emitOpError("memref 'dest' implies zero results");
683   if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
684     return emitOpError("'restrict' is valid only for memref destinations");
685   if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
686     return emitOpError("'writable' must be specified if and only if the "
687                        "destination is of memref type");
688   TensorType srcType = getSource().getType();
689   ShapedType destType = cast<ShapedType>(getDest().getType());
690   if (srcType.hasRank() != destType.hasRank())
691     return emitOpError("source/destination shapes are incompatible");
692   if (srcType.hasRank()) {
693     if (srcType.getRank() != destType.getRank())
694       return emitOpError("rank mismatch between source and destination shape");
695     for (auto [src, dest] :
696          llvm::zip(srcType.getShape(), destType.getShape())) {
697       if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
698         // Cannot verify dynamic dimension size. Assume that that they match at
699         // runtime.
700         continue;
701       }
702       if (src != dest)
703         return emitOpError("source/destination shapes are incompatible");
704     }
705   }
706   return success();
707 }
708 
709 void MaterializeInDestinationOp::build(OpBuilder &builder,
710                                        OperationState &state, Value source,
711                                        Value dest) {
712   auto destTensorType = dyn_cast<TensorType>(dest.getType());
713   build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
714         source, dest);
715 }
716 
717 bool MaterializeInDestinationOp::isWritable(Value value,
718                                             const AnalysisState &state) {
719   return isa<TensorType>(getDest().getType()) ? true : getWritable();
720 }
721 
722 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
723   return getDestMutable();
724 }
725 
726 void MaterializeInDestinationOp::getEffects(
727     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
728         &effects) {
729   if (isa<BaseMemRefType>(getDest().getType()))
730     effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
731                          SideEffects::DefaultResource::get());
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // ToTensorOp
736 //===----------------------------------------------------------------------===//
737 
738 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
739   return getWritable();
740 }
741 
742 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
743   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
744     // Approximate alias analysis by conservatively folding only when no there
745     // is no interleaved operation.
746     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
747         toMemref->getNextNode() == this->getOperation())
748       return toMemref.getTensor();
749   return {};
750 }
751 
752 namespace {
753 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
754   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
755 
756   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
757                                 PatternRewriter &rewriter) const override {
758     auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
759     if (!memrefToTensorOp)
760       return failure();
761 
762     rewriter.replaceOpWithNewOp<memref::DimOp>(
763         dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
764     return success();
765   }
766 };
767 } // namespace
768 
769 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
770                                              MLIRContext *context) {
771   results.add<DimOfToTensorFolder>(context);
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // ToMemrefOp
776 //===----------------------------------------------------------------------===//
777 
778 OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
779   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
780     if (memrefToTensor.getMemref().getType() == getType())
781       return memrefToTensor.getMemref();
782   return {};
783 }
784 
785 namespace {
786 
787 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
788 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
789   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
790 
791   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
792                                 PatternRewriter &rewriter) const final {
793     auto tensorCastOperand =
794         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
795     if (!tensorCastOperand)
796       return failure();
797     auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
798         tensorCastOperand.getOperand().getType());
799     if (!srcTensorType)
800       return failure();
801     auto memrefType = MemRefType::get(srcTensorType.getShape(),
802                                       srcTensorType.getElementType());
803     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
804                                                tensorCastOperand.getOperand());
805     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
806                                                 memref);
807     return success();
808   }
809 };
810 
811 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
812 /// cast if necessary.
813 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
814   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
815 
816   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
817                                 PatternRewriter &rewriter) const final {
818     BufferizationOptions options;
819     options.bufferAlignment = 0;
820     return foldToMemrefToTensorPair(rewriter, toMemref, options);
821   }
822 };
823 
824 /// Fold a load on a to_memref operation into an tensor.extract on the
825 /// corresponding tensor.
826 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
827   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
828 
829   LogicalResult matchAndRewrite(memref::LoadOp load,
830                                 PatternRewriter &rewriter) const override {
831     auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
832     if (!toMemref)
833       return failure();
834 
835     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
836                                                    load.getIndices());
837     return success();
838   }
839 };
840 
841 /// Fold dim of a to_memref into the dim of the tensor.
842 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
843   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
844 
845   LogicalResult matchAndRewrite(memref::DimOp dimOp,
846                                 PatternRewriter &rewriter) const override {
847     auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
848     if (!castOp)
849       return failure();
850     Value newSource = castOp.getOperand();
851     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
852                                                dimOp.getIndex());
853     return success();
854   }
855 };
856 
857 } // namespace
858 
859 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
860                                              MLIRContext *context) {
861   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
862               ToMemrefToTensorFolding>(context);
863 }
864 
865 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
866                                     const BufferizationOptions &options) {
867   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
868   (void)foldToMemrefToTensorPair(rewriter, *this, options);
869   // Note: The return value of `bufferize` indicates whether there was an error
870   // or not. (And not whether the pattern matched or not.)
871   return success();
872 }
873 
874 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
875                                                  Value alloc) {
876   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
877       .getOperation();
878 }
879 
880 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
881   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // DeallocOp
886 //===----------------------------------------------------------------------===//
887 
888 LogicalResult DeallocOp::inferReturnTypes(
889     MLIRContext *context, std::optional<::mlir::Location> location,
890     ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
891     RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
892   DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
893   inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
894                                           IntegerType::get(context, 1));
895   return success();
896 }
897 
898 LogicalResult DeallocOp::verify() {
899   if (getMemrefs().size() != getConditions().size())
900     return emitOpError(
901         "must have the same number of conditions as memrefs to deallocate");
902   if (getRetained().size() != getUpdatedConditions().size())
903     return emitOpError("must have the same number of updated conditions "
904                        "(results) as retained operands");
905   return success();
906 }
907 
908 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
909                                             ValueRange memrefs,
910                                             ValueRange conditions,
911                                             PatternRewriter &rewriter) {
912   if (deallocOp.getMemrefs() == memrefs &&
913       deallocOp.getConditions() == conditions)
914     return failure();
915 
916   rewriter.modifyOpInPlace(deallocOp, [&]() {
917     deallocOp.getMemrefsMutable().assign(memrefs);
918     deallocOp.getConditionsMutable().assign(conditions);
919   });
920   return success();
921 }
922 
923 namespace {
924 
925 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
926 /// make sure the corresponding condition value is updated accordingly since
927 /// their two conditions might not cover the same set of cases. In that case, we
928 /// have to combine them (by computing the disjunction of them).
929 /// Example:
930 /// ```mlir
931 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
932 /// ```
933 /// is canonicalized to
934 /// ```mlir
935 /// %0 = arith.ori %arg1, %arg2 : i1
936 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
937 /// ```
938 struct DeallocRemoveDuplicateDeallocMemrefs
939     : public OpRewritePattern<DeallocOp> {
940   using OpRewritePattern<DeallocOp>::OpRewritePattern;
941 
942   LogicalResult matchAndRewrite(DeallocOp deallocOp,
943                                 PatternRewriter &rewriter) const override {
944     // Unique memrefs to be deallocated.
945     DenseMap<Value, unsigned> memrefToCondition;
946     SmallVector<Value> newMemrefs, newConditions;
947     for (auto [i, memref, cond] :
948          llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
949       if (memrefToCondition.count(memref)) {
950         // If the dealloc conditions don't match, we need to make sure that the
951         // dealloc happens on the union of cases.
952         Value &newCond = newConditions[memrefToCondition[memref]];
953         if (newCond != cond)
954           newCond =
955               rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
956       } else {
957         memrefToCondition.insert({memref, newConditions.size()});
958         newMemrefs.push_back(memref);
959         newConditions.push_back(cond);
960       }
961     }
962 
963     // Return failure if we don't change anything such that we don't run into an
964     // infinite loop of pattern applications.
965     return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
966                                   rewriter);
967   }
968 };
969 
970 /// Remove duplicate values in the list of retained memrefs. We need to make
971 /// sure the corresponding result condition value is replaced properly.
972 /// Example:
973 /// ```mlir
974 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
975 /// ```
976 /// is canonicalized to
977 /// ```mlir
978 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
979 /// ```
980 struct DeallocRemoveDuplicateRetainedMemrefs
981     : public OpRewritePattern<DeallocOp> {
982   using OpRewritePattern<DeallocOp>::OpRewritePattern;
983 
984   LogicalResult matchAndRewrite(DeallocOp deallocOp,
985                                 PatternRewriter &rewriter) const override {
986     // Unique retained values
987     DenseMap<Value, unsigned> seen;
988     SmallVector<Value> newRetained;
989     SmallVector<unsigned> resultReplacementIdx;
990     unsigned i = 0;
991     for (auto retained : deallocOp.getRetained()) {
992       if (seen.count(retained)) {
993         resultReplacementIdx.push_back(seen[retained]);
994         continue;
995       }
996 
997       seen[retained] = i;
998       newRetained.push_back(retained);
999       resultReplacementIdx.push_back(i++);
1000     }
1001 
1002     // Return failure if we don't change anything such that we don't run into an
1003     // infinite loop of pattern applications.
1004     if (newRetained.size() == deallocOp.getRetained().size())
1005       return failure();
1006 
1007     // We need to create a new op because the number of results is always the
1008     // same as the number of condition operands.
1009     auto newDeallocOp =
1010         rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1011                                    deallocOp.getConditions(), newRetained);
1012     SmallVector<Value> replacements(
1013         llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1014           return newDeallocOp.getUpdatedConditions()[idx];
1015         }));
1016     rewriter.replaceOp(deallocOp, replacements);
1017     return success();
1018   }
1019 };
1020 
1021 /// Erase deallocation operations where the variadic list of memrefs to
1022 /// deallocate is empty. Example:
1023 /// ```mlir
1024 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1025 /// ```
1026 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1027   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1028 
1029   LogicalResult matchAndRewrite(DeallocOp deallocOp,
1030                                 PatternRewriter &rewriter) const override {
1031     if (deallocOp.getMemrefs().empty()) {
1032       Value constFalse = rewriter.create<arith::ConstantOp>(
1033           deallocOp.getLoc(), rewriter.getBoolAttr(false));
1034       rewriter.replaceOp(
1035           deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1036                                         constFalse));
1037       return success();
1038     }
1039     return failure();
1040   }
1041 };
1042 
1043 /// Removes memrefs from the deallocation list if their associated condition is
1044 /// always 'false'.
1045 ///
1046 /// Example:
1047 /// ```
1048 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1049 ///                           if (%arg2, %false)
1050 /// ```
1051 /// becomes
1052 /// ```
1053 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1054 /// ```
1055 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1056   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1057 
1058   LogicalResult matchAndRewrite(DeallocOp deallocOp,
1059                                 PatternRewriter &rewriter) const override {
1060     SmallVector<Value> newMemrefs, newConditions;
1061     for (auto [memref, cond] :
1062          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1063       if (!matchPattern(cond, m_Zero())) {
1064         newMemrefs.push_back(memref);
1065         newConditions.push_back(cond);
1066       }
1067     }
1068 
1069     return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1070                                   rewriter);
1071   }
1072 };
1073 
1074 /// The `memref.extract_strided_metadata` is often inserted to get the base
1075 /// memref if the operand is not already guaranteed to be the result of a memref
1076 /// allocation operation. This canonicalization pattern removes this extraction
1077 /// operation if the operand is now produced by an allocation operation (e.g.,
1078 /// due to other canonicalizations simplifying the IR).
1079 ///
1080 /// Example:
1081 /// ```mlir
1082 /// %alloc = memref.alloc() : memref<2xi32>
1083 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1084 ///   %alloc : memref<2xi32> -> memref<i32>, index, index, index
1085 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1086 /// ```
1087 /// is canonicalized to
1088 /// ```mlir
1089 /// %alloc = memref.alloc() : memref<2xi32>
1090 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1091 /// ```
1092 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1093   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1094 
1095   LogicalResult matchAndRewrite(DeallocOp deallocOp,
1096                                 PatternRewriter &rewriter) const override {
1097     SmallVector<Value> newMemrefs(
1098         llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1099           auto extractStridedOp =
1100               memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1101           if (!extractStridedOp)
1102             return memref;
1103           Value allocMemref = extractStridedOp.getOperand();
1104           auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1105           if (!allocOp)
1106             return memref;
1107           if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1108             return allocMemref;
1109           return memref;
1110         }));
1111 
1112     return updateDeallocIfChanged(deallocOp, newMemrefs,
1113                                   deallocOp.getConditions(), rewriter);
1114   }
1115 };
1116 
1117 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1118 /// other user of the allocated value and the allocating operation can be safely
1119 /// removed. If the same value is present multiple times, this pattern relies on
1120 /// other canonicalization patterns to remove the duplicate first.
1121 ///
1122 /// Example:
1123 /// ```mlir
1124 /// %alloc = memref.alloc() : memref<2xi32>
1125 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1126 /// ```
1127 /// is canonicalized to
1128 /// ```mlir
1129 /// bufferization.dealloc (%arg0 : ...) if (%true)
1130 /// ```
1131 struct RemoveAllocDeallocPairWhenNoOtherUsers
1132     : public OpRewritePattern<DeallocOp> {
1133   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1134 
1135   LogicalResult matchAndRewrite(DeallocOp deallocOp,
1136                                 PatternRewriter &rewriter) const override {
1137     SmallVector<Value> newMemrefs, newConditions;
1138     SmallVector<Operation *> toDelete;
1139     for (auto [memref, cond] :
1140          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1141       if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1142         // Check that it is indeed an allocate effect, that the op has no other
1143         // side effects (which would not allow us to remove the op), and that
1144         // there are no other users.
1145         if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1146             hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1147             memref.hasOneUse()) {
1148           toDelete.push_back(allocOp);
1149           continue;
1150         }
1151       }
1152 
1153       newMemrefs.push_back(memref);
1154       newConditions.push_back(cond);
1155     }
1156 
1157     if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1158                                       rewriter)))
1159       return failure();
1160 
1161     for (Operation *op : toDelete)
1162       rewriter.eraseOp(op);
1163 
1164     return success();
1165   }
1166 };
1167 
1168 } // anonymous namespace
1169 
1170 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1171                                             MLIRContext *context) {
1172   populateDeallocOpCanonicalizationPatterns(results, context);
1173 }
1174 
1175 void bufferization::populateDeallocOpCanonicalizationPatterns(
1176     RewritePatternSet &patterns, MLIRContext *context) {
1177   patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1178                DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1179                EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1180                RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // TableGen'd op method definitions
1185 //===----------------------------------------------------------------------===//
1186 
1187 #define GET_OP_CLASSES
1188 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
1189