xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
10e9a4a3bSRiver Riddle //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
20e9a4a3bSRiver Riddle //
30e9a4a3bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40e9a4a3bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
50e9a4a3bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60e9a4a3bSRiver Riddle //
70e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
80e9a4a3bSRiver Riddle 
98906b7beSsrcarroll #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
100e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1167d0d7acSMichele Scuttari 
1223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
130e9a4a3bSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
140e9a4a3bSRiver Riddle #include "mlir/IR/Operation.h"
150e9a4a3bSRiver Riddle #include "mlir/Pass/Pass.h"
160e9a4a3bSRiver Riddle 
1767d0d7acSMichele Scuttari namespace mlir {
1867d0d7acSMichele Scuttari namespace bufferization {
1967d0d7acSMichele Scuttari #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
2067d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
2167d0d7acSMichele Scuttari } // namespace bufferization
2267d0d7acSMichele Scuttari } // namespace mlir
2367d0d7acSMichele Scuttari 
240e9a4a3bSRiver Riddle using namespace mlir;
258906b7beSsrcarroll using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
26e6048b72SMatthias Gehre using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
270e9a4a3bSRiver Riddle 
28598c5ddbSMatthias Springer /// Return `true` if the given MemRef type has a fully dynamic layout.
29598c5ddbSMatthias Springer static bool hasFullyDynamicLayoutMap(MemRefType type) {
30598c5ddbSMatthias Springer   int64_t offset;
31598c5ddbSMatthias Springer   SmallVector<int64_t, 4> strides;
32*6aaa8f25SMatthias Springer   if (failed(type.getStridesAndOffset(strides, offset)))
33598c5ddbSMatthias Springer     return false;
34399638f9SAliia Khasanova   if (!llvm::all_of(strides, ShapedType::isDynamic))
35598c5ddbSMatthias Springer     return false;
36399638f9SAliia Khasanova   if (!ShapedType::isDynamic(offset))
37598c5ddbSMatthias Springer     return false;
38598c5ddbSMatthias Springer   return true;
39598c5ddbSMatthias Springer }
40598c5ddbSMatthias Springer 
41598c5ddbSMatthias Springer /// Return `true` if the given MemRef type has a static identity layout (i.e.,
42598c5ddbSMatthias Springer /// no layout).
43598c5ddbSMatthias Springer static bool hasStaticIdentityLayout(MemRefType type) {
44598c5ddbSMatthias Springer   return type.getLayout().isIdentity();
45598c5ddbSMatthias Springer }
46598c5ddbSMatthias Springer 
470e9a4a3bSRiver Riddle // Updates the func op and entry block.
480e9a4a3bSRiver Riddle //
490e9a4a3bSRiver Riddle // Any args appended to the entry block are added to `appendedEntryArgs`.
50e6048b72SMatthias Gehre // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
51e6048b72SMatthias Gehre // to each newly created function argument.
52598c5ddbSMatthias Springer static LogicalResult
53598c5ddbSMatthias Springer updateFuncOp(func::FuncOp func,
54e6048b72SMatthias Gehre              SmallVectorImpl<BlockArgument> &appendedEntryArgs,
55e6048b72SMatthias Gehre              bool addResultAttribute) {
564a3460a7SRiver Riddle   auto functionType = func.getFunctionType();
570e9a4a3bSRiver Riddle 
580e9a4a3bSRiver Riddle   // Collect information about the results will become appended arguments.
590e9a4a3bSRiver Riddle   SmallVector<Type, 6> erasedResultTypes;
60d10d49dcSRiver Riddle   BitVector erasedResultIndices(functionType.getNumResults());
610e9a4a3bSRiver Riddle   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
625550c821STres Popp     if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
63598c5ddbSMatthias Springer       if (!hasStaticIdentityLayout(memrefType) &&
64598c5ddbSMatthias Springer           !hasFullyDynamicLayoutMap(memrefType)) {
65598c5ddbSMatthias Springer         // Only buffers with static identity layout can be allocated. These can
66598c5ddbSMatthias Springer         // be casted to memrefs with fully dynamic layout map. Other layout maps
67598c5ddbSMatthias Springer         // are not supported.
68598c5ddbSMatthias Springer         return func->emitError()
69598c5ddbSMatthias Springer                << "cannot create out param for result with unsupported layout";
70598c5ddbSMatthias Springer       }
71e3cd80eaSRiver Riddle       erasedResultIndices.set(resultType.index());
72598c5ddbSMatthias Springer       erasedResultTypes.push_back(memrefType);
730e9a4a3bSRiver Riddle     }
740e9a4a3bSRiver Riddle   }
750e9a4a3bSRiver Riddle 
760e9a4a3bSRiver Riddle   // Add the new arguments to the function type.
770e9a4a3bSRiver Riddle   auto newArgTypes = llvm::to_vector<6>(
780e9a4a3bSRiver Riddle       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
790e9a4a3bSRiver Riddle   auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
800e9a4a3bSRiver Riddle                                            functionType.getResults());
810e9a4a3bSRiver Riddle   func.setType(newFunctionType);
820e9a4a3bSRiver Riddle 
830e9a4a3bSRiver Riddle   // Transfer the result attributes to arg attributes.
84e3cd80eaSRiver Riddle   auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
85e3cd80eaSRiver Riddle   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
860e9a4a3bSRiver Riddle     func.setArgAttrs(functionType.getNumInputs() + i,
87e3cd80eaSRiver Riddle                      func.getResultAttrs(*erasedIndicesIt));
88e6048b72SMatthias Gehre     if (addResultAttribute)
89e6048b72SMatthias Gehre       func.setArgAttr(functionType.getNumInputs() + i,
90e6048b72SMatthias Gehre                       StringAttr::get(func.getContext(), "bufferize.result"),
91e6048b72SMatthias Gehre                       UnitAttr::get(func.getContext()));
92e3cd80eaSRiver Riddle   }
930e9a4a3bSRiver Riddle 
940e9a4a3bSRiver Riddle   // Erase the results.
950e9a4a3bSRiver Riddle   func.eraseResults(erasedResultIndices);
960e9a4a3bSRiver Riddle 
970e9a4a3bSRiver Riddle   // Add the new arguments to the entry block if the function is not external.
980e9a4a3bSRiver Riddle   if (func.isExternal())
99598c5ddbSMatthias Springer     return success();
1000e9a4a3bSRiver Riddle   Location loc = func.getLoc();
1010e9a4a3bSRiver Riddle   for (Type type : erasedResultTypes)
1020e9a4a3bSRiver Riddle     appendedEntryArgs.push_back(func.front().addArgument(type, loc));
103598c5ddbSMatthias Springer 
104598c5ddbSMatthias Springer   return success();
1050e9a4a3bSRiver Riddle }
1060e9a4a3bSRiver Riddle 
10758ceae95SRiver Riddle // Updates all ReturnOps in the scope of the given func::FuncOp by either
10858ceae95SRiver Riddle // keeping them as return values or copying the associated buffer contents into
10958ceae95SRiver Riddle // the given out-params.
1108906b7beSsrcarroll static LogicalResult
1118906b7beSsrcarroll updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
1128906b7beSsrcarroll                 const bufferization::BufferResultsToOutParamsOpts &options) {
113afac64ceSMatthias Gehre   auto res = func.walk([&](func::ReturnOp op) {
1140e9a4a3bSRiver Riddle     SmallVector<Value, 6> copyIntoOutParams;
1150e9a4a3bSRiver Riddle     SmallVector<Value, 6> keepAsReturnOperands;
1160e9a4a3bSRiver Riddle     for (Value operand : op.getOperands()) {
1175550c821STres Popp       if (isa<MemRefType>(operand.getType()))
1180e9a4a3bSRiver Riddle         copyIntoOutParams.push_back(operand);
1190e9a4a3bSRiver Riddle       else
1200e9a4a3bSRiver Riddle         keepAsReturnOperands.push_back(operand);
1210e9a4a3bSRiver Riddle     }
1220e9a4a3bSRiver Riddle     OpBuilder builder(op);
1230af448b7SMenooker     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
1248906b7beSsrcarroll       if (options.hoistStaticAllocs &&
1258906b7beSsrcarroll           isa_and_nonnull<bufferization::AllocationOpInterface>(
1268906b7beSsrcarroll               orig.getDefiningOp()) &&
1271c8c2fddSJie Fu           mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
1280af448b7SMenooker         orig.replaceAllUsesWith(arg);
1290af448b7SMenooker         orig.getDefiningOp()->erase();
1300af448b7SMenooker       } else {
1318906b7beSsrcarroll         if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
132afac64ceSMatthias Gehre           return WalkResult::interrupt();
133afac64ceSMatthias Gehre       }
1340af448b7SMenooker     }
13523aa5a74SRiver Riddle     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
1360e9a4a3bSRiver Riddle     op.erase();
137afac64ceSMatthias Gehre     return WalkResult::advance();
1380e9a4a3bSRiver Riddle   });
139afac64ceSMatthias Gehre   return failure(res.wasInterrupted());
1400e9a4a3bSRiver Riddle }
1410e9a4a3bSRiver Riddle 
1420e9a4a3bSRiver Riddle // Updates all CallOps in the scope of the given ModuleOp by allocating
1430e9a4a3bSRiver Riddle // temporary buffers for newly introduced out params.
1447e133eb4SEmilio Cota static LogicalResult
1457e133eb4SEmilio Cota updateCalls(ModuleOp module,
146e6048b72SMatthias Gehre             const bufferization::BufferResultsToOutParamsOpts &options) {
1470e9a4a3bSRiver Riddle   bool didFail = false;
1487e133eb4SEmilio Cota   SymbolTable symtab(module);
14923aa5a74SRiver Riddle   module.walk([&](func::CallOp op) {
1507e133eb4SEmilio Cota     auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
1517e133eb4SEmilio Cota     if (!callee) {
1527e133eb4SEmilio Cota       op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
1537e133eb4SEmilio Cota                      << "symbol table";
1547e133eb4SEmilio Cota       didFail = true;
1557e133eb4SEmilio Cota       return;
1567e133eb4SEmilio Cota     }
1577e133eb4SEmilio Cota     if (!options.filterFn(&callee))
1587e133eb4SEmilio Cota       return;
1590e9a4a3bSRiver Riddle     SmallVector<Value, 6> replaceWithNewCallResults;
1600e9a4a3bSRiver Riddle     SmallVector<Value, 6> replaceWithOutParams;
1610e9a4a3bSRiver Riddle     for (OpResult result : op.getResults()) {
1625550c821STres Popp       if (isa<MemRefType>(result.getType()))
1630e9a4a3bSRiver Riddle         replaceWithOutParams.push_back(result);
1640e9a4a3bSRiver Riddle       else
1650e9a4a3bSRiver Riddle         replaceWithNewCallResults.push_back(result);
1660e9a4a3bSRiver Riddle     }
1670e9a4a3bSRiver Riddle     SmallVector<Value, 6> outParams;
1680e9a4a3bSRiver Riddle     OpBuilder builder(op);
1690e9a4a3bSRiver Riddle     for (Value memref : replaceWithOutParams) {
1705550c821STres Popp       if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
1710e9a4a3bSRiver Riddle         op.emitError()
1720e9a4a3bSRiver Riddle             << "cannot create out param for dynamically shaped result";
1730e9a4a3bSRiver Riddle         didFail = true;
1740e9a4a3bSRiver Riddle         return;
1750e9a4a3bSRiver Riddle       }
1765550c821STres Popp       auto memrefType = cast<MemRefType>(memref.getType());
177598c5ddbSMatthias Springer       auto allocType =
178598c5ddbSMatthias Springer           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
1795d04f0c9SMatthias Springer                           AffineMap(), memrefType.getMemorySpace());
1808906b7beSsrcarroll       auto maybeOutParam =
1818906b7beSsrcarroll           options.allocationFn(builder, op.getLoc(), allocType);
1828906b7beSsrcarroll       if (failed(maybeOutParam)) {
1838906b7beSsrcarroll         op.emitError() << "failed to create allocation op";
1848906b7beSsrcarroll         didFail = true;
1858906b7beSsrcarroll         return;
1868906b7beSsrcarroll       }
1878906b7beSsrcarroll       Value outParam = maybeOutParam.value();
188598c5ddbSMatthias Springer       if (!hasStaticIdentityLayout(memrefType)) {
189598c5ddbSMatthias Springer         // Layout maps are already checked in `updateFuncOp`.
190598c5ddbSMatthias Springer         assert(hasFullyDynamicLayoutMap(memrefType) &&
191598c5ddbSMatthias Springer                "layout map not supported");
192598c5ddbSMatthias Springer         outParam =
193598c5ddbSMatthias Springer             builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
194598c5ddbSMatthias Springer       }
1950e9a4a3bSRiver Riddle       memref.replaceAllUsesWith(outParam);
1960e9a4a3bSRiver Riddle       outParams.push_back(outParam);
1970e9a4a3bSRiver Riddle     }
1980e9a4a3bSRiver Riddle 
1990e9a4a3bSRiver Riddle     auto newOperands = llvm::to_vector<6>(op.getOperands());
2000e9a4a3bSRiver Riddle     newOperands.append(outParams.begin(), outParams.end());
2010e9a4a3bSRiver Riddle     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
2020e9a4a3bSRiver Riddle         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
20323aa5a74SRiver Riddle     auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
2040e9a4a3bSRiver Riddle                                                 newResultTypes, newOperands);
2050e9a4a3bSRiver Riddle     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
2060e9a4a3bSRiver Riddle       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
2070e9a4a3bSRiver Riddle     op.erase();
2080e9a4a3bSRiver Riddle   });
2090e9a4a3bSRiver Riddle 
2100e9a4a3bSRiver Riddle   return failure(didFail);
2110e9a4a3bSRiver Riddle }
2120e9a4a3bSRiver Riddle 
2137e133eb4SEmilio Cota LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
2147e133eb4SEmilio Cota     ModuleOp module,
215e6048b72SMatthias Gehre     const bufferization::BufferResultsToOutParamsOpts &options) {
21658ceae95SRiver Riddle   for (auto func : module.getOps<func::FuncOp>()) {
2177e133eb4SEmilio Cota     if (!options.filterFn(&func))
2187e133eb4SEmilio Cota       continue;
2190e9a4a3bSRiver Riddle     SmallVector<BlockArgument, 6> appendedEntryArgs;
220e6048b72SMatthias Gehre     if (failed(
221e6048b72SMatthias Gehre             updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
222598c5ddbSMatthias Springer       return failure();
2230e9a4a3bSRiver Riddle     if (func.isExternal())
2240e9a4a3bSRiver Riddle       continue;
2258906b7beSsrcarroll     if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
226afac64ceSMatthias Gehre       return failure();
227afac64ceSMatthias Gehre     }
2280e9a4a3bSRiver Riddle   }
2297e133eb4SEmilio Cota   if (failed(updateCalls(module, options)))
23082ea0d8bSMatthias Springer     return failure();
23182ea0d8bSMatthias Springer   return success();
23282ea0d8bSMatthias Springer }
23382ea0d8bSMatthias Springer 
23482ea0d8bSMatthias Springer namespace {
23582ea0d8bSMatthias Springer struct BufferResultsToOutParamsPass
23667d0d7acSMichele Scuttari     : bufferization::impl::BufferResultsToOutParamsBase<
23767d0d7acSMichele Scuttari           BufferResultsToOutParamsPass> {
2387e133eb4SEmilio Cota   explicit BufferResultsToOutParamsPass(
239e6048b72SMatthias Gehre       const bufferization::BufferResultsToOutParamsOpts &options)
2407e133eb4SEmilio Cota       : options(options) {}
2417e133eb4SEmilio Cota 
24282ea0d8bSMatthias Springer   void runOnOperation() override {
243e6048b72SMatthias Gehre     // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
244e6048b72SMatthias Gehre     if (addResultAttribute)
245e6048b72SMatthias Gehre       options.addResultAttribute = true;
2460af448b7SMenooker     if (hoistStaticAllocs)
2470af448b7SMenooker       options.hoistStaticAllocs = true;
248e6048b72SMatthias Gehre 
2497e133eb4SEmilio Cota     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
2507e133eb4SEmilio Cota                                                               options)))
2510e9a4a3bSRiver Riddle       return signalPassFailure();
2520e9a4a3bSRiver Riddle   }
2537e133eb4SEmilio Cota 
2547e133eb4SEmilio Cota private:
255e6048b72SMatthias Gehre   bufferization::BufferResultsToOutParamsOpts options;
2560e9a4a3bSRiver Riddle };
2570e9a4a3bSRiver Riddle } // namespace
2580e9a4a3bSRiver Riddle 
2597e133eb4SEmilio Cota std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
260e6048b72SMatthias Gehre     const bufferization::BufferResultsToOutParamsOpts &options) {
2617e133eb4SEmilio Cota   return std::make_unique<BufferResultsToOutParamsPass>(options);
2620e9a4a3bSRiver Riddle }
263