1 //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass drops return values from functions if they are equivalent to one of 10 // their arguments. E.g.: 11 // 12 // ``` 13 // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) { 14 // return %m : memref<?xf32> 15 // } 16 // ``` 17 // 18 // This functions is rewritten to: 19 // 20 // ``` 21 // func.func @foo(%m : memref<?xf32>) { 22 // return 23 // } 24 // ``` 25 // 26 // All call sites are updated accordingly. If a function returns a cast of a 27 // function argument, it is also considered equivalent. A cast is inserted at 28 // the call site in that case. 29 30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 31 32 #include "mlir/Dialect/Func/IR/FuncOps.h" 33 #include "mlir/Dialect/MemRef/IR/MemRef.h" 34 #include "mlir/IR/Operation.h" 35 #include "mlir/Pass/Pass.h" 36 37 namespace mlir { 38 namespace bufferization { 39 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS 40 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 41 } // namespace bufferization 42 } // namespace mlir 43 44 using namespace mlir; 45 46 /// Return the unique ReturnOp that terminates `funcOp`. 47 /// Return nullptr if there is no such unique ReturnOp. 48 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 49 func::ReturnOp returnOp; 50 for (Block &b : funcOp.getBody()) { 51 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 52 if (returnOp) 53 return nullptr; 54 returnOp = candidateOp; 55 } 56 } 57 return returnOp; 58 } 59 60 /// Return the func::FuncOp called by `callOp`. 61 static func::FuncOp getCalledFunction(CallOpInterface callOp) { 62 SymbolRefAttr sym = 63 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); 64 if (!sym) 65 return nullptr; 66 return dyn_cast_or_null<func::FuncOp>( 67 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 68 } 69 70 LogicalResult 71 mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { 72 IRRewriter rewriter(module.getContext()); 73 74 DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; 75 // Collect the mapping of functions to their call sites. 76 module.walk([&](func::CallOp callOp) { 77 if (func::FuncOp calledFunc = getCalledFunction(callOp)) { 78 callerMap[calledFunc].insert(callOp); 79 } 80 }); 81 82 for (auto funcOp : module.getOps<func::FuncOp>()) { 83 if (funcOp.isExternal()) 84 continue; 85 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 86 // TODO: Support functions with multiple blocks. 87 if (!returnOp) 88 continue; 89 90 // Compute erased results. 91 SmallVector<Value> newReturnValues; 92 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); 93 DenseMap<int64_t, int64_t> resultToArgs; 94 for (const auto &it : llvm::enumerate(returnOp.getOperands())) { 95 bool erased = false; 96 for (BlockArgument bbArg : funcOp.getArguments()) { 97 Value val = it.value(); 98 while (auto castOp = val.getDefiningOp<memref::CastOp>()) 99 val = castOp.getSource(); 100 101 if (val == bbArg) { 102 resultToArgs[it.index()] = bbArg.getArgNumber(); 103 erased = true; 104 break; 105 } 106 } 107 108 if (erased) { 109 erasedResultIndices.set(it.index()); 110 } else { 111 newReturnValues.push_back(it.value()); 112 } 113 } 114 115 // Update function. 116 funcOp.eraseResults(erasedResultIndices); 117 returnOp.getOperandsMutable().assign(newReturnValues); 118 119 // Update function calls. 120 for (func::CallOp callOp : callerMap[funcOp]) { 121 rewriter.setInsertionPoint(callOp); 122 auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp, 123 callOp.getOperands()); 124 SmallVector<Value> newResults; 125 int64_t nextResult = 0; 126 for (int64_t i = 0; i < callOp.getNumResults(); ++i) { 127 if (!resultToArgs.count(i)) { 128 // This result was not erased. 129 newResults.push_back(newCallOp.getResult(nextResult++)); 130 continue; 131 } 132 133 // This result was erased. 134 Value replacement = callOp.getOperand(resultToArgs[i]); 135 Type expectedType = callOp.getResult(i).getType(); 136 if (replacement.getType() != expectedType) { 137 // A cast must be inserted at the call site. 138 replacement = rewriter.create<memref::CastOp>( 139 callOp.getLoc(), expectedType, replacement); 140 } 141 newResults.push_back(replacement); 142 } 143 rewriter.replaceOp(callOp, newResults); 144 } 145 } 146 147 return success(); 148 } 149 150 namespace { 151 struct DropEquivalentBufferResultsPass 152 : bufferization::impl::DropEquivalentBufferResultsBase< 153 DropEquivalentBufferResultsPass> { 154 void runOnOperation() override { 155 if (failed(bufferization::dropEquivalentBufferResults(getOperation()))) 156 return signalPassFailure(); 157 } 158 }; 159 } // namespace 160 161 std::unique_ptr<Pass> 162 mlir::bufferization::createDropEquivalentBufferResultsPass() { 163 return std::make_unique<DropEquivalentBufferResultsPass>(); 164 } 165