xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- BufferDeallocationSimplification.cpp -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for optimizing `bufferization.dealloc` operations
10 // that requires more analysis than what can be supported by regular
11 // canonicalization patterns.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
17 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 namespace mlir {
24 namespace bufferization {
25 #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27 } // namespace bufferization
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::bufferization;
32 
33 //===----------------------------------------------------------------------===//
34 // Helpers
35 //===----------------------------------------------------------------------===//
36 
37 /// Given a memref value, return the "base" value by skipping over all
38 /// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39 static Value getViewBase(Value value) {
40   while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
41     value = viewLikeOp.getViewSource();
42   return value;
43 }
44 
45 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
46                                             ValueRange memrefs,
47                                             ValueRange conditions,
48                                             PatternRewriter &rewriter) {
49   if (deallocOp.getMemrefs() == memrefs &&
50       deallocOp.getConditions() == conditions)
51     return failure();
52 
53   rewriter.modifyOpInPlace(deallocOp, [&]() {
54     deallocOp.getMemrefsMutable().assign(memrefs);
55     deallocOp.getConditionsMutable().assign(conditions);
56   });
57   return success();
58 }
59 
60 /// Return "true" if the given values are guaranteed to be different (and
61 /// non-aliasing) allocations based on the fact that one value is the result
62 /// of an allocation and the other value is a block argument of a parent block.
63 /// Note: This is a best-effort analysis that will eventually be replaced by a
64 /// proper "is same allocation" analysis. This function may return "false" even
65 /// though the two values are distinct allocations.
66 static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
67   Value v1Base = getViewBase(v1);
68   Value v2Base = getViewBase(v2);
69   auto areDistinct = [](Value v1, Value v2) {
70     if (Operation *op = v1.getDefiningOp())
71       if (hasEffect<MemoryEffects::Allocate>(op, v1))
72         if (auto bbArg = dyn_cast<BlockArgument>(v2))
73           if (bbArg.getOwner()->findAncestorOpInBlock(*op))
74             return true;
75     return false;
76   };
77   return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
78 }
79 
80 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
81 /// often a requirement of optimization patterns that there cannot be any
82 /// aliasing memref in order to perform the desired simplification.
83 static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
84                                      ValueRange otherList, Value memref) {
85   for (auto other : otherList) {
86     if (distinctAllocAndBlockArgument(other, memref))
87       continue;
88     std::optional<bool> analysisResult =
89         analysis.isSameAllocation(other, memref);
90     if (!analysisResult.has_value() || analysisResult == true)
91       return true;
92   }
93   return false;
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // Patterns
98 //===----------------------------------------------------------------------===//
99 
100 namespace {
101 
102 /// Remove values from the `memref` operand list that are also present in the
103 /// `retained` list (or a guaranteed alias of it) because they will never
104 /// actually be deallocated. However, we also need to be certain about which
105 /// other memrefs in the `retained` list can alias, i.e., there must not by any
106 /// may-aliasing memref. This is necessary because the `dealloc` operation is
107 /// defined to return one `i1` value per memref in the `retained` list which
108 /// represents the disjunction of the condition values corresponding to all
109 /// aliasing values in the `memref` list. In particular, this means that if
110 /// there is some value R in the `retained` list which aliases with a value M in
111 /// the `memref` list (but can only be staticaly determined to may-alias) and M
112 /// is also present in the `retained` list, then it would be illegal to remove M
113 /// because the result corresponding to R would be computed incorrectly
114 /// afterwards.  Because we require an alias analysis, this pattern cannot be
115 /// applied as a regular canonicalization pattern.
116 ///
117 /// Example:
118 /// ```mlir
119 /// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
120 ///                     retain (%m0, %r0, %r1 : ...)
121 /// ```
122 /// is canonicalized to
123 /// ```mlir
124 /// // bufferization.dealloc without memrefs and conditions returns %false for
125 /// // every retained value
126 /// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
127 /// %1 = arith.ori %0#0, %cond0 : i1
128 /// // replace %0#0 with %1
129 /// ```
130 /// given that `%r0` and `%r1` may not alias with `%m0`.
131 struct RemoveDeallocMemrefsContainedInRetained
132     : public OpRewritePattern<DeallocOp> {
133   RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
134                                           BufferOriginAnalysis &analysis)
135       : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
136 
137   /// The passed 'memref' must not have a may-alias relation to any retained
138   /// memref, and at least one must-alias relation. If there is no must-aliasing
139   /// memref in the retain list, we cannot simply remove the memref as there
140   /// could be situations in which it actually has to be deallocated. If it's
141   /// no-alias, then just proceed, if it's must-alias we need to update the
142   /// updated condition returned by the dealloc operation for that alias.
143   LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
144                                 PatternRewriter &rewriter) const {
145     rewriter.setInsertionPointAfter(deallocOp);
146 
147     // Check that there is no may-aliasing memref and that at least one memref
148     // in the retain list aliases (because otherwise it might have to be
149     // deallocated in some situations and can thus not be dropped).
150     bool atLeastOneMustAlias = false;
151     for (Value retained : deallocOp.getRetained()) {
152       std::optional<bool> analysisResult =
153           analysis.isSameAllocation(retained, memref);
154       if (!analysisResult.has_value())
155         return failure();
156       if (analysisResult == true)
157         atLeastOneMustAlias = true;
158     }
159     if (!atLeastOneMustAlias)
160       return failure();
161 
162     // Insert arith.ori operations to update the corresponding dealloc result
163     // values to incorporate the condition of the must-aliasing memref such that
164     // we can remove that operand later on.
165     for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
166       Value updatedCondition = deallocOp.getUpdatedConditions()[i];
167       std::optional<bool> analysisResult =
168           analysis.isSameAllocation(retained, memref);
169       if (analysisResult == true) {
170         auto disjunction = rewriter.create<arith::OrIOp>(
171             deallocOp.getLoc(), updatedCondition, cond);
172         rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
173                                       disjunction);
174       }
175     }
176 
177     return success();
178   }
179 
180   LogicalResult matchAndRewrite(DeallocOp deallocOp,
181                                 PatternRewriter &rewriter) const override {
182     // There must not be any duplicates in the retain list anymore because we
183     // would miss updating one of the result values otherwise.
184     DenseSet<Value> retained(deallocOp.getRetained().begin(),
185                              deallocOp.getRetained().end());
186     if (retained.size() != deallocOp.getRetained().size())
187       return failure();
188 
189     SmallVector<Value> newMemrefs, newConditions;
190     for (auto [memref, cond] :
191          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
192 
193       if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
194         continue;
195 
196       if (auto extractOp =
197               memref.getDefiningOp<memref::ExtractStridedMetadataOp>())
198         if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
199                                       rewriter)))
200           continue;
201 
202       newMemrefs.push_back(memref);
203       newConditions.push_back(cond);
204     }
205 
206     // Return failure if we don't change anything such that we don't run into an
207     // infinite loop of pattern applications.
208     return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
209                                   rewriter);
210   }
211 
212 private:
213   BufferOriginAnalysis &analysis;
214 };
215 
216 /// Remove memrefs from the `retained` list which are guaranteed to not alias
217 /// any memref in the `memrefs` list. The corresponding result value can be
218 /// replaced with `false` in that case according to the operation description.
219 ///
220 /// Example:
221 /// ```mlir
222 /// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond)
223 ///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
224 /// return %0#0, %0#1
225 /// ```
226 /// can be canonicalized to the following given that `%r0` and `%r1` do not
227 /// alias `%m`:
228 /// ```mlir
229 /// bufferization.dealloc (%m : memref<2xi32>) if (%cond)
230 /// return %false, %false
231 /// ```
232 struct RemoveRetainedMemrefsGuaranteedToNotAlias
233     : public OpRewritePattern<DeallocOp> {
234   RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
235                                             BufferOriginAnalysis &analysis)
236       : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
237 
238   LogicalResult matchAndRewrite(DeallocOp deallocOp,
239                                 PatternRewriter &rewriter) const override {
240     SmallVector<Value> newRetainedMemrefs, replacements;
241 
242     for (auto retainedMemref : deallocOp.getRetained()) {
243       if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
244                                    retainedMemref)) {
245         newRetainedMemrefs.push_back(retainedMemref);
246         replacements.push_back({});
247         continue;
248       }
249 
250       replacements.push_back(rewriter.create<arith::ConstantOp>(
251           deallocOp.getLoc(), rewriter.getBoolAttr(false)));
252     }
253 
254     if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
255       return failure();
256 
257     auto newDeallocOp = rewriter.create<DeallocOp>(
258         deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(),
259         newRetainedMemrefs);
260     int i = 0;
261     for (auto &repl : replacements) {
262       if (!repl)
263         repl = newDeallocOp.getUpdatedConditions()[i++];
264     }
265 
266     rewriter.replaceOp(deallocOp, replacements);
267     return success();
268   }
269 
270 private:
271   BufferOriginAnalysis &analysis;
272 };
273 
274 /// Split off memrefs to separate dealloc operations to reduce the number of
275 /// runtime checks required and enable further canonicalization of the new and
276 /// simpler dealloc operations. A memref can be split off if it is guaranteed to
277 /// not alias with any other memref in the `memref` operand list.  The results
278 /// of the old and the new dealloc operation have to be combined by computing
279 /// the element-wise disjunction of them.
280 ///
281 /// Example:
282 /// ```mlir
283 /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
284 ///                           if (%cond0, %cond1)
285 ///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
286 /// return %0#0, %0#1
287 /// ```
288 /// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
289 /// canonicalized to the following, thus reducing the number of runtime alias
290 /// checks by 1 and potentially enabling further canonicalization of the new
291 /// split-up dealloc operations.
292 /// ```mlir
293 /// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
294 ///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
295 /// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
296 ///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
297 /// %2 = arith.ori %0#0, %1#0
298 /// %3 = arith.ori %0#1, %1#1
299 /// return %2, %3
300 /// ```
301 struct SplitDeallocWhenNotAliasingAnyOther
302     : public OpRewritePattern<DeallocOp> {
303   SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
304                                       BufferOriginAnalysis &analysis)
305       : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
306 
307   LogicalResult matchAndRewrite(DeallocOp deallocOp,
308                                 PatternRewriter &rewriter) const override {
309     Location loc = deallocOp.getLoc();
310     if (deallocOp.getMemrefs().size() <= 1)
311       return failure();
312 
313     SmallVector<Value> remainingMemrefs, remainingConditions;
314     SmallVector<SmallVector<Value>> updatedConditions;
315     for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
316       Value memref = deallocOp.getMemrefs()[i];
317       Value cond = deallocOp.getConditions()[i];
318       SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
319       otherMemrefs.erase(otherMemrefs.begin() + i);
320       // Check if `memref` can split off into a separate bufferization.dealloc.
321       if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
322         // `memref` alias with other memrefs, do not split off.
323         remainingMemrefs.push_back(memref);
324         remainingConditions.push_back(cond);
325         continue;
326       }
327 
328       // Create new bufferization.dealloc op for `memref`.
329       auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
330                                                      deallocOp.getRetained());
331       updatedConditions.push_back(
332           llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
333     }
334 
335     // Fail if no memref was split off.
336     if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
337       return failure();
338 
339     // Create bufferization.dealloc op for all remaining memrefs.
340     auto newDeallocOp = rewriter.create<DeallocOp>(
341         loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
342 
343     // Bit-or all conditions.
344     SmallVector<Value> replacements =
345         llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
346     for (auto additionalConditions : updatedConditions) {
347       assert(replacements.size() == additionalConditions.size() &&
348              "expected same number of updated conditions");
349       for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
350         replacements[i] = rewriter.create<arith::OrIOp>(
351             loc, replacements[i], additionalConditions[i]);
352       }
353     }
354     rewriter.replaceOp(deallocOp, replacements);
355     return success();
356   }
357 
358 private:
359   BufferOriginAnalysis &analysis;
360 };
361 
362 /// Check for every retained memref if a must-aliasing memref exists in the
363 /// 'memref' operand list with constant 'true' condition. If so, we can replace
364 /// the operation result corresponding to that retained memref with 'true'. If
365 /// this condition holds for all retained memrefs we can also remove the
366 /// aliasing memrefs and their conditions since they will never be deallocated
367 /// due to the must-alias and we don't need them to compute the result value
368 /// anymore since it got replaced with 'true'.
369 ///
370 /// Example:
371 /// ```mlir
372 /// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...)
373 ///                           if (%true, %true, %true)
374 ///                       retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
375 /// ```
376 /// becomes
377 /// ```mlir
378 /// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true)
379 ///                       retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
380 /// // replace %0#0 with %true
381 /// // replace %0#1 with %true
382 /// ```
383 /// Note that the dealloc operation will still have the result values, but they
384 /// don't have uses anymore.
385 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
386     : public OpRewritePattern<DeallocOp> {
387   RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
388                                                 BufferOriginAnalysis &analysis)
389       : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
390 
391   LogicalResult matchAndRewrite(DeallocOp deallocOp,
392                                 PatternRewriter &rewriter) const override {
393     BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
394     SmallVector<Value> newMemrefs, newConditions;
395     for (auto [memref, cond] :
396          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
397       bool canDropMemref = false;
398       for (auto [i, retained, res] : llvm::enumerate(
399                deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
400         if (!matchPattern(cond, m_One()))
401           continue;
402 
403         std::optional<bool> analysisResult =
404             analysis.isSameAllocation(retained, memref);
405         if (analysisResult == true) {
406           rewriter.replaceAllUsesWith(res, cond);
407           aliasesWithConstTrueMemref[i] = true;
408           canDropMemref = true;
409           continue;
410         }
411 
412         // TODO: once our alias analysis is powerful enough we can remove the
413         // rest of this loop body
414         auto extractOp =
415             memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
416         if (!extractOp)
417           continue;
418 
419         std::optional<bool> extractAnalysisResult =
420             analysis.isSameAllocation(retained, extractOp.getOperand());
421         if (extractAnalysisResult == true) {
422           rewriter.replaceAllUsesWith(res, cond);
423           aliasesWithConstTrueMemref[i] = true;
424           canDropMemref = true;
425         }
426       }
427 
428       if (!canDropMemref) {
429         newMemrefs.push_back(memref);
430         newConditions.push_back(cond);
431       }
432     }
433     if (!aliasesWithConstTrueMemref.all())
434       return failure();
435 
436     return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
437                                   rewriter);
438   }
439 
440 private:
441   BufferOriginAnalysis &analysis;
442 };
443 
444 } // namespace
445 
446 //===----------------------------------------------------------------------===//
447 // BufferDeallocationSimplificationPass
448 //===----------------------------------------------------------------------===//
449 
450 namespace {
451 
452 /// The actual buffer deallocation pass that inserts and moves dealloc nodes
453 /// into the right positions. Furthermore, it inserts additional clones if
454 /// necessary. It uses the algorithm described at the top of the file.
455 struct BufferDeallocationSimplificationPass
456     : public bufferization::impl::BufferDeallocationSimplificationBase<
457           BufferDeallocationSimplificationPass> {
458   void runOnOperation() override {
459     BufferOriginAnalysis analysis(getOperation());
460     RewritePatternSet patterns(&getContext());
461     patterns.add<RemoveDeallocMemrefsContainedInRetained,
462                  RemoveRetainedMemrefsGuaranteedToNotAlias,
463                  SplitDeallocWhenNotAliasingAnyOther,
464                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465                                                                 analysis);
466     // We don't want that the block structure changes invalidating the
467     // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
468     // region simplification
469     GreedyRewriteConfig config;
470     config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
471     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
472 
473     if (failed(
474             applyPatternsGreedily(getOperation(), std::move(patterns), config)))
475       signalPassFailure();
476   }
477 };
478 
479 } // namespace
480 
481 std::unique_ptr<Pass>
482 mlir::bufferization::createBufferDeallocationSimplificationPass() {
483   return std::make_unique<BufferDeallocationSimplificationPass>();
484 }
485