xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (revision 26645ae2eea00456d98b497f348426c375409ce4)
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20 } // namespace bufferization
21 } // namespace mlir
22 
23 using namespace mlir;
24 using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
25 
26 /// Return `true` if the given MemRef type has a fully dynamic layout.
27 static bool hasFullyDynamicLayoutMap(MemRefType type) {
28   int64_t offset;
29   SmallVector<int64_t, 4> strides;
30   if (failed(getStridesAndOffset(type, strides, offset)))
31     return false;
32   if (!llvm::all_of(strides, ShapedType::isDynamic))
33     return false;
34   if (!ShapedType::isDynamic(offset))
35     return false;
36   return true;
37 }
38 
39 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
40 /// no layout).
41 static bool hasStaticIdentityLayout(MemRefType type) {
42   return type.getLayout().isIdentity();
43 }
44 
45 // Updates the func op and entry block.
46 //
47 // Any args appended to the entry block are added to `appendedEntryArgs`.
48 // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
49 // to each newly created function argument.
50 static LogicalResult
51 updateFuncOp(func::FuncOp func,
52              SmallVectorImpl<BlockArgument> &appendedEntryArgs,
53              bool addResultAttribute) {
54   auto functionType = func.getFunctionType();
55 
56   // Collect information about the results will become appended arguments.
57   SmallVector<Type, 6> erasedResultTypes;
58   BitVector erasedResultIndices(functionType.getNumResults());
59   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
60     if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
61       if (!hasStaticIdentityLayout(memrefType) &&
62           !hasFullyDynamicLayoutMap(memrefType)) {
63         // Only buffers with static identity layout can be allocated. These can
64         // be casted to memrefs with fully dynamic layout map. Other layout maps
65         // are not supported.
66         return func->emitError()
67                << "cannot create out param for result with unsupported layout";
68       }
69       erasedResultIndices.set(resultType.index());
70       erasedResultTypes.push_back(memrefType);
71     }
72   }
73 
74   // Add the new arguments to the function type.
75   auto newArgTypes = llvm::to_vector<6>(
76       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
77   auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
78                                            functionType.getResults());
79   func.setType(newFunctionType);
80 
81   // Transfer the result attributes to arg attributes.
82   auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
83   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
84     func.setArgAttrs(functionType.getNumInputs() + i,
85                      func.getResultAttrs(*erasedIndicesIt));
86     if (addResultAttribute)
87       func.setArgAttr(functionType.getNumInputs() + i,
88                       StringAttr::get(func.getContext(), "bufferize.result"),
89                       UnitAttr::get(func.getContext()));
90   }
91 
92   // Erase the results.
93   func.eraseResults(erasedResultIndices);
94 
95   // Add the new arguments to the entry block if the function is not external.
96   if (func.isExternal())
97     return success();
98   Location loc = func.getLoc();
99   for (Type type : erasedResultTypes)
100     appendedEntryArgs.push_back(func.front().addArgument(type, loc));
101 
102   return success();
103 }
104 
105 // Updates all ReturnOps in the scope of the given func::FuncOp by either
106 // keeping them as return values or copying the associated buffer contents into
107 // the given out-params.
108 static LogicalResult updateReturnOps(func::FuncOp func,
109                                      ArrayRef<BlockArgument> appendedEntryArgs,
110                                      MemCpyFn memCpyFn,
111                                      bool hoistStaticAllocs) {
112   auto res = func.walk([&](func::ReturnOp op) {
113     SmallVector<Value, 6> copyIntoOutParams;
114     SmallVector<Value, 6> keepAsReturnOperands;
115     for (Value operand : op.getOperands()) {
116       if (isa<MemRefType>(operand.getType()))
117         copyIntoOutParams.push_back(operand);
118       else
119         keepAsReturnOperands.push_back(operand);
120     }
121     OpBuilder builder(op);
122     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
123       if (hoistStaticAllocs &&
124           isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
125           mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
126         orig.replaceAllUsesWith(arg);
127         orig.getDefiningOp()->erase();
128       } else {
129         if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
130           return WalkResult::interrupt();
131       }
132     }
133     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
134     op.erase();
135     return WalkResult::advance();
136   });
137   return failure(res.wasInterrupted());
138 }
139 
140 // Updates all CallOps in the scope of the given ModuleOp by allocating
141 // temporary buffers for newly introduced out params.
142 static LogicalResult
143 updateCalls(ModuleOp module,
144             const bufferization::BufferResultsToOutParamsOpts &options) {
145   bool didFail = false;
146   SymbolTable symtab(module);
147   module.walk([&](func::CallOp op) {
148     auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
149     if (!callee) {
150       op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
151                      << "symbol table";
152       didFail = true;
153       return;
154     }
155     if (!options.filterFn(&callee))
156       return;
157     SmallVector<Value, 6> replaceWithNewCallResults;
158     SmallVector<Value, 6> replaceWithOutParams;
159     for (OpResult result : op.getResults()) {
160       if (isa<MemRefType>(result.getType()))
161         replaceWithOutParams.push_back(result);
162       else
163         replaceWithNewCallResults.push_back(result);
164     }
165     SmallVector<Value, 6> outParams;
166     OpBuilder builder(op);
167     for (Value memref : replaceWithOutParams) {
168       if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
169         op.emitError()
170             << "cannot create out param for dynamically shaped result";
171         didFail = true;
172         return;
173       }
174       auto memrefType = cast<MemRefType>(memref.getType());
175       auto allocType =
176           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
177                           AffineMap(), memrefType.getMemorySpace());
178       Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
179       if (!hasStaticIdentityLayout(memrefType)) {
180         // Layout maps are already checked in `updateFuncOp`.
181         assert(hasFullyDynamicLayoutMap(memrefType) &&
182                "layout map not supported");
183         outParam =
184             builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
185       }
186       memref.replaceAllUsesWith(outParam);
187       outParams.push_back(outParam);
188     }
189 
190     auto newOperands = llvm::to_vector<6>(op.getOperands());
191     newOperands.append(outParams.begin(), outParams.end());
192     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
193         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
194     auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
195                                                 newResultTypes, newOperands);
196     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
197       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
198     op.erase();
199   });
200 
201   return failure(didFail);
202 }
203 
204 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
205     ModuleOp module,
206     const bufferization::BufferResultsToOutParamsOpts &options) {
207   for (auto func : module.getOps<func::FuncOp>()) {
208     if (!options.filterFn(&func))
209       continue;
210     SmallVector<BlockArgument, 6> appendedEntryArgs;
211     if (failed(
212             updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
213       return failure();
214     if (func.isExternal())
215       continue;
216     auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
217                               Value to) {
218       builder.create<memref::CopyOp>(loc, from, to);
219       return success();
220     };
221     if (failed(updateReturnOps(func, appendedEntryArgs,
222                                options.memCpyFn.value_or(defaultMemCpyFn),
223                                options.hoistStaticAllocs))) {
224       return failure();
225     }
226   }
227   if (failed(updateCalls(module, options)))
228     return failure();
229   return success();
230 }
231 
232 namespace {
233 struct BufferResultsToOutParamsPass
234     : bufferization::impl::BufferResultsToOutParamsBase<
235           BufferResultsToOutParamsPass> {
236   explicit BufferResultsToOutParamsPass(
237       const bufferization::BufferResultsToOutParamsOpts &options)
238       : options(options) {}
239 
240   void runOnOperation() override {
241     // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
242     if (addResultAttribute)
243       options.addResultAttribute = true;
244     if (hoistStaticAllocs)
245       options.hoistStaticAllocs = true;
246 
247     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
248                                                               options)))
249       return signalPassFailure();
250   }
251 
252 private:
253   bufferization::BufferResultsToOutParamsOpts options;
254 };
255 } // namespace
256 
257 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
258     const bufferization::BufferResultsToOutParamsOpts &options) {
259   return std::make_unique<BufferResultsToOutParamsPass>(options);
260 }
261