xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
11 
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 namespace bufferization {
19 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
21 } // namespace bufferization
22 } // namespace mlir
23 
24 using namespace mlir;
25 using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
26 using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
27 
28 /// Return `true` if the given MemRef type has a fully dynamic layout.
29 static bool hasFullyDynamicLayoutMap(MemRefType type) {
30   int64_t offset;
31   SmallVector<int64_t, 4> strides;
32   if (failed(type.getStridesAndOffset(strides, offset)))
33     return false;
34   if (!llvm::all_of(strides, ShapedType::isDynamic))
35     return false;
36   if (!ShapedType::isDynamic(offset))
37     return false;
38   return true;
39 }
40 
41 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
42 /// no layout).
43 static bool hasStaticIdentityLayout(MemRefType type) {
44   return type.getLayout().isIdentity();
45 }
46 
47 // Updates the func op and entry block.
48 //
49 // Any args appended to the entry block are added to `appendedEntryArgs`.
50 // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
51 // to each newly created function argument.
52 static LogicalResult
53 updateFuncOp(func::FuncOp func,
54              SmallVectorImpl<BlockArgument> &appendedEntryArgs,
55              bool addResultAttribute) {
56   auto functionType = func.getFunctionType();
57 
58   // Collect information about the results will become appended arguments.
59   SmallVector<Type, 6> erasedResultTypes;
60   BitVector erasedResultIndices(functionType.getNumResults());
61   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
62     if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
63       if (!hasStaticIdentityLayout(memrefType) &&
64           !hasFullyDynamicLayoutMap(memrefType)) {
65         // Only buffers with static identity layout can be allocated. These can
66         // be casted to memrefs with fully dynamic layout map. Other layout maps
67         // are not supported.
68         return func->emitError()
69                << "cannot create out param for result with unsupported layout";
70       }
71       erasedResultIndices.set(resultType.index());
72       erasedResultTypes.push_back(memrefType);
73     }
74   }
75 
76   // Add the new arguments to the function type.
77   auto newArgTypes = llvm::to_vector<6>(
78       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
79   auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
80                                            functionType.getResults());
81   func.setType(newFunctionType);
82 
83   // Transfer the result attributes to arg attributes.
84   auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
85   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
86     func.setArgAttrs(functionType.getNumInputs() + i,
87                      func.getResultAttrs(*erasedIndicesIt));
88     if (addResultAttribute)
89       func.setArgAttr(functionType.getNumInputs() + i,
90                       StringAttr::get(func.getContext(), "bufferize.result"),
91                       UnitAttr::get(func.getContext()));
92   }
93 
94   // Erase the results.
95   func.eraseResults(erasedResultIndices);
96 
97   // Add the new arguments to the entry block if the function is not external.
98   if (func.isExternal())
99     return success();
100   Location loc = func.getLoc();
101   for (Type type : erasedResultTypes)
102     appendedEntryArgs.push_back(func.front().addArgument(type, loc));
103 
104   return success();
105 }
106 
107 // Updates all ReturnOps in the scope of the given func::FuncOp by either
108 // keeping them as return values or copying the associated buffer contents into
109 // the given out-params.
110 static LogicalResult
111 updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
112                 const bufferization::BufferResultsToOutParamsOpts &options) {
113   auto res = func.walk([&](func::ReturnOp op) {
114     SmallVector<Value, 6> copyIntoOutParams;
115     SmallVector<Value, 6> keepAsReturnOperands;
116     for (Value operand : op.getOperands()) {
117       if (isa<MemRefType>(operand.getType()))
118         copyIntoOutParams.push_back(operand);
119       else
120         keepAsReturnOperands.push_back(operand);
121     }
122     OpBuilder builder(op);
123     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
124       if (options.hoistStaticAllocs &&
125           isa_and_nonnull<bufferization::AllocationOpInterface>(
126               orig.getDefiningOp()) &&
127           mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
128         orig.replaceAllUsesWith(arg);
129         orig.getDefiningOp()->erase();
130       } else {
131         if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
132           return WalkResult::interrupt();
133       }
134     }
135     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
136     op.erase();
137     return WalkResult::advance();
138   });
139   return failure(res.wasInterrupted());
140 }
141 
142 // Updates all CallOps in the scope of the given ModuleOp by allocating
143 // temporary buffers for newly introduced out params.
144 static LogicalResult
145 updateCalls(ModuleOp module,
146             const bufferization::BufferResultsToOutParamsOpts &options) {
147   bool didFail = false;
148   SymbolTable symtab(module);
149   module.walk([&](func::CallOp op) {
150     auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
151     if (!callee) {
152       op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
153                      << "symbol table";
154       didFail = true;
155       return;
156     }
157     if (!options.filterFn(&callee))
158       return;
159     SmallVector<Value, 6> replaceWithNewCallResults;
160     SmallVector<Value, 6> replaceWithOutParams;
161     for (OpResult result : op.getResults()) {
162       if (isa<MemRefType>(result.getType()))
163         replaceWithOutParams.push_back(result);
164       else
165         replaceWithNewCallResults.push_back(result);
166     }
167     SmallVector<Value, 6> outParams;
168     OpBuilder builder(op);
169     for (Value memref : replaceWithOutParams) {
170       if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
171         op.emitError()
172             << "cannot create out param for dynamically shaped result";
173         didFail = true;
174         return;
175       }
176       auto memrefType = cast<MemRefType>(memref.getType());
177       auto allocType =
178           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
179                           AffineMap(), memrefType.getMemorySpace());
180       auto maybeOutParam =
181           options.allocationFn(builder, op.getLoc(), allocType);
182       if (failed(maybeOutParam)) {
183         op.emitError() << "failed to create allocation op";
184         didFail = true;
185         return;
186       }
187       Value outParam = maybeOutParam.value();
188       if (!hasStaticIdentityLayout(memrefType)) {
189         // Layout maps are already checked in `updateFuncOp`.
190         assert(hasFullyDynamicLayoutMap(memrefType) &&
191                "layout map not supported");
192         outParam =
193             builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
194       }
195       memref.replaceAllUsesWith(outParam);
196       outParams.push_back(outParam);
197     }
198 
199     auto newOperands = llvm::to_vector<6>(op.getOperands());
200     newOperands.append(outParams.begin(), outParams.end());
201     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
202         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
203     auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
204                                                 newResultTypes, newOperands);
205     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
206       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
207     op.erase();
208   });
209 
210   return failure(didFail);
211 }
212 
213 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
214     ModuleOp module,
215     const bufferization::BufferResultsToOutParamsOpts &options) {
216   for (auto func : module.getOps<func::FuncOp>()) {
217     if (!options.filterFn(&func))
218       continue;
219     SmallVector<BlockArgument, 6> appendedEntryArgs;
220     if (failed(
221             updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
222       return failure();
223     if (func.isExternal())
224       continue;
225     if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
226       return failure();
227     }
228   }
229   if (failed(updateCalls(module, options)))
230     return failure();
231   return success();
232 }
233 
234 namespace {
235 struct BufferResultsToOutParamsPass
236     : bufferization::impl::BufferResultsToOutParamsBase<
237           BufferResultsToOutParamsPass> {
238   explicit BufferResultsToOutParamsPass(
239       const bufferization::BufferResultsToOutParamsOpts &options)
240       : options(options) {}
241 
242   void runOnOperation() override {
243     // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
244     if (addResultAttribute)
245       options.addResultAttribute = true;
246     if (hoistStaticAllocs)
247       options.hoistStaticAllocs = true;
248 
249     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
250                                                               options)))
251       return signalPassFailure();
252   }
253 
254 private:
255   bufferization::BufferResultsToOutParamsOpts options;
256 };
257 } // namespace
258 
259 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
260     const bufferization::BufferResultsToOutParamsOpts &options) {
261   return std::make_unique<BufferResultsToOutParamsPass>(options);
262 }
263