1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" 10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 11 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/Pass/Pass.h" 16 17 namespace mlir { 18 namespace bufferization { 19 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS 20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 21 } // namespace bufferization 22 } // namespace mlir 23 24 using namespace mlir; 25 using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; 26 using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; 27 28 /// Return `true` if the given MemRef type has a fully dynamic layout. 29 static bool hasFullyDynamicLayoutMap(MemRefType type) { 30 int64_t offset; 31 SmallVector<int64_t, 4> strides; 32 if (failed(type.getStridesAndOffset(strides, offset))) 33 return false; 34 if (!llvm::all_of(strides, ShapedType::isDynamic)) 35 return false; 36 if (!ShapedType::isDynamic(offset)) 37 return false; 38 return true; 39 } 40 41 /// Return `true` if the given MemRef type has a static identity layout (i.e., 42 /// no layout). 43 static bool hasStaticIdentityLayout(MemRefType type) { 44 return type.getLayout().isIdentity(); 45 } 46 47 // Updates the func op and entry block. 48 // 49 // Any args appended to the entry block are added to `appendedEntryArgs`. 50 // If `addResultAttribute` is true, adds the unit attribute `bufferize.result` 51 // to each newly created function argument. 52 static LogicalResult 53 updateFuncOp(func::FuncOp func, 54 SmallVectorImpl<BlockArgument> &appendedEntryArgs, 55 bool addResultAttribute) { 56 auto functionType = func.getFunctionType(); 57 58 // Collect information about the results will become appended arguments. 59 SmallVector<Type, 6> erasedResultTypes; 60 BitVector erasedResultIndices(functionType.getNumResults()); 61 for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 62 if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) { 63 if (!hasStaticIdentityLayout(memrefType) && 64 !hasFullyDynamicLayoutMap(memrefType)) { 65 // Only buffers with static identity layout can be allocated. These can 66 // be casted to memrefs with fully dynamic layout map. Other layout maps 67 // are not supported. 68 return func->emitError() 69 << "cannot create out param for result with unsupported layout"; 70 } 71 erasedResultIndices.set(resultType.index()); 72 erasedResultTypes.push_back(memrefType); 73 } 74 } 75 76 // Add the new arguments to the function type. 77 auto newArgTypes = llvm::to_vector<6>( 78 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 79 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 80 functionType.getResults()); 81 func.setType(newFunctionType); 82 83 // Transfer the result attributes to arg attributes. 84 auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); 85 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { 86 func.setArgAttrs(functionType.getNumInputs() + i, 87 func.getResultAttrs(*erasedIndicesIt)); 88 if (addResultAttribute) 89 func.setArgAttr(functionType.getNumInputs() + i, 90 StringAttr::get(func.getContext(), "bufferize.result"), 91 UnitAttr::get(func.getContext())); 92 } 93 94 // Erase the results. 95 func.eraseResults(erasedResultIndices); 96 97 // Add the new arguments to the entry block if the function is not external. 98 if (func.isExternal()) 99 return success(); 100 Location loc = func.getLoc(); 101 for (Type type : erasedResultTypes) 102 appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 103 104 return success(); 105 } 106 107 // Updates all ReturnOps in the scope of the given func::FuncOp by either 108 // keeping them as return values or copying the associated buffer contents into 109 // the given out-params. 110 static LogicalResult 111 updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, 112 const bufferization::BufferResultsToOutParamsOpts &options) { 113 auto res = func.walk([&](func::ReturnOp op) { 114 SmallVector<Value, 6> copyIntoOutParams; 115 SmallVector<Value, 6> keepAsReturnOperands; 116 for (Value operand : op.getOperands()) { 117 if (isa<MemRefType>(operand.getType())) 118 copyIntoOutParams.push_back(operand); 119 else 120 keepAsReturnOperands.push_back(operand); 121 } 122 OpBuilder builder(op); 123 for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { 124 if (options.hoistStaticAllocs && 125 isa_and_nonnull<bufferization::AllocationOpInterface>( 126 orig.getDefiningOp()) && 127 mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) { 128 orig.replaceAllUsesWith(arg); 129 orig.getDefiningOp()->erase(); 130 } else { 131 if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) 132 return WalkResult::interrupt(); 133 } 134 } 135 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); 136 op.erase(); 137 return WalkResult::advance(); 138 }); 139 return failure(res.wasInterrupted()); 140 } 141 142 // Updates all CallOps in the scope of the given ModuleOp by allocating 143 // temporary buffers for newly introduced out params. 144 static LogicalResult 145 updateCalls(ModuleOp module, 146 const bufferization::BufferResultsToOutParamsOpts &options) { 147 bool didFail = false; 148 SymbolTable symtab(module); 149 module.walk([&](func::CallOp op) { 150 auto callee = symtab.lookup<func::FuncOp>(op.getCallee()); 151 if (!callee) { 152 op.emitError() << "cannot find callee '" << op.getCallee() << "' in " 153 << "symbol table"; 154 didFail = true; 155 return; 156 } 157 if (!options.filterFn(&callee)) 158 return; 159 SmallVector<Value, 6> replaceWithNewCallResults; 160 SmallVector<Value, 6> replaceWithOutParams; 161 for (OpResult result : op.getResults()) { 162 if (isa<MemRefType>(result.getType())) 163 replaceWithOutParams.push_back(result); 164 else 165 replaceWithNewCallResults.push_back(result); 166 } 167 SmallVector<Value, 6> outParams; 168 OpBuilder builder(op); 169 for (Value memref : replaceWithOutParams) { 170 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) { 171 op.emitError() 172 << "cannot create out param for dynamically shaped result"; 173 didFail = true; 174 return; 175 } 176 auto memrefType = cast<MemRefType>(memref.getType()); 177 auto allocType = 178 MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 179 AffineMap(), memrefType.getMemorySpace()); 180 auto maybeOutParam = 181 options.allocationFn(builder, op.getLoc(), allocType); 182 if (failed(maybeOutParam)) { 183 op.emitError() << "failed to create allocation op"; 184 didFail = true; 185 return; 186 } 187 Value outParam = maybeOutParam.value(); 188 if (!hasStaticIdentityLayout(memrefType)) { 189 // Layout maps are already checked in `updateFuncOp`. 190 assert(hasFullyDynamicLayoutMap(memrefType) && 191 "layout map not supported"); 192 outParam = 193 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); 194 } 195 memref.replaceAllUsesWith(outParam); 196 outParams.push_back(outParam); 197 } 198 199 auto newOperands = llvm::to_vector<6>(op.getOperands()); 200 newOperands.append(outParams.begin(), outParams.end()); 201 auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 202 replaceWithNewCallResults, [](Value v) { return v.getType(); })); 203 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), 204 newResultTypes, newOperands); 205 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 206 std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 207 op.erase(); 208 }); 209 210 return failure(didFail); 211 } 212 213 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( 214 ModuleOp module, 215 const bufferization::BufferResultsToOutParamsOpts &options) { 216 for (auto func : module.getOps<func::FuncOp>()) { 217 if (!options.filterFn(&func)) 218 continue; 219 SmallVector<BlockArgument, 6> appendedEntryArgs; 220 if (failed( 221 updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) 222 return failure(); 223 if (func.isExternal()) 224 continue; 225 if (failed(updateReturnOps(func, appendedEntryArgs, options))) { 226 return failure(); 227 } 228 } 229 if (failed(updateCalls(module, options))) 230 return failure(); 231 return success(); 232 } 233 234 namespace { 235 struct BufferResultsToOutParamsPass 236 : bufferization::impl::BufferResultsToOutParamsBase< 237 BufferResultsToOutParamsPass> { 238 explicit BufferResultsToOutParamsPass( 239 const bufferization::BufferResultsToOutParamsOpts &options) 240 : options(options) {} 241 242 void runOnOperation() override { 243 // Convert from pass options in tablegen to BufferResultsToOutParamsOpts. 244 if (addResultAttribute) 245 options.addResultAttribute = true; 246 if (hoistStaticAllocs) 247 options.hoistStaticAllocs = true; 248 249 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), 250 options))) 251 return signalPassFailure(); 252 } 253 254 private: 255 bufferization::BufferResultsToOutParamsOpts options; 256 }; 257 } // namespace 258 259 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass( 260 const bufferization::BufferResultsToOutParamsOpts &options) { 261 return std::make_unique<BufferResultsToOutParamsPass>(options); 262 } 263