xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp (revision 49df12c01e99af6e091fedc123f775580064740a)
1 //===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert `bufferization.dealloc` operations
10 // to the MemRef dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 namespace mlir {
25 namespace bufferization {
26 #define GEN_PASS_DEF_LOWERDEALLOCATIONS
27 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
28 } // namespace bufferization
29 } // namespace mlir
30 
31 using namespace mlir;
32 
33 namespace {
34 /// The DeallocOpConversion transforms all bufferization dealloc operations into
35 /// memref dealloc operations potentially guarded by scf if operations.
36 /// Additionally, memref extract_aligned_pointer_as_index and arith operations
37 /// are inserted to compute the guard conditions. We distinguish multiple cases
38 /// to provide an overall more efficient lowering. In the general case, a helper
39 /// func is created to avoid quadratic code size explosion (relative to the
40 /// number of operands of the dealloc operation). For examples of each case,
41 /// refer to the documentation of the member functions of this class.
42 class DeallocOpConversion
43     : public OpConversionPattern<bufferization::DeallocOp> {
44 
45   /// Lower a simple case without any retained values and a single memref to
46   /// avoiding the helper function. Ideally, static analysis can provide enough
47   /// aliasing information to split the dealloc operations up into this simple
48   /// case as much as possible before running this pass.
49   ///
50   /// Example:
51   /// ```
52   /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
53   /// ```
54   /// is lowered to
55   /// ```
56   /// scf.if %arg1 {
57   ///   memref.dealloc %arg0 : memref<2xf32>
58   /// }
59   /// ```
60   LogicalResult
61   rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
62                                ConversionPatternRewriter &rewriter) const {
63     assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
64     assert(adaptor.getRetained().empty() && "expected no retained memrefs");
65 
66     rewriter.replaceOpWithNewOp<scf::IfOp>(
67         op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
68           builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
69           builder.create<scf::YieldOp>(loc);
70         });
71     return success();
72   }
73 
74   /// A special case lowering for the deallocation operation with exactly one
75   /// memref, but arbitrary number of retained values. This avoids the helper
76   /// function that the general case needs and thus also avoids storing indices
77   /// to specifically allocated memrefs. The size of the code produced by this
78   /// lowering is linear to the number of retained values.
79   ///
80   /// Example:
81   /// ```mlir
82   /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
83   //                        retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
84   /// return %0#0, %0#1 : i1, i1
85   /// ```
86   /// ```mlir
87   /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m
88   /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
89   /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
90   /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
91   /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
92   /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
93   /// %should_dealloc = arith.andi %not_retained, %cond : i1
94   /// scf.if %should_dealloc {
95   ///   memref.dealloc %m : memref<2xf32>
96   /// }
97   /// %true = arith.constant true
98   /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
99   /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1
100   /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
101   /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1
102   /// return %r0_ownership, %r1_ownership : i1, i1
103   /// ```
104   LogicalResult rewriteOneMemrefMultipleRetainCase(
105       bufferization::DeallocOp op, OpAdaptor adaptor,
106       ConversionPatternRewriter &rewriter) const {
107     assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
108 
109     // Compute the base pointer indices, compare all retained indices to the
110     // memref index to check if they alias.
111     SmallVector<Value> doesNotAliasList;
112     Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
113         op->getLoc(), adaptor.getMemrefs()[0]);
114     for (Value retained : adaptor.getRetained()) {
115       Value retainedAsIdx =
116           rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
117                                                                   retained);
118       Value doesNotAlias = rewriter.create<arith::CmpIOp>(
119           op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
120       doesNotAliasList.push_back(doesNotAlias);
121     }
122 
123     // AND-reduce the list of booleans from above.
124     Value prev = doesNotAliasList.front();
125     for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
126       prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
127 
128     // Also consider the condition given by the dealloc operation and perform a
129     // conditional deallocation guarded by that value.
130     Value shouldDealloc = rewriter.create<arith::AndIOp>(
131         op->getLoc(), prev, adaptor.getConditions()[0]);
132 
133     rewriter.create<scf::IfOp>(
134         op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
135           builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
136           builder.create<scf::YieldOp>(loc);
137         });
138 
139     // Compute the replacement values for the dealloc operation results. This
140     // inserts an already canonicalized form of
141     // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
142     // value r.
143     SmallVector<Value> replacements;
144     Value trueVal = rewriter.create<arith::ConstantOp>(
145         op->getLoc(), rewriter.getBoolAttr(true));
146     for (Value doesNotAlias : doesNotAliasList) {
147       Value aliases =
148           rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
149       Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases,
150                                                     adaptor.getConditions()[0]);
151       replacements.push_back(result);
152     }
153 
154     rewriter.replaceOp(op, replacements);
155 
156     return success();
157   }
158 
159   /// Lowering that supports all features the dealloc operation has to offer. It
160   /// computes the base pointer of each memref (as an index), stores it in a
161   /// new memref helper structure and passes it to the helper function generated
162   /// in 'buildDeallocationHelperFunction'. The results are stored in two lists
163   /// (represented as memrefs) of booleans passed as arguments. The first list
164   /// stores whether the corresponding condition should be deallocated, the
165   /// second list stores the ownership of the retained values which can be used
166   /// to replace the result values of the `bufferization.dealloc` operation.
167   ///
168   /// Example:
169   /// ```
170   /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
171   ///                           if (%cond0, %cond1)
172   ///                       retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
173   /// ```
174   /// lowers to (simplified):
175   /// ```
176   /// %c0 = arith.constant 0 : index
177   /// %c1 = arith.constant 1 : index
178   /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
179   /// %cond_list = memref.alloc() : memref<2xi1>
180   /// %retain_base_pointer_list = memref.alloc() : memref<2xindex>
181   /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
182   /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
183   /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
184   /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
185   /// memref.store %cond0, %cond_list[%c0]
186   /// memref.store %cond1, %cond_list[%c1]
187   /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
188   /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
189   /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
190   /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
191   /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
192   ///    memref<2xindex> to memref<?xindex>
193   /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
194   /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
195   ///    memref<2xindex> to memref<?xindex>
196   /// %dealloc_cond_out = memref.alloc() : memref<2xi1>
197   /// %ownership_out = memref.alloc() : memref<2xi1>
198   /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
199   ///    memref<2xi1> to memref<?xi1>
200   /// %dyn_ownership_out = memref.cast %ownership_out :
201   ///    memref<2xi1> to memref<?xi1>
202   /// call @dealloc_helper(%dyn_dealloc_base_pointer_list,
203   ///                      %dyn_retain_base_pointer_list,
204   ///                      %dyn_cond_list,
205   ///                      %dyn_dealloc_cond_out,
206   ///                      %dyn_ownership_out) : (...)
207   /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
208   /// scf.if %m0_dealloc_cond {
209   ///   memref.dealloc %m0 : memref<2xf32>
210   /// }
211   /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
212   /// scf.if %m1_dealloc_cond {
213   ///   memref.dealloc %m1 : memref<5xf32>
214   /// }
215   /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
216   /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
217   /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
218   /// memref.dealloc %retain_base_pointer_list : memref<2xindex>
219   /// memref.dealloc %cond_list : memref<2xi1>
220   /// memref.dealloc %dealloc_cond_out : memref<2xi1>
221   /// memref.dealloc %ownership_out : memref<2xi1>
222   /// // replace %0#0 with %r0_ownership
223   /// // replace %0#1 with %r1_ownership
224   /// ```
225   LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
226                                    OpAdaptor adaptor,
227                                    ConversionPatternRewriter &rewriter) const {
228     // Allocate two memrefs holding the base pointer indices of the list of
229     // memrefs to be deallocated and the ones to be retained. These can then be
230     // passed to the helper function and the for-loops can iterate over them.
231     // Without storing them to memrefs, we could not use for-loops but only a
232     // completely unrolled version of it, potentially leading to code-size
233     // blow-up.
234     Value toDeallocMemref = rewriter.create<memref::AllocOp>(
235         op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
236                                      rewriter.getIndexType()));
237     Value conditionMemref = rewriter.create<memref::AllocOp>(
238         op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
239                                      rewriter.getI1Type()));
240     Value toRetainMemref = rewriter.create<memref::AllocOp>(
241         op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
242                                      rewriter.getIndexType()));
243 
244     auto getConstValue = [&](uint64_t value) -> Value {
245       return rewriter.create<arith::ConstantOp>(op.getLoc(),
246                                                 rewriter.getIndexAttr(value));
247     };
248 
249     // Extract the base pointers of the memrefs as indices to check for aliasing
250     // at runtime.
251     for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) {
252       Value memrefAsIdx =
253           rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
254                                                                   toDealloc);
255       rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
256                                        toDeallocMemref, getConstValue(i));
257     }
258 
259     for (auto [i, cond] : llvm::enumerate(adaptor.getConditions()))
260       rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
261                                        getConstValue(i));
262 
263     for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
264       Value memrefAsIdx =
265           rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
266                                                                   toRetain);
267       rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
268                                        getConstValue(i));
269     }
270 
271     // Cast the allocated memrefs to dynamic shape because we want only one
272     // helper function no matter how many operands the bufferization.dealloc
273     // has.
274     Value castedDeallocMemref = rewriter.create<memref::CastOp>(
275         op->getLoc(),
276         MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
277         toDeallocMemref);
278     Value castedCondsMemref = rewriter.create<memref::CastOp>(
279         op->getLoc(),
280         MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
281         conditionMemref);
282     Value castedRetainMemref = rewriter.create<memref::CastOp>(
283         op->getLoc(),
284         MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
285         toRetainMemref);
286 
287     Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
288         op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
289                                      rewriter.getI1Type()));
290     Value retainCondsMemref = rewriter.create<memref::AllocOp>(
291         op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
292                                      rewriter.getI1Type()));
293 
294     Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
295         op->getLoc(),
296         MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
297         deallocCondsMemref);
298     Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
299         op->getLoc(),
300         MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
301         retainCondsMemref);
302 
303     Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
304     rewriter.create<func::CallOp>(
305         op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
306         SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
307                            castedCondsMemref, castedDeallocCondsMemref,
308                            castedRetainCondsMemref});
309 
310     for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
311       Value idxValue = getConstValue(i);
312       Value shouldDealloc = rewriter.create<memref::LoadOp>(
313           op.getLoc(), deallocCondsMemref, idxValue);
314       rewriter.create<scf::IfOp>(
315           op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
316             builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
317             builder.create<scf::YieldOp>(loc);
318           });
319     }
320 
321     SmallVector<Value> replacements;
322     for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
323       Value idxValue = getConstValue(i);
324       Value ownership = rewriter.create<memref::LoadOp>(
325           op.getLoc(), retainCondsMemref, idxValue);
326       replacements.push_back(ownership);
327     }
328 
329     // Deallocate above allocated memrefs again to avoid memory leaks.
330     // Deallocation will not be run on code after this stage.
331     rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
332     rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
333     rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
334     rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
335     rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
336 
337     rewriter.replaceOp(op, replacements);
338     return success();
339   }
340 
341 public:
342   DeallocOpConversion(
343       MLIRContext *context,
344       const bufferization::DeallocHelperMap &deallocHelperFuncMap)
345       : OpConversionPattern<bufferization::DeallocOp>(context),
346         deallocHelperFuncMap(deallocHelperFuncMap) {}
347 
348   LogicalResult
349   matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
350                   ConversionPatternRewriter &rewriter) const override {
351     // Lower the trivial case.
352     if (adaptor.getMemrefs().empty()) {
353       Value falseVal = rewriter.create<arith::ConstantOp>(
354           op.getLoc(), rewriter.getBoolAttr(false));
355       rewriter.replaceOp(
356           op, SmallVector<Value>(adaptor.getRetained().size(), falseVal));
357       return success();
358     }
359 
360     if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
361       return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
362 
363     if (adaptor.getMemrefs().size() == 1)
364       return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
365 
366     Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
367     if (!deallocHelperFuncMap.contains(symtableOp))
368       return op->emitError(
369           "library function required for generic lowering, but cannot be "
370           "automatically inserted when operating on functions");
371 
372     return rewriteGeneralCase(op, adaptor, rewriter);
373   }
374 
375 private:
376   const bufferization::DeallocHelperMap &deallocHelperFuncMap;
377 };
378 } // namespace
379 
380 namespace {
381 struct LowerDeallocationsPass
382     : public bufferization::impl::LowerDeallocationsBase<
383           LowerDeallocationsPass> {
384   void runOnOperation() override {
385     if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
386       emitError(getOperation()->getLoc(),
387                 "root operation must be a builtin.module or a function");
388       signalPassFailure();
389       return;
390     }
391 
392     bufferization::DeallocHelperMap deallocHelperFuncMap;
393     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
394       OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
395 
396       // Build dealloc helper function if there are deallocs.
397       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
398         Operation *symtableOp =
399             deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
400         if (deallocOp.getMemrefs().size() > 1 &&
401             !deallocHelperFuncMap.contains(symtableOp)) {
402           SymbolTable symbolTable(symtableOp);
403           func::FuncOp helperFuncOp =
404               bufferization::buildDeallocationLibraryFunction(
405                   builder, getOperation()->getLoc(), symbolTable);
406           deallocHelperFuncMap[symtableOp] = helperFuncOp;
407         }
408       });
409     }
410 
411     RewritePatternSet patterns(&getContext());
412     bufferization::populateBufferizationDeallocLoweringPattern(
413         patterns, deallocHelperFuncMap);
414 
415     ConversionTarget target(getContext());
416     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
417                            scf::SCFDialect, func::FuncDialect>();
418     target.addIllegalOp<bufferization::DeallocOp>();
419 
420     if (failed(applyPartialConversion(getOperation(), target,
421                                       std::move(patterns))))
422       signalPassFailure();
423   }
424 };
425 } // namespace
426 
427 func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
428     OpBuilder &builder, Location loc, SymbolTable &symbolTable) {
429   Type indexMemrefType =
430       MemRefType::get({ShapedType::kDynamic}, builder.getIndexType());
431   Type boolMemrefType =
432       MemRefType::get({ShapedType::kDynamic}, builder.getI1Type());
433   SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType,
434                              boolMemrefType, boolMemrefType};
435   builder.clearInsertionPoint();
436 
437   // Generate the func operation itself.
438   auto helperFuncOp = func::FuncOp::create(
439       loc, "dealloc_helper", builder.getFunctionType(argTypes, {}));
440   helperFuncOp.setVisibility(SymbolTable::Visibility::Private);
441   symbolTable.insert(helperFuncOp);
442   auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
443   block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
444 
445   builder.setInsertionPointToStart(&block);
446   Value toDeallocMemref = helperFuncOp.getArguments()[0];
447   Value toRetainMemref = helperFuncOp.getArguments()[1];
448   Value conditionMemref = helperFuncOp.getArguments()[2];
449   Value deallocCondsMemref = helperFuncOp.getArguments()[3];
450   Value retainCondsMemref = helperFuncOp.getArguments()[4];
451 
452   // Insert some prerequisites.
453   Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
454   Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
455   Value trueValue =
456       builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
457   Value falseValue =
458       builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
459   Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0);
460   Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
461 
462   builder.create<scf::ForOp>(
463       loc, c0, toRetainSize, c1, std::nullopt,
464       [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
465         builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
466         builder.create<scf::YieldOp>(loc);
467       });
468 
469   builder.create<scf::ForOp>(
470       loc, c0, toDeallocSize, c1, std::nullopt,
471       [&](OpBuilder &builder, Location loc, Value outerIter,
472           ValueRange iterArgs) {
473         Value toDealloc =
474             builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
475         Value cond =
476             builder.create<memref::LoadOp>(loc, conditionMemref, outerIter);
477 
478         // Build the first for loop that computes aliasing with retained
479         // memrefs.
480         Value noRetainAlias =
481             builder
482                 .create<scf::ForOp>(
483                     loc, c0, toRetainSize, c1, trueValue,
484                     [&](OpBuilder &builder, Location loc, Value i,
485                         ValueRange iterArgs) {
486                       Value retainValue = builder.create<memref::LoadOp>(
487                           loc, toRetainMemref, i);
488                       Value doesAlias = builder.create<arith::CmpIOp>(
489                           loc, arith::CmpIPredicate::eq, retainValue,
490                           toDealloc);
491                       builder.create<scf::IfOp>(
492                           loc, doesAlias,
493                           [&](OpBuilder &builder, Location loc) {
494                             Value retainCondValue =
495                                 builder.create<memref::LoadOp>(
496                                     loc, retainCondsMemref, i);
497                             Value aggregatedRetainCond =
498                                 builder.create<arith::OrIOp>(
499                                     loc, retainCondValue, cond);
500                             builder.create<memref::StoreOp>(
501                                 loc, aggregatedRetainCond, retainCondsMemref,
502                                 i);
503                             builder.create<scf::YieldOp>(loc);
504                           });
505                       Value doesntAlias = builder.create<arith::CmpIOp>(
506                           loc, arith::CmpIPredicate::ne, retainValue,
507                           toDealloc);
508                       Value yieldValue = builder.create<arith::AndIOp>(
509                           loc, iterArgs[0], doesntAlias);
510                       builder.create<scf::YieldOp>(loc, yieldValue);
511                     })
512                 .getResult(0);
513 
514         // Build the second for loop that adds aliasing with previously
515         // deallocated memrefs.
516         Value noAlias =
517             builder
518                 .create<scf::ForOp>(
519                     loc, c0, outerIter, c1, noRetainAlias,
520                     [&](OpBuilder &builder, Location loc, Value i,
521                         ValueRange iterArgs) {
522                       Value prevDeallocValue = builder.create<memref::LoadOp>(
523                           loc, toDeallocMemref, i);
524                       Value doesntAlias = builder.create<arith::CmpIOp>(
525                           loc, arith::CmpIPredicate::ne, prevDeallocValue,
526                           toDealloc);
527                       Value yieldValue = builder.create<arith::AndIOp>(
528                           loc, iterArgs[0], doesntAlias);
529                       builder.create<scf::YieldOp>(loc, yieldValue);
530                     })
531                 .getResult(0);
532 
533         Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
534         builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
535                                         outerIter);
536         builder.create<scf::YieldOp>(loc);
537       });
538 
539   builder.create<func::ReturnOp>(loc);
540   return helperFuncOp;
541 }
542 
543 void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
544     RewritePatternSet &patterns,
545     const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
546   patterns.add<DeallocOpConversion>(patterns.getContext(),
547                                     deallocHelperFuncMap);
548 }
549 
550 std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
551   return std::make_unique<LowerDeallocationsPass>();
552 }
553