1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 /// Return `true` if the given MemRef type has a fully dynamic layout. 19 static bool hasFullyDynamicLayoutMap(MemRefType type) { 20 int64_t offset; 21 SmallVector<int64_t, 4> strides; 22 if (failed(getStridesAndOffset(type, strides, offset))) 23 return false; 24 if (!llvm::all_of(strides, [](int64_t stride) { 25 return ShapedType::isDynamicStrideOrOffset(stride); 26 })) 27 return false; 28 if (!ShapedType::isDynamicStrideOrOffset(offset)) 29 return false; 30 return true; 31 } 32 33 /// Return `true` if the given MemRef type has a static identity layout (i.e., 34 /// no layout). 35 static bool hasStaticIdentityLayout(MemRefType type) { 36 return type.getLayout().isIdentity(); 37 } 38 39 // Updates the func op and entry block. 40 // 41 // Any args appended to the entry block are added to `appendedEntryArgs`. 42 static LogicalResult 43 updateFuncOp(func::FuncOp func, 44 SmallVectorImpl<BlockArgument> &appendedEntryArgs) { 45 auto functionType = func.getFunctionType(); 46 47 // Collect information about the results will become appended arguments. 48 SmallVector<Type, 6> erasedResultTypes; 49 BitVector erasedResultIndices(functionType.getNumResults()); 50 for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 51 if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) { 52 if (!hasStaticIdentityLayout(memrefType) && 53 !hasFullyDynamicLayoutMap(memrefType)) { 54 // Only buffers with static identity layout can be allocated. These can 55 // be casted to memrefs with fully dynamic layout map. Other layout maps 56 // are not supported. 57 return func->emitError() 58 << "cannot create out param for result with unsupported layout"; 59 } 60 erasedResultIndices.set(resultType.index()); 61 erasedResultTypes.push_back(memrefType); 62 } 63 } 64 65 // Add the new arguments to the function type. 66 auto newArgTypes = llvm::to_vector<6>( 67 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 68 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 69 functionType.getResults()); 70 func.setType(newFunctionType); 71 72 // Transfer the result attributes to arg attributes. 73 auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); 74 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { 75 func.setArgAttrs(functionType.getNumInputs() + i, 76 func.getResultAttrs(*erasedIndicesIt)); 77 } 78 79 // Erase the results. 80 func.eraseResults(erasedResultIndices); 81 82 // Add the new arguments to the entry block if the function is not external. 83 if (func.isExternal()) 84 return success(); 85 Location loc = func.getLoc(); 86 for (Type type : erasedResultTypes) 87 appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 88 89 return success(); 90 } 91 92 // Updates all ReturnOps in the scope of the given func::FuncOp by either 93 // keeping them as return values or copying the associated buffer contents into 94 // the given out-params. 95 static void updateReturnOps(func::FuncOp func, 96 ArrayRef<BlockArgument> appendedEntryArgs) { 97 func.walk([&](func::ReturnOp op) { 98 SmallVector<Value, 6> copyIntoOutParams; 99 SmallVector<Value, 6> keepAsReturnOperands; 100 for (Value operand : op.getOperands()) { 101 if (operand.getType().isa<MemRefType>()) 102 copyIntoOutParams.push_back(operand); 103 else 104 keepAsReturnOperands.push_back(operand); 105 } 106 OpBuilder builder(op); 107 for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) 108 builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t), 109 std::get<1>(t)); 110 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); 111 op.erase(); 112 }); 113 } 114 115 // Updates all CallOps in the scope of the given ModuleOp by allocating 116 // temporary buffers for newly introduced out params. 117 static LogicalResult updateCalls(ModuleOp module) { 118 bool didFail = false; 119 module.walk([&](func::CallOp op) { 120 SmallVector<Value, 6> replaceWithNewCallResults; 121 SmallVector<Value, 6> replaceWithOutParams; 122 for (OpResult result : op.getResults()) { 123 if (result.getType().isa<MemRefType>()) 124 replaceWithOutParams.push_back(result); 125 else 126 replaceWithNewCallResults.push_back(result); 127 } 128 SmallVector<Value, 6> outParams; 129 OpBuilder builder(op); 130 for (Value memref : replaceWithOutParams) { 131 if (!memref.getType().cast<MemRefType>().hasStaticShape()) { 132 op.emitError() 133 << "cannot create out param for dynamically shaped result"; 134 didFail = true; 135 return; 136 } 137 auto memrefType = memref.getType().cast<MemRefType>(); 138 auto allocType = 139 MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 140 AffineMap(), memrefType.getMemorySpaceAsInt()); 141 Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType); 142 if (!hasStaticIdentityLayout(memrefType)) { 143 // Layout maps are already checked in `updateFuncOp`. 144 assert(hasFullyDynamicLayoutMap(memrefType) && 145 "layout map not supported"); 146 outParam = 147 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); 148 } 149 memref.replaceAllUsesWith(outParam); 150 outParams.push_back(outParam); 151 } 152 153 auto newOperands = llvm::to_vector<6>(op.getOperands()); 154 newOperands.append(outParams.begin(), outParams.end()); 155 auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 156 replaceWithNewCallResults, [](Value v) { return v.getType(); })); 157 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), 158 newResultTypes, newOperands); 159 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 160 std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 161 op.erase(); 162 }); 163 164 return failure(didFail); 165 } 166 167 LogicalResult 168 mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) { 169 for (auto func : module.getOps<func::FuncOp>()) { 170 SmallVector<BlockArgument, 6> appendedEntryArgs; 171 if (failed(updateFuncOp(func, appendedEntryArgs))) 172 return failure(); 173 if (func.isExternal()) 174 continue; 175 updateReturnOps(func, appendedEntryArgs); 176 } 177 if (failed(updateCalls(module))) 178 return failure(); 179 return success(); 180 } 181 182 namespace { 183 struct BufferResultsToOutParamsPass 184 : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> { 185 void runOnOperation() override { 186 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation()))) 187 return signalPassFailure(); 188 } 189 }; 190 } // namespace 191 192 std::unique_ptr<Pass> 193 mlir::bufferization::createBufferResultsToOutParamsPass() { 194 return std::make_unique<BufferResultsToOutParamsPass>(); 195 } 196