10e9a4a3bSRiver Riddle //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 20e9a4a3bSRiver Riddle // 30e9a4a3bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40e9a4a3bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 50e9a4a3bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60e9a4a3bSRiver Riddle // 70e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 80e9a4a3bSRiver Riddle 98906b7beSsrcarroll #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" 100e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 1167d0d7acSMichele Scuttari 1223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 130e9a4a3bSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 140e9a4a3bSRiver Riddle #include "mlir/IR/Operation.h" 150e9a4a3bSRiver Riddle #include "mlir/Pass/Pass.h" 160e9a4a3bSRiver Riddle 1767d0d7acSMichele Scuttari namespace mlir { 1867d0d7acSMichele Scuttari namespace bufferization { 1967d0d7acSMichele Scuttari #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS 2067d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 2167d0d7acSMichele Scuttari } // namespace bufferization 2267d0d7acSMichele Scuttari } // namespace mlir 2367d0d7acSMichele Scuttari 240e9a4a3bSRiver Riddle using namespace mlir; 258906b7beSsrcarroll using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; 26e6048b72SMatthias Gehre using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; 270e9a4a3bSRiver Riddle 28598c5ddbSMatthias Springer /// Return `true` if the given MemRef type has a fully dynamic layout. 29598c5ddbSMatthias Springer static bool hasFullyDynamicLayoutMap(MemRefType type) { 30598c5ddbSMatthias Springer int64_t offset; 31598c5ddbSMatthias Springer SmallVector<int64_t, 4> strides; 32*6aaa8f25SMatthias Springer if (failed(type.getStridesAndOffset(strides, offset))) 33598c5ddbSMatthias Springer return false; 34399638f9SAliia Khasanova if (!llvm::all_of(strides, ShapedType::isDynamic)) 35598c5ddbSMatthias Springer return false; 36399638f9SAliia Khasanova if (!ShapedType::isDynamic(offset)) 37598c5ddbSMatthias Springer return false; 38598c5ddbSMatthias Springer return true; 39598c5ddbSMatthias Springer } 40598c5ddbSMatthias Springer 41598c5ddbSMatthias Springer /// Return `true` if the given MemRef type has a static identity layout (i.e., 42598c5ddbSMatthias Springer /// no layout). 43598c5ddbSMatthias Springer static bool hasStaticIdentityLayout(MemRefType type) { 44598c5ddbSMatthias Springer return type.getLayout().isIdentity(); 45598c5ddbSMatthias Springer } 46598c5ddbSMatthias Springer 470e9a4a3bSRiver Riddle // Updates the func op and entry block. 480e9a4a3bSRiver Riddle // 490e9a4a3bSRiver Riddle // Any args appended to the entry block are added to `appendedEntryArgs`. 50e6048b72SMatthias Gehre // If `addResultAttribute` is true, adds the unit attribute `bufferize.result` 51e6048b72SMatthias Gehre // to each newly created function argument. 52598c5ddbSMatthias Springer static LogicalResult 53598c5ddbSMatthias Springer updateFuncOp(func::FuncOp func, 54e6048b72SMatthias Gehre SmallVectorImpl<BlockArgument> &appendedEntryArgs, 55e6048b72SMatthias Gehre bool addResultAttribute) { 564a3460a7SRiver Riddle auto functionType = func.getFunctionType(); 570e9a4a3bSRiver Riddle 580e9a4a3bSRiver Riddle // Collect information about the results will become appended arguments. 590e9a4a3bSRiver Riddle SmallVector<Type, 6> erasedResultTypes; 60d10d49dcSRiver Riddle BitVector erasedResultIndices(functionType.getNumResults()); 610e9a4a3bSRiver Riddle for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 625550c821STres Popp if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) { 63598c5ddbSMatthias Springer if (!hasStaticIdentityLayout(memrefType) && 64598c5ddbSMatthias Springer !hasFullyDynamicLayoutMap(memrefType)) { 65598c5ddbSMatthias Springer // Only buffers with static identity layout can be allocated. These can 66598c5ddbSMatthias Springer // be casted to memrefs with fully dynamic layout map. Other layout maps 67598c5ddbSMatthias Springer // are not supported. 68598c5ddbSMatthias Springer return func->emitError() 69598c5ddbSMatthias Springer << "cannot create out param for result with unsupported layout"; 70598c5ddbSMatthias Springer } 71e3cd80eaSRiver Riddle erasedResultIndices.set(resultType.index()); 72598c5ddbSMatthias Springer erasedResultTypes.push_back(memrefType); 730e9a4a3bSRiver Riddle } 740e9a4a3bSRiver Riddle } 750e9a4a3bSRiver Riddle 760e9a4a3bSRiver Riddle // Add the new arguments to the function type. 770e9a4a3bSRiver Riddle auto newArgTypes = llvm::to_vector<6>( 780e9a4a3bSRiver Riddle llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 790e9a4a3bSRiver Riddle auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 800e9a4a3bSRiver Riddle functionType.getResults()); 810e9a4a3bSRiver Riddle func.setType(newFunctionType); 820e9a4a3bSRiver Riddle 830e9a4a3bSRiver Riddle // Transfer the result attributes to arg attributes. 84e3cd80eaSRiver Riddle auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); 85e3cd80eaSRiver Riddle for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { 860e9a4a3bSRiver Riddle func.setArgAttrs(functionType.getNumInputs() + i, 87e3cd80eaSRiver Riddle func.getResultAttrs(*erasedIndicesIt)); 88e6048b72SMatthias Gehre if (addResultAttribute) 89e6048b72SMatthias Gehre func.setArgAttr(functionType.getNumInputs() + i, 90e6048b72SMatthias Gehre StringAttr::get(func.getContext(), "bufferize.result"), 91e6048b72SMatthias Gehre UnitAttr::get(func.getContext())); 92e3cd80eaSRiver Riddle } 930e9a4a3bSRiver Riddle 940e9a4a3bSRiver Riddle // Erase the results. 950e9a4a3bSRiver Riddle func.eraseResults(erasedResultIndices); 960e9a4a3bSRiver Riddle 970e9a4a3bSRiver Riddle // Add the new arguments to the entry block if the function is not external. 980e9a4a3bSRiver Riddle if (func.isExternal()) 99598c5ddbSMatthias Springer return success(); 1000e9a4a3bSRiver Riddle Location loc = func.getLoc(); 1010e9a4a3bSRiver Riddle for (Type type : erasedResultTypes) 1020e9a4a3bSRiver Riddle appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 103598c5ddbSMatthias Springer 104598c5ddbSMatthias Springer return success(); 1050e9a4a3bSRiver Riddle } 1060e9a4a3bSRiver Riddle 10758ceae95SRiver Riddle // Updates all ReturnOps in the scope of the given func::FuncOp by either 10858ceae95SRiver Riddle // keeping them as return values or copying the associated buffer contents into 10958ceae95SRiver Riddle // the given out-params. 1108906b7beSsrcarroll static LogicalResult 1118906b7beSsrcarroll updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, 1128906b7beSsrcarroll const bufferization::BufferResultsToOutParamsOpts &options) { 113afac64ceSMatthias Gehre auto res = func.walk([&](func::ReturnOp op) { 1140e9a4a3bSRiver Riddle SmallVector<Value, 6> copyIntoOutParams; 1150e9a4a3bSRiver Riddle SmallVector<Value, 6> keepAsReturnOperands; 1160e9a4a3bSRiver Riddle for (Value operand : op.getOperands()) { 1175550c821STres Popp if (isa<MemRefType>(operand.getType())) 1180e9a4a3bSRiver Riddle copyIntoOutParams.push_back(operand); 1190e9a4a3bSRiver Riddle else 1200e9a4a3bSRiver Riddle keepAsReturnOperands.push_back(operand); 1210e9a4a3bSRiver Riddle } 1220e9a4a3bSRiver Riddle OpBuilder builder(op); 1230af448b7SMenooker for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { 1248906b7beSsrcarroll if (options.hoistStaticAllocs && 1258906b7beSsrcarroll isa_and_nonnull<bufferization::AllocationOpInterface>( 1268906b7beSsrcarroll orig.getDefiningOp()) && 1271c8c2fddSJie Fu mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) { 1280af448b7SMenooker orig.replaceAllUsesWith(arg); 1290af448b7SMenooker orig.getDefiningOp()->erase(); 1300af448b7SMenooker } else { 1318906b7beSsrcarroll if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) 132afac64ceSMatthias Gehre return WalkResult::interrupt(); 133afac64ceSMatthias Gehre } 1340af448b7SMenooker } 13523aa5a74SRiver Riddle builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); 1360e9a4a3bSRiver Riddle op.erase(); 137afac64ceSMatthias Gehre return WalkResult::advance(); 1380e9a4a3bSRiver Riddle }); 139afac64ceSMatthias Gehre return failure(res.wasInterrupted()); 1400e9a4a3bSRiver Riddle } 1410e9a4a3bSRiver Riddle 1420e9a4a3bSRiver Riddle // Updates all CallOps in the scope of the given ModuleOp by allocating 1430e9a4a3bSRiver Riddle // temporary buffers for newly introduced out params. 1447e133eb4SEmilio Cota static LogicalResult 1457e133eb4SEmilio Cota updateCalls(ModuleOp module, 146e6048b72SMatthias Gehre const bufferization::BufferResultsToOutParamsOpts &options) { 1470e9a4a3bSRiver Riddle bool didFail = false; 1487e133eb4SEmilio Cota SymbolTable symtab(module); 14923aa5a74SRiver Riddle module.walk([&](func::CallOp op) { 1507e133eb4SEmilio Cota auto callee = symtab.lookup<func::FuncOp>(op.getCallee()); 1517e133eb4SEmilio Cota if (!callee) { 1527e133eb4SEmilio Cota op.emitError() << "cannot find callee '" << op.getCallee() << "' in " 1537e133eb4SEmilio Cota << "symbol table"; 1547e133eb4SEmilio Cota didFail = true; 1557e133eb4SEmilio Cota return; 1567e133eb4SEmilio Cota } 1577e133eb4SEmilio Cota if (!options.filterFn(&callee)) 1587e133eb4SEmilio Cota return; 1590e9a4a3bSRiver Riddle SmallVector<Value, 6> replaceWithNewCallResults; 1600e9a4a3bSRiver Riddle SmallVector<Value, 6> replaceWithOutParams; 1610e9a4a3bSRiver Riddle for (OpResult result : op.getResults()) { 1625550c821STres Popp if (isa<MemRefType>(result.getType())) 1630e9a4a3bSRiver Riddle replaceWithOutParams.push_back(result); 1640e9a4a3bSRiver Riddle else 1650e9a4a3bSRiver Riddle replaceWithNewCallResults.push_back(result); 1660e9a4a3bSRiver Riddle } 1670e9a4a3bSRiver Riddle SmallVector<Value, 6> outParams; 1680e9a4a3bSRiver Riddle OpBuilder builder(op); 1690e9a4a3bSRiver Riddle for (Value memref : replaceWithOutParams) { 1705550c821STres Popp if (!cast<MemRefType>(memref.getType()).hasStaticShape()) { 1710e9a4a3bSRiver Riddle op.emitError() 1720e9a4a3bSRiver Riddle << "cannot create out param for dynamically shaped result"; 1730e9a4a3bSRiver Riddle didFail = true; 1740e9a4a3bSRiver Riddle return; 1750e9a4a3bSRiver Riddle } 1765550c821STres Popp auto memrefType = cast<MemRefType>(memref.getType()); 177598c5ddbSMatthias Springer auto allocType = 178598c5ddbSMatthias Springer MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 1795d04f0c9SMatthias Springer AffineMap(), memrefType.getMemorySpace()); 1808906b7beSsrcarroll auto maybeOutParam = 1818906b7beSsrcarroll options.allocationFn(builder, op.getLoc(), allocType); 1828906b7beSsrcarroll if (failed(maybeOutParam)) { 1838906b7beSsrcarroll op.emitError() << "failed to create allocation op"; 1848906b7beSsrcarroll didFail = true; 1858906b7beSsrcarroll return; 1868906b7beSsrcarroll } 1878906b7beSsrcarroll Value outParam = maybeOutParam.value(); 188598c5ddbSMatthias Springer if (!hasStaticIdentityLayout(memrefType)) { 189598c5ddbSMatthias Springer // Layout maps are already checked in `updateFuncOp`. 190598c5ddbSMatthias Springer assert(hasFullyDynamicLayoutMap(memrefType) && 191598c5ddbSMatthias Springer "layout map not supported"); 192598c5ddbSMatthias Springer outParam = 193598c5ddbSMatthias Springer builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); 194598c5ddbSMatthias Springer } 1950e9a4a3bSRiver Riddle memref.replaceAllUsesWith(outParam); 1960e9a4a3bSRiver Riddle outParams.push_back(outParam); 1970e9a4a3bSRiver Riddle } 1980e9a4a3bSRiver Riddle 1990e9a4a3bSRiver Riddle auto newOperands = llvm::to_vector<6>(op.getOperands()); 2000e9a4a3bSRiver Riddle newOperands.append(outParams.begin(), outParams.end()); 2010e9a4a3bSRiver Riddle auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 2020e9a4a3bSRiver Riddle replaceWithNewCallResults, [](Value v) { return v.getType(); })); 20323aa5a74SRiver Riddle auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), 2040e9a4a3bSRiver Riddle newResultTypes, newOperands); 2050e9a4a3bSRiver Riddle for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 2060e9a4a3bSRiver Riddle std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 2070e9a4a3bSRiver Riddle op.erase(); 2080e9a4a3bSRiver Riddle }); 2090e9a4a3bSRiver Riddle 2100e9a4a3bSRiver Riddle return failure(didFail); 2110e9a4a3bSRiver Riddle } 2120e9a4a3bSRiver Riddle 2137e133eb4SEmilio Cota LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( 2147e133eb4SEmilio Cota ModuleOp module, 215e6048b72SMatthias Gehre const bufferization::BufferResultsToOutParamsOpts &options) { 21658ceae95SRiver Riddle for (auto func : module.getOps<func::FuncOp>()) { 2177e133eb4SEmilio Cota if (!options.filterFn(&func)) 2187e133eb4SEmilio Cota continue; 2190e9a4a3bSRiver Riddle SmallVector<BlockArgument, 6> appendedEntryArgs; 220e6048b72SMatthias Gehre if (failed( 221e6048b72SMatthias Gehre updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) 222598c5ddbSMatthias Springer return failure(); 2230e9a4a3bSRiver Riddle if (func.isExternal()) 2240e9a4a3bSRiver Riddle continue; 2258906b7beSsrcarroll if (failed(updateReturnOps(func, appendedEntryArgs, options))) { 226afac64ceSMatthias Gehre return failure(); 227afac64ceSMatthias Gehre } 2280e9a4a3bSRiver Riddle } 2297e133eb4SEmilio Cota if (failed(updateCalls(module, options))) 23082ea0d8bSMatthias Springer return failure(); 23182ea0d8bSMatthias Springer return success(); 23282ea0d8bSMatthias Springer } 23382ea0d8bSMatthias Springer 23482ea0d8bSMatthias Springer namespace { 23582ea0d8bSMatthias Springer struct BufferResultsToOutParamsPass 23667d0d7acSMichele Scuttari : bufferization::impl::BufferResultsToOutParamsBase< 23767d0d7acSMichele Scuttari BufferResultsToOutParamsPass> { 2387e133eb4SEmilio Cota explicit BufferResultsToOutParamsPass( 239e6048b72SMatthias Gehre const bufferization::BufferResultsToOutParamsOpts &options) 2407e133eb4SEmilio Cota : options(options) {} 2417e133eb4SEmilio Cota 24282ea0d8bSMatthias Springer void runOnOperation() override { 243e6048b72SMatthias Gehre // Convert from pass options in tablegen to BufferResultsToOutParamsOpts. 244e6048b72SMatthias Gehre if (addResultAttribute) 245e6048b72SMatthias Gehre options.addResultAttribute = true; 2460af448b7SMenooker if (hoistStaticAllocs) 2470af448b7SMenooker options.hoistStaticAllocs = true; 248e6048b72SMatthias Gehre 2497e133eb4SEmilio Cota if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), 2507e133eb4SEmilio Cota options))) 2510e9a4a3bSRiver Riddle return signalPassFailure(); 2520e9a4a3bSRiver Riddle } 2537e133eb4SEmilio Cota 2547e133eb4SEmilio Cota private: 255e6048b72SMatthias Gehre bufferization::BufferResultsToOutParamsOpts options; 2560e9a4a3bSRiver Riddle }; 2570e9a4a3bSRiver Riddle } // namespace 2580e9a4a3bSRiver Riddle 2597e133eb4SEmilio Cota std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass( 260e6048b72SMatthias Gehre const bufferization::BufferResultsToOutParamsOpts &options) { 2617e133eb4SEmilio Cota return std::make_unique<BufferResultsToOutParamsPass>(options); 2620e9a4a3bSRiver Riddle } 263