1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 /// Return `true` if the given MemRef type has a fully dynamic layout. 19 static bool hasFullyDynamicLayoutMap(MemRefType type) { 20 int64_t offset; 21 SmallVector<int64_t, 4> strides; 22 if (failed(getStridesAndOffset(type, strides, offset))) 23 return false; 24 if (!llvm::all_of(strides, ShapedType::isDynamicStrideOrOffset)) 25 return false; 26 if (!ShapedType::isDynamicStrideOrOffset(offset)) 27 return false; 28 return true; 29 } 30 31 /// Return `true` if the given MemRef type has a static identity layout (i.e., 32 /// no layout). 33 static bool hasStaticIdentityLayout(MemRefType type) { 34 return type.getLayout().isIdentity(); 35 } 36 37 // Updates the func op and entry block. 38 // 39 // Any args appended to the entry block are added to `appendedEntryArgs`. 40 static LogicalResult 41 updateFuncOp(func::FuncOp func, 42 SmallVectorImpl<BlockArgument> &appendedEntryArgs) { 43 auto functionType = func.getFunctionType(); 44 45 // Collect information about the results will become appended arguments. 46 SmallVector<Type, 6> erasedResultTypes; 47 BitVector erasedResultIndices(functionType.getNumResults()); 48 for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 49 if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) { 50 if (!hasStaticIdentityLayout(memrefType) && 51 !hasFullyDynamicLayoutMap(memrefType)) { 52 // Only buffers with static identity layout can be allocated. These can 53 // be casted to memrefs with fully dynamic layout map. Other layout maps 54 // are not supported. 55 return func->emitError() 56 << "cannot create out param for result with unsupported layout"; 57 } 58 erasedResultIndices.set(resultType.index()); 59 erasedResultTypes.push_back(memrefType); 60 } 61 } 62 63 // Add the new arguments to the function type. 64 auto newArgTypes = llvm::to_vector<6>( 65 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 66 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 67 functionType.getResults()); 68 func.setType(newFunctionType); 69 70 // Transfer the result attributes to arg attributes. 71 auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); 72 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { 73 func.setArgAttrs(functionType.getNumInputs() + i, 74 func.getResultAttrs(*erasedIndicesIt)); 75 } 76 77 // Erase the results. 78 func.eraseResults(erasedResultIndices); 79 80 // Add the new arguments to the entry block if the function is not external. 81 if (func.isExternal()) 82 return success(); 83 Location loc = func.getLoc(); 84 for (Type type : erasedResultTypes) 85 appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 86 87 return success(); 88 } 89 90 // Updates all ReturnOps in the scope of the given func::FuncOp by either 91 // keeping them as return values or copying the associated buffer contents into 92 // the given out-params. 93 static void updateReturnOps(func::FuncOp func, 94 ArrayRef<BlockArgument> appendedEntryArgs) { 95 func.walk([&](func::ReturnOp op) { 96 SmallVector<Value, 6> copyIntoOutParams; 97 SmallVector<Value, 6> keepAsReturnOperands; 98 for (Value operand : op.getOperands()) { 99 if (operand.getType().isa<MemRefType>()) 100 copyIntoOutParams.push_back(operand); 101 else 102 keepAsReturnOperands.push_back(operand); 103 } 104 OpBuilder builder(op); 105 for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) 106 builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t), 107 std::get<1>(t)); 108 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); 109 op.erase(); 110 }); 111 } 112 113 // Updates all CallOps in the scope of the given ModuleOp by allocating 114 // temporary buffers for newly introduced out params. 115 static LogicalResult updateCalls(ModuleOp module) { 116 bool didFail = false; 117 module.walk([&](func::CallOp op) { 118 SmallVector<Value, 6> replaceWithNewCallResults; 119 SmallVector<Value, 6> replaceWithOutParams; 120 for (OpResult result : op.getResults()) { 121 if (result.getType().isa<MemRefType>()) 122 replaceWithOutParams.push_back(result); 123 else 124 replaceWithNewCallResults.push_back(result); 125 } 126 SmallVector<Value, 6> outParams; 127 OpBuilder builder(op); 128 for (Value memref : replaceWithOutParams) { 129 if (!memref.getType().cast<MemRefType>().hasStaticShape()) { 130 op.emitError() 131 << "cannot create out param for dynamically shaped result"; 132 didFail = true; 133 return; 134 } 135 auto memrefType = memref.getType().cast<MemRefType>(); 136 auto allocType = 137 MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 138 AffineMap(), memrefType.getMemorySpaceAsInt()); 139 Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType); 140 if (!hasStaticIdentityLayout(memrefType)) { 141 // Layout maps are already checked in `updateFuncOp`. 142 assert(hasFullyDynamicLayoutMap(memrefType) && 143 "layout map not supported"); 144 outParam = 145 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); 146 } 147 memref.replaceAllUsesWith(outParam); 148 outParams.push_back(outParam); 149 } 150 151 auto newOperands = llvm::to_vector<6>(op.getOperands()); 152 newOperands.append(outParams.begin(), outParams.end()); 153 auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 154 replaceWithNewCallResults, [](Value v) { return v.getType(); })); 155 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), 156 newResultTypes, newOperands); 157 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 158 std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 159 op.erase(); 160 }); 161 162 return failure(didFail); 163 } 164 165 LogicalResult 166 mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) { 167 for (auto func : module.getOps<func::FuncOp>()) { 168 SmallVector<BlockArgument, 6> appendedEntryArgs; 169 if (failed(updateFuncOp(func, appendedEntryArgs))) 170 return failure(); 171 if (func.isExternal()) 172 continue; 173 updateReturnOps(func, appendedEntryArgs); 174 } 175 if (failed(updateCalls(module))) 176 return failure(); 177 return success(); 178 } 179 180 namespace { 181 struct BufferResultsToOutParamsPass 182 : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> { 183 void runOnOperation() override { 184 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation()))) 185 return signalPassFailure(); 186 } 187 }; 188 } // namespace 189 190 std::unique_ptr<Pass> 191 mlir::bufferization::createBufferResultsToOutParamsPass() { 192 return std::make_unique<BufferResultsToOutParamsPass>(); 193 } 194