xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 /// Return `true` if the given MemRef type has a fully dynamic layout.
19 static bool hasFullyDynamicLayoutMap(MemRefType type) {
20   int64_t offset;
21   SmallVector<int64_t, 4> strides;
22   if (failed(getStridesAndOffset(type, strides, offset)))
23     return false;
24   if (!llvm::all_of(strides, ShapedType::isDynamicStrideOrOffset))
25     return false;
26   if (!ShapedType::isDynamicStrideOrOffset(offset))
27     return false;
28   return true;
29 }
30 
31 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
32 /// no layout).
33 static bool hasStaticIdentityLayout(MemRefType type) {
34   return type.getLayout().isIdentity();
35 }
36 
37 // Updates the func op and entry block.
38 //
39 // Any args appended to the entry block are added to `appendedEntryArgs`.
40 static LogicalResult
41 updateFuncOp(func::FuncOp func,
42              SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
43   auto functionType = func.getFunctionType();
44 
45   // Collect information about the results will become appended arguments.
46   SmallVector<Type, 6> erasedResultTypes;
47   BitVector erasedResultIndices(functionType.getNumResults());
48   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
49     if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) {
50       if (!hasStaticIdentityLayout(memrefType) &&
51           !hasFullyDynamicLayoutMap(memrefType)) {
52         // Only buffers with static identity layout can be allocated. These can
53         // be casted to memrefs with fully dynamic layout map. Other layout maps
54         // are not supported.
55         return func->emitError()
56                << "cannot create out param for result with unsupported layout";
57       }
58       erasedResultIndices.set(resultType.index());
59       erasedResultTypes.push_back(memrefType);
60     }
61   }
62 
63   // Add the new arguments to the function type.
64   auto newArgTypes = llvm::to_vector<6>(
65       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
66   auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
67                                            functionType.getResults());
68   func.setType(newFunctionType);
69 
70   // Transfer the result attributes to arg attributes.
71   auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
72   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
73     func.setArgAttrs(functionType.getNumInputs() + i,
74                      func.getResultAttrs(*erasedIndicesIt));
75   }
76 
77   // Erase the results.
78   func.eraseResults(erasedResultIndices);
79 
80   // Add the new arguments to the entry block if the function is not external.
81   if (func.isExternal())
82     return success();
83   Location loc = func.getLoc();
84   for (Type type : erasedResultTypes)
85     appendedEntryArgs.push_back(func.front().addArgument(type, loc));
86 
87   return success();
88 }
89 
90 // Updates all ReturnOps in the scope of the given func::FuncOp by either
91 // keeping them as return values or copying the associated buffer contents into
92 // the given out-params.
93 static void updateReturnOps(func::FuncOp func,
94                             ArrayRef<BlockArgument> appendedEntryArgs) {
95   func.walk([&](func::ReturnOp op) {
96     SmallVector<Value, 6> copyIntoOutParams;
97     SmallVector<Value, 6> keepAsReturnOperands;
98     for (Value operand : op.getOperands()) {
99       if (operand.getType().isa<MemRefType>())
100         copyIntoOutParams.push_back(operand);
101       else
102         keepAsReturnOperands.push_back(operand);
103     }
104     OpBuilder builder(op);
105     for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
106       builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
107                                      std::get<1>(t));
108     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
109     op.erase();
110   });
111 }
112 
113 // Updates all CallOps in the scope of the given ModuleOp by allocating
114 // temporary buffers for newly introduced out params.
115 static LogicalResult updateCalls(ModuleOp module) {
116   bool didFail = false;
117   module.walk([&](func::CallOp op) {
118     SmallVector<Value, 6> replaceWithNewCallResults;
119     SmallVector<Value, 6> replaceWithOutParams;
120     for (OpResult result : op.getResults()) {
121       if (result.getType().isa<MemRefType>())
122         replaceWithOutParams.push_back(result);
123       else
124         replaceWithNewCallResults.push_back(result);
125     }
126     SmallVector<Value, 6> outParams;
127     OpBuilder builder(op);
128     for (Value memref : replaceWithOutParams) {
129       if (!memref.getType().cast<MemRefType>().hasStaticShape()) {
130         op.emitError()
131             << "cannot create out param for dynamically shaped result";
132         didFail = true;
133         return;
134       }
135       auto memrefType = memref.getType().cast<MemRefType>();
136       auto allocType =
137           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
138                           AffineMap(), memrefType.getMemorySpaceAsInt());
139       Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
140       if (!hasStaticIdentityLayout(memrefType)) {
141         // Layout maps are already checked in `updateFuncOp`.
142         assert(hasFullyDynamicLayoutMap(memrefType) &&
143                "layout map not supported");
144         outParam =
145             builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
146       }
147       memref.replaceAllUsesWith(outParam);
148       outParams.push_back(outParam);
149     }
150 
151     auto newOperands = llvm::to_vector<6>(op.getOperands());
152     newOperands.append(outParams.begin(), outParams.end());
153     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
154         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
155     auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
156                                                 newResultTypes, newOperands);
157     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
158       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
159     op.erase();
160   });
161 
162   return failure(didFail);
163 }
164 
165 LogicalResult
166 mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
167   for (auto func : module.getOps<func::FuncOp>()) {
168     SmallVector<BlockArgument, 6> appendedEntryArgs;
169     if (failed(updateFuncOp(func, appendedEntryArgs)))
170       return failure();
171     if (func.isExternal())
172       continue;
173     updateReturnOps(func, appendedEntryArgs);
174   }
175   if (failed(updateCalls(module)))
176     return failure();
177   return success();
178 }
179 
180 namespace {
181 struct BufferResultsToOutParamsPass
182     : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> {
183   void runOnOperation() override {
184     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
185       return signalPassFailure();
186   }
187 };
188 } // namespace
189 
190 std::unique_ptr<Pass>
191 mlir::bufferization::createBufferResultsToOutParamsPass() {
192   return std::make_unique<BufferResultsToOutParamsPass>();
193 }
194