xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp (revision 6867324eeec7c4f297c2f787d9c7b4d751a384c7)
1 //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass drops return values from functions if they are equivalent to one of
10 // their arguments. E.g.:
11 //
12 // ```
13 // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14 //   return %m : memref<?xf32>
15 // }
16 // ```
17 //
18 // This functions is rewritten to:
19 //
20 // ```
21 // func.func @foo(%m : memref<?xf32>) {
22 //   return
23 // }
24 // ```
25 //
26 // All call sites are updated accordingly. If a function returns a cast of a
27 // function argument, it is also considered equivalent. A cast is inserted at
28 // the call site in that case.
29 
30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
31 
32 #include "mlir/Dialect/Func/IR/FuncOps.h"
33 #include "mlir/Dialect/MemRef/IR/MemRef.h"
34 #include "mlir/IR/Operation.h"
35 #include "mlir/Pass/Pass.h"
36 
37 namespace mlir {
38 namespace bufferization {
39 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS
40 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
41 } // namespace bufferization
42 } // namespace mlir
43 
44 using namespace mlir;
45 
46 /// Return the unique ReturnOp that terminates `funcOp`.
47 /// Return nullptr if there is no such unique ReturnOp.
48 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
49   func::ReturnOp returnOp;
50   for (Block &b : funcOp.getBody()) {
51     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
52       if (returnOp)
53         return nullptr;
54       returnOp = candidateOp;
55     }
56   }
57   return returnOp;
58 }
59 
60 /// Return the func::FuncOp called by `callOp`.
61 static func::FuncOp getCalledFunction(CallOpInterface callOp) {
62   SymbolRefAttr sym =
63       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
64   if (!sym)
65     return nullptr;
66   return dyn_cast_or_null<func::FuncOp>(
67       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
68 }
69 
70 LogicalResult
71 mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
72   IRRewriter rewriter(module.getContext());
73 
74   DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
75   // Collect the mapping of functions to their call sites.
76   module.walk([&](func::CallOp callOp) {
77     if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
78       callerMap[calledFunc].insert(callOp);
79     }
80   });
81 
82   for (auto funcOp : module.getOps<func::FuncOp>()) {
83     if (funcOp.isExternal())
84       continue;
85     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
86     // TODO: Support functions with multiple blocks.
87     if (!returnOp)
88       continue;
89 
90     // Compute erased results.
91     SmallVector<Value> newReturnValues;
92     BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
93     DenseMap<int64_t, int64_t> resultToArgs;
94     for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
95       bool erased = false;
96       for (BlockArgument bbArg : funcOp.getArguments()) {
97         Value val = it.value();
98         while (auto castOp = val.getDefiningOp<memref::CastOp>())
99           val = castOp.getSource();
100 
101         if (val == bbArg) {
102           resultToArgs[it.index()] = bbArg.getArgNumber();
103           erased = true;
104           break;
105         }
106       }
107 
108       if (erased) {
109         erasedResultIndices.set(it.index());
110       } else {
111         newReturnValues.push_back(it.value());
112       }
113     }
114 
115     // Update function.
116     funcOp.eraseResults(erasedResultIndices);
117     returnOp.getOperandsMutable().assign(newReturnValues);
118 
119     // Update function calls.
120     for (func::CallOp callOp : callerMap[funcOp]) {
121       rewriter.setInsertionPoint(callOp);
122       auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
123                                                      callOp.getOperands());
124       SmallVector<Value> newResults;
125       int64_t nextResult = 0;
126       for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
127         if (!resultToArgs.count(i)) {
128           // This result was not erased.
129           newResults.push_back(newCallOp.getResult(nextResult++));
130           continue;
131         }
132 
133         // This result was erased.
134         Value replacement = callOp.getOperand(resultToArgs[i]);
135         Type expectedType = callOp.getResult(i).getType();
136         if (replacement.getType() != expectedType) {
137           // A cast must be inserted at the call site.
138           replacement = rewriter.create<memref::CastOp>(
139               callOp.getLoc(), expectedType, replacement);
140         }
141         newResults.push_back(replacement);
142       }
143       rewriter.replaceOp(callOp, newResults);
144     }
145   }
146 
147   return success();
148 }
149 
150 namespace {
151 struct DropEquivalentBufferResultsPass
152     : bufferization::impl::DropEquivalentBufferResultsBase<
153           DropEquivalentBufferResultsPass> {
154   void runOnOperation() override {
155     if (failed(bufferization::dropEquivalentBufferResults(getOperation())))
156       return signalPassFailure();
157   }
158 };
159 } // namespace
160 
161 std::unique_ptr<Pass>
162 mlir::bufferization::createDropEquivalentBufferResultsPass() {
163   return std::make_unique<DropEquivalentBufferResultsPass>();
164 }
165