xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (revision 5d04f0c937582357dc51230e17aef398e0e48cd6)
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20 } // namespace bufferization
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 /// Return `true` if the given MemRef type has a fully dynamic layout.
26 static bool hasFullyDynamicLayoutMap(MemRefType type) {
27   int64_t offset;
28   SmallVector<int64_t, 4> strides;
29   if (failed(getStridesAndOffset(type, strides, offset)))
30     return false;
31   if (!llvm::all_of(strides, ShapedType::isDynamic))
32     return false;
33   if (!ShapedType::isDynamic(offset))
34     return false;
35   return true;
36 }
37 
38 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
39 /// no layout).
40 static bool hasStaticIdentityLayout(MemRefType type) {
41   return type.getLayout().isIdentity();
42 }
43 
44 // Updates the func op and entry block.
45 //
46 // Any args appended to the entry block are added to `appendedEntryArgs`.
47 static LogicalResult
48 updateFuncOp(func::FuncOp func,
49              SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
50   auto functionType = func.getFunctionType();
51 
52   // Collect information about the results will become appended arguments.
53   SmallVector<Type, 6> erasedResultTypes;
54   BitVector erasedResultIndices(functionType.getNumResults());
55   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
56     if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) {
57       if (!hasStaticIdentityLayout(memrefType) &&
58           !hasFullyDynamicLayoutMap(memrefType)) {
59         // Only buffers with static identity layout can be allocated. These can
60         // be casted to memrefs with fully dynamic layout map. Other layout maps
61         // are not supported.
62         return func->emitError()
63                << "cannot create out param for result with unsupported layout";
64       }
65       erasedResultIndices.set(resultType.index());
66       erasedResultTypes.push_back(memrefType);
67     }
68   }
69 
70   // Add the new arguments to the function type.
71   auto newArgTypes = llvm::to_vector<6>(
72       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
73   auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
74                                            functionType.getResults());
75   func.setType(newFunctionType);
76 
77   // Transfer the result attributes to arg attributes.
78   auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
79   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
80     func.setArgAttrs(functionType.getNumInputs() + i,
81                      func.getResultAttrs(*erasedIndicesIt));
82   }
83 
84   // Erase the results.
85   func.eraseResults(erasedResultIndices);
86 
87   // Add the new arguments to the entry block if the function is not external.
88   if (func.isExternal())
89     return success();
90   Location loc = func.getLoc();
91   for (Type type : erasedResultTypes)
92     appendedEntryArgs.push_back(func.front().addArgument(type, loc));
93 
94   return success();
95 }
96 
97 // Updates all ReturnOps in the scope of the given func::FuncOp by either
98 // keeping them as return values or copying the associated buffer contents into
99 // the given out-params.
100 static void updateReturnOps(func::FuncOp func,
101                             ArrayRef<BlockArgument> appendedEntryArgs) {
102   func.walk([&](func::ReturnOp op) {
103     SmallVector<Value, 6> copyIntoOutParams;
104     SmallVector<Value, 6> keepAsReturnOperands;
105     for (Value operand : op.getOperands()) {
106       if (operand.getType().isa<MemRefType>())
107         copyIntoOutParams.push_back(operand);
108       else
109         keepAsReturnOperands.push_back(operand);
110     }
111     OpBuilder builder(op);
112     for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
113       builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
114                                      std::get<1>(t));
115     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
116     op.erase();
117   });
118 }
119 
120 // Updates all CallOps in the scope of the given ModuleOp by allocating
121 // temporary buffers for newly introduced out params.
122 static LogicalResult
123 updateCalls(ModuleOp module,
124             const bufferization::BufferResultsToOutParamsOptions &options) {
125   bool didFail = false;
126   SymbolTable symtab(module);
127   module.walk([&](func::CallOp op) {
128     auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
129     if (!callee) {
130       op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
131                      << "symbol table";
132       didFail = true;
133       return;
134     }
135     if (!options.filterFn(&callee))
136       return;
137     SmallVector<Value, 6> replaceWithNewCallResults;
138     SmallVector<Value, 6> replaceWithOutParams;
139     for (OpResult result : op.getResults()) {
140       if (result.getType().isa<MemRefType>())
141         replaceWithOutParams.push_back(result);
142       else
143         replaceWithNewCallResults.push_back(result);
144     }
145     SmallVector<Value, 6> outParams;
146     OpBuilder builder(op);
147     for (Value memref : replaceWithOutParams) {
148       if (!memref.getType().cast<MemRefType>().hasStaticShape()) {
149         op.emitError()
150             << "cannot create out param for dynamically shaped result";
151         didFail = true;
152         return;
153       }
154       auto memrefType = memref.getType().cast<MemRefType>();
155       auto allocType =
156           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
157                           AffineMap(), memrefType.getMemorySpace());
158       Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
159       if (!hasStaticIdentityLayout(memrefType)) {
160         // Layout maps are already checked in `updateFuncOp`.
161         assert(hasFullyDynamicLayoutMap(memrefType) &&
162                "layout map not supported");
163         outParam =
164             builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
165       }
166       memref.replaceAllUsesWith(outParam);
167       outParams.push_back(outParam);
168     }
169 
170     auto newOperands = llvm::to_vector<6>(op.getOperands());
171     newOperands.append(outParams.begin(), outParams.end());
172     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
173         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
174     auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
175                                                 newResultTypes, newOperands);
176     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
177       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
178     op.erase();
179   });
180 
181   return failure(didFail);
182 }
183 
184 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
185     ModuleOp module,
186     const bufferization::BufferResultsToOutParamsOptions &options) {
187   for (auto func : module.getOps<func::FuncOp>()) {
188     if (!options.filterFn(&func))
189       continue;
190     SmallVector<BlockArgument, 6> appendedEntryArgs;
191     if (failed(updateFuncOp(func, appendedEntryArgs)))
192       return failure();
193     if (func.isExternal())
194       continue;
195     updateReturnOps(func, appendedEntryArgs);
196   }
197   if (failed(updateCalls(module, options)))
198     return failure();
199   return success();
200 }
201 
202 namespace {
203 struct BufferResultsToOutParamsPass
204     : bufferization::impl::BufferResultsToOutParamsBase<
205           BufferResultsToOutParamsPass> {
206   explicit BufferResultsToOutParamsPass(
207       const bufferization::BufferResultsToOutParamsOptions &options)
208       : options(options) {}
209 
210   void runOnOperation() override {
211     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
212                                                               options)))
213       return signalPassFailure();
214   }
215 
216 private:
217   bufferization::BufferResultsToOutParamsOptions options;
218 };
219 } // namespace
220 
221 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
222     const bufferization::BufferResultsToOutParamsOptions &options) {
223   return std::make_unique<BufferResultsToOutParamsPass>(options);
224 }
225