1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 16 namespace mlir { 17 namespace bufferization { 18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS 19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 20 } // namespace bufferization 21 } // namespace mlir 22 23 using namespace mlir; 24 25 /// Return `true` if the given MemRef type has a fully dynamic layout. 26 static bool hasFullyDynamicLayoutMap(MemRefType type) { 27 int64_t offset; 28 SmallVector<int64_t, 4> strides; 29 if (failed(getStridesAndOffset(type, strides, offset))) 30 return false; 31 if (!llvm::all_of(strides, ShapedType::isDynamic)) 32 return false; 33 if (!ShapedType::isDynamic(offset)) 34 return false; 35 return true; 36 } 37 38 /// Return `true` if the given MemRef type has a static identity layout (i.e., 39 /// no layout). 40 static bool hasStaticIdentityLayout(MemRefType type) { 41 return type.getLayout().isIdentity(); 42 } 43 44 // Updates the func op and entry block. 45 // 46 // Any args appended to the entry block are added to `appendedEntryArgs`. 47 static LogicalResult 48 updateFuncOp(func::FuncOp func, 49 SmallVectorImpl<BlockArgument> &appendedEntryArgs) { 50 auto functionType = func.getFunctionType(); 51 52 // Collect information about the results will become appended arguments. 53 SmallVector<Type, 6> erasedResultTypes; 54 BitVector erasedResultIndices(functionType.getNumResults()); 55 for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 56 if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) { 57 if (!hasStaticIdentityLayout(memrefType) && 58 !hasFullyDynamicLayoutMap(memrefType)) { 59 // Only buffers with static identity layout can be allocated. These can 60 // be casted to memrefs with fully dynamic layout map. Other layout maps 61 // are not supported. 62 return func->emitError() 63 << "cannot create out param for result with unsupported layout"; 64 } 65 erasedResultIndices.set(resultType.index()); 66 erasedResultTypes.push_back(memrefType); 67 } 68 } 69 70 // Add the new arguments to the function type. 71 auto newArgTypes = llvm::to_vector<6>( 72 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 73 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 74 functionType.getResults()); 75 func.setType(newFunctionType); 76 77 // Transfer the result attributes to arg attributes. 78 auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); 79 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { 80 func.setArgAttrs(functionType.getNumInputs() + i, 81 func.getResultAttrs(*erasedIndicesIt)); 82 } 83 84 // Erase the results. 85 func.eraseResults(erasedResultIndices); 86 87 // Add the new arguments to the entry block if the function is not external. 88 if (func.isExternal()) 89 return success(); 90 Location loc = func.getLoc(); 91 for (Type type : erasedResultTypes) 92 appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 93 94 return success(); 95 } 96 97 // Updates all ReturnOps in the scope of the given func::FuncOp by either 98 // keeping them as return values or copying the associated buffer contents into 99 // the given out-params. 100 static void updateReturnOps(func::FuncOp func, 101 ArrayRef<BlockArgument> appendedEntryArgs) { 102 func.walk([&](func::ReturnOp op) { 103 SmallVector<Value, 6> copyIntoOutParams; 104 SmallVector<Value, 6> keepAsReturnOperands; 105 for (Value operand : op.getOperands()) { 106 if (operand.getType().isa<MemRefType>()) 107 copyIntoOutParams.push_back(operand); 108 else 109 keepAsReturnOperands.push_back(operand); 110 } 111 OpBuilder builder(op); 112 for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) 113 builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t), 114 std::get<1>(t)); 115 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); 116 op.erase(); 117 }); 118 } 119 120 // Updates all CallOps in the scope of the given ModuleOp by allocating 121 // temporary buffers for newly introduced out params. 122 static LogicalResult 123 updateCalls(ModuleOp module, 124 const bufferization::BufferResultsToOutParamsOptions &options) { 125 bool didFail = false; 126 SymbolTable symtab(module); 127 module.walk([&](func::CallOp op) { 128 auto callee = symtab.lookup<func::FuncOp>(op.getCallee()); 129 if (!callee) { 130 op.emitError() << "cannot find callee '" << op.getCallee() << "' in " 131 << "symbol table"; 132 didFail = true; 133 return; 134 } 135 if (!options.filterFn(&callee)) 136 return; 137 SmallVector<Value, 6> replaceWithNewCallResults; 138 SmallVector<Value, 6> replaceWithOutParams; 139 for (OpResult result : op.getResults()) { 140 if (result.getType().isa<MemRefType>()) 141 replaceWithOutParams.push_back(result); 142 else 143 replaceWithNewCallResults.push_back(result); 144 } 145 SmallVector<Value, 6> outParams; 146 OpBuilder builder(op); 147 for (Value memref : replaceWithOutParams) { 148 if (!memref.getType().cast<MemRefType>().hasStaticShape()) { 149 op.emitError() 150 << "cannot create out param for dynamically shaped result"; 151 didFail = true; 152 return; 153 } 154 auto memrefType = memref.getType().cast<MemRefType>(); 155 auto allocType = 156 MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 157 AffineMap(), memrefType.getMemorySpace()); 158 Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType); 159 if (!hasStaticIdentityLayout(memrefType)) { 160 // Layout maps are already checked in `updateFuncOp`. 161 assert(hasFullyDynamicLayoutMap(memrefType) && 162 "layout map not supported"); 163 outParam = 164 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); 165 } 166 memref.replaceAllUsesWith(outParam); 167 outParams.push_back(outParam); 168 } 169 170 auto newOperands = llvm::to_vector<6>(op.getOperands()); 171 newOperands.append(outParams.begin(), outParams.end()); 172 auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 173 replaceWithNewCallResults, [](Value v) { return v.getType(); })); 174 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), 175 newResultTypes, newOperands); 176 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 177 std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 178 op.erase(); 179 }); 180 181 return failure(didFail); 182 } 183 184 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( 185 ModuleOp module, 186 const bufferization::BufferResultsToOutParamsOptions &options) { 187 for (auto func : module.getOps<func::FuncOp>()) { 188 if (!options.filterFn(&func)) 189 continue; 190 SmallVector<BlockArgument, 6> appendedEntryArgs; 191 if (failed(updateFuncOp(func, appendedEntryArgs))) 192 return failure(); 193 if (func.isExternal()) 194 continue; 195 updateReturnOps(func, appendedEntryArgs); 196 } 197 if (failed(updateCalls(module, options))) 198 return failure(); 199 return success(); 200 } 201 202 namespace { 203 struct BufferResultsToOutParamsPass 204 : bufferization::impl::BufferResultsToOutParamsBase< 205 BufferResultsToOutParamsPass> { 206 explicit BufferResultsToOutParamsPass( 207 const bufferization::BufferResultsToOutParamsOptions &options) 208 : options(options) {} 209 210 void runOnOperation() override { 211 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), 212 options))) 213 return signalPassFailure(); 214 } 215 216 private: 217 bufferization::BufferResultsToOutParamsOptions options; 218 }; 219 } // namespace 220 221 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass( 222 const bufferization::BufferResultsToOutParamsOptions &options) { 223 return std::make_unique<BufferResultsToOutParamsPass>(options); 224 } 225