xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (revision 598c5ddba6b0fd1e7226ffd4dc34e44ec7c7c513)
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 /// Return `true` if the given MemRef type has a fully dynamic layout.
19 static bool hasFullyDynamicLayoutMap(MemRefType type) {
20   int64_t offset;
21   SmallVector<int64_t, 4> strides;
22   if (failed(getStridesAndOffset(type, strides, offset)))
23     return false;
24   if (!llvm::all_of(strides, [](int64_t stride) {
25         return ShapedType::isDynamicStrideOrOffset(stride);
26       }))
27     return false;
28   if (!ShapedType::isDynamicStrideOrOffset(offset))
29     return false;
30   return true;
31 }
32 
33 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
34 /// no layout).
35 static bool hasStaticIdentityLayout(MemRefType type) {
36   return type.getLayout().isIdentity();
37 }
38 
39 // Updates the func op and entry block.
40 //
41 // Any args appended to the entry block are added to `appendedEntryArgs`.
42 static LogicalResult
43 updateFuncOp(func::FuncOp func,
44              SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
45   auto functionType = func.getFunctionType();
46 
47   // Collect information about the results will become appended arguments.
48   SmallVector<Type, 6> erasedResultTypes;
49   BitVector erasedResultIndices(functionType.getNumResults());
50   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
51     if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) {
52       if (!hasStaticIdentityLayout(memrefType) &&
53           !hasFullyDynamicLayoutMap(memrefType)) {
54         // Only buffers with static identity layout can be allocated. These can
55         // be casted to memrefs with fully dynamic layout map. Other layout maps
56         // are not supported.
57         return func->emitError()
58                << "cannot create out param for result with unsupported layout";
59       }
60       erasedResultIndices.set(resultType.index());
61       erasedResultTypes.push_back(memrefType);
62     }
63   }
64 
65   // Add the new arguments to the function type.
66   auto newArgTypes = llvm::to_vector<6>(
67       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
68   auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
69                                            functionType.getResults());
70   func.setType(newFunctionType);
71 
72   // Transfer the result attributes to arg attributes.
73   auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
74   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
75     func.setArgAttrs(functionType.getNumInputs() + i,
76                      func.getResultAttrs(*erasedIndicesIt));
77   }
78 
79   // Erase the results.
80   func.eraseResults(erasedResultIndices);
81 
82   // Add the new arguments to the entry block if the function is not external.
83   if (func.isExternal())
84     return success();
85   Location loc = func.getLoc();
86   for (Type type : erasedResultTypes)
87     appendedEntryArgs.push_back(func.front().addArgument(type, loc));
88 
89   return success();
90 }
91 
92 // Updates all ReturnOps in the scope of the given func::FuncOp by either
93 // keeping them as return values or copying the associated buffer contents into
94 // the given out-params.
95 static void updateReturnOps(func::FuncOp func,
96                             ArrayRef<BlockArgument> appendedEntryArgs) {
97   func.walk([&](func::ReturnOp op) {
98     SmallVector<Value, 6> copyIntoOutParams;
99     SmallVector<Value, 6> keepAsReturnOperands;
100     for (Value operand : op.getOperands()) {
101       if (operand.getType().isa<MemRefType>())
102         copyIntoOutParams.push_back(operand);
103       else
104         keepAsReturnOperands.push_back(operand);
105     }
106     OpBuilder builder(op);
107     for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
108       builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
109                                      std::get<1>(t));
110     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
111     op.erase();
112   });
113 }
114 
115 // Updates all CallOps in the scope of the given ModuleOp by allocating
116 // temporary buffers for newly introduced out params.
117 static LogicalResult updateCalls(ModuleOp module) {
118   bool didFail = false;
119   module.walk([&](func::CallOp op) {
120     SmallVector<Value, 6> replaceWithNewCallResults;
121     SmallVector<Value, 6> replaceWithOutParams;
122     for (OpResult result : op.getResults()) {
123       if (result.getType().isa<MemRefType>())
124         replaceWithOutParams.push_back(result);
125       else
126         replaceWithNewCallResults.push_back(result);
127     }
128     SmallVector<Value, 6> outParams;
129     OpBuilder builder(op);
130     for (Value memref : replaceWithOutParams) {
131       if (!memref.getType().cast<MemRefType>().hasStaticShape()) {
132         op.emitError()
133             << "cannot create out param for dynamically shaped result";
134         didFail = true;
135         return;
136       }
137       auto memrefType = memref.getType().cast<MemRefType>();
138       auto allocType =
139           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
140                           AffineMap(), memrefType.getMemorySpaceAsInt());
141       Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
142       if (!hasStaticIdentityLayout(memrefType)) {
143         // Layout maps are already checked in `updateFuncOp`.
144         assert(hasFullyDynamicLayoutMap(memrefType) &&
145                "layout map not supported");
146         outParam =
147             builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
148       }
149       memref.replaceAllUsesWith(outParam);
150       outParams.push_back(outParam);
151     }
152 
153     auto newOperands = llvm::to_vector<6>(op.getOperands());
154     newOperands.append(outParams.begin(), outParams.end());
155     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
156         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
157     auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
158                                                 newResultTypes, newOperands);
159     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
160       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
161     op.erase();
162   });
163 
164   return failure(didFail);
165 }
166 
167 LogicalResult
168 mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
169   for (auto func : module.getOps<func::FuncOp>()) {
170     SmallVector<BlockArgument, 6> appendedEntryArgs;
171     if (failed(updateFuncOp(func, appendedEntryArgs)))
172       return failure();
173     if (func.isExternal())
174       continue;
175     updateReturnOps(func, appendedEntryArgs);
176   }
177   if (failed(updateCalls(module)))
178     return failure();
179   return success();
180 }
181 
182 namespace {
183 struct BufferResultsToOutParamsPass
184     : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> {
185   void runOnOperation() override {
186     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
187       return signalPassFailure();
188   }
189 };
190 } // namespace
191 
192 std::unique_ptr<Pass>
193 mlir::bufferization::createBufferResultsToOutParamsPass() {
194   return std::make_unique<BufferResultsToOutParamsPass>();
195 }
196