xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (revision b0a4e958e85784cff46303c92b6a3a14b20fa1d8)
1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2e07a7fd5SMatthias Springer //
3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e07a7fd5SMatthias Springer //
7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===//
8e07a7fd5SMatthias Springer 
9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
126ecebb49SMatthias Springer #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
13a88732d9SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
16e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
17e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h"
18e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h"
19a1fe1f5fSKazu Hirata #include <optional>
20e07a7fd5SMatthias Springer 
21e07a7fd5SMatthias Springer namespace mlir {
22*b0a4e958SMatthias Springer /// Return all func.return ops in the given function.
23*b0a4e958SMatthias Springer SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
24*b0a4e958SMatthias Springer   SmallVector<func::ReturnOp> result;
25*b0a4e958SMatthias Springer   for (Block &b : funcOp.getBody())
26*b0a4e958SMatthias Springer     if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
27*b0a4e958SMatthias Springer       result.push_back(returnOp);
28*b0a4e958SMatthias Springer   return result;
29*b0a4e958SMatthias Springer }
30*b0a4e958SMatthias Springer 
31e07a7fd5SMatthias Springer namespace bufferization {
32e07a7fd5SMatthias Springer namespace func_ext {
33e07a7fd5SMatthias Springer 
3491c11574SAndrzej Warzyński void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
35e07a7fd5SMatthias Springer   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
36e07a7fd5SMatthias Springer   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
37e07a7fd5SMatthias Springer   auto createdAliasingResults =
38e07a7fd5SMatthias Springer       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
39e07a7fd5SMatthias Springer   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
40e07a7fd5SMatthias Springer   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
41e07a7fd5SMatthias Springer   (void)createdEquiv;
42e07a7fd5SMatthias Springer   (void)createdAliasingResults;
43e07a7fd5SMatthias Springer   (void)createdRead;
44e07a7fd5SMatthias Springer   (void)createdWritten;
45e07a7fd5SMatthias Springer #ifndef NDEBUG
46e07a7fd5SMatthias Springer   assert(createdEquiv.second && "equivalence info exists already");
47e07a7fd5SMatthias Springer   assert(createdAliasingResults.second && "aliasing info exists already");
48e07a7fd5SMatthias Springer   assert(createdRead.second && "bbarg access info exists already");
49e07a7fd5SMatthias Springer   assert(createdWritten.second && "bbarg access info exists already");
50e07a7fd5SMatthias Springer #endif // NDEBUG
51e07a7fd5SMatthias Springer }
52e07a7fd5SMatthias Springer 
53e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the
54e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be
5575ef84bfSOleg Shyshkov /// specified by the user (as per `options.functionArgTypeConverterFn`).
56e07a7fd5SMatthias Springer static BaseMemRefType
57e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
58e07a7fd5SMatthias Springer                              const BufferizationOptions &options) {
59e07a7fd5SMatthias Springer   auto tensorType =
605550c821STres Popp       dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
61e07a7fd5SMatthias Springer   assert(tensorType && "expected TensorType");
62f287da8aSMatthias Springer 
6375ef84bfSOleg Shyshkov   BaseMemRefType memrefType = options.functionArgTypeConverterFn(
64067d2779Sian Bearman       tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
65e07a7fd5SMatthias Springer 
66e07a7fd5SMatthias Springer   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
67e07a7fd5SMatthias Springer       index, BufferizationDialect::kBufferLayoutAttrName);
68e07a7fd5SMatthias Springer   if (!layoutAttr)
69e07a7fd5SMatthias Springer     return memrefType;
70e07a7fd5SMatthias Springer 
715550c821STres Popp   auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
72e07a7fd5SMatthias Springer   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
73e07a7fd5SMatthias Springer   return MemRefType::get(
74e07a7fd5SMatthias Springer       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
759bb63374SLei Zhang       layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
76e07a7fd5SMatthias Springer }
77e07a7fd5SMatthias Springer 
78e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`.
79e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) {
80217700baSMatthias Springer   SymbolRefAttr sym =
81217700baSMatthias Springer       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
82e07a7fd5SMatthias Springer   if (!sym)
83e07a7fd5SMatthias Springer     return nullptr;
84e07a7fd5SMatthias Springer   return dyn_cast_or_null<FuncOp>(
85e07a7fd5SMatthias Springer       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
86e07a7fd5SMatthias Springer }
87e07a7fd5SMatthias Springer 
88e07a7fd5SMatthias Springer /// Get FuncAnalysisState.
89e07a7fd5SMatthias Springer static const FuncAnalysisState &
90e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) {
91faa9be75SMatthias Springer   assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
92faa9be75SMatthias Springer   auto *result = static_cast<const OneShotAnalysisState &>(state)
93faa9be75SMatthias Springer                      .getExtension<FuncAnalysisState>();
94faa9be75SMatthias Springer   assert(result && "FuncAnalysisState does not exist");
95faa9be75SMatthias Springer   return *result;
96e07a7fd5SMatthias Springer }
97e07a7fd5SMatthias Springer 
98e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp.
99e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
100e07a7fd5SMatthias Springer                                                   FuncOp funcOp) {
101faa9be75SMatthias Springer   if (!isa<OneShotAnalysisState>(state))
102cd80617aSMatthias Springer     return FuncOpAnalysisState::NotAnalyzed;
103faa9be75SMatthias Springer   auto *funcState = static_cast<const OneShotAnalysisState &>(state)
104faa9be75SMatthias Springer                         .getExtension<FuncAnalysisState>();
105faa9be75SMatthias Springer   if (!funcState)
106faa9be75SMatthias Springer     return FuncOpAnalysisState::NotAnalyzed;
107faa9be75SMatthias Springer   const auto &analyzedFuncOps = funcState->analyzedFuncOps;
108cd80617aSMatthias Springer   auto it = analyzedFuncOps.find(funcOp);
109cd80617aSMatthias Springer   if (it == analyzedFuncOps.end())
110e07a7fd5SMatthias Springer     return FuncOpAnalysisState::NotAnalyzed;
111e07a7fd5SMatthias Springer   return it->second;
112e07a7fd5SMatthias Springer }
113e07a7fd5SMatthias Springer 
114e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the
115e07a7fd5SMatthias Springer /// specified return value (if any).
1160a81ace0SKazu Hirata static std::optional<int64_t>
1170a81ace0SKazu Hirata getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state,
118e07a7fd5SMatthias Springer                         int64_t returnValIdx) {
119e07a7fd5SMatthias Springer   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
120e07a7fd5SMatthias Springer   if (funcOpIt == state.equivalentFuncArgs.end())
121e07a7fd5SMatthias Springer     // No equivalence info stores for funcOp.
1221a36588eSKazu Hirata     return std::nullopt;
123e07a7fd5SMatthias Springer 
124e07a7fd5SMatthias Springer   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
125e07a7fd5SMatthias Springer   if (retValIt == funcOpIt->getSecond().end())
126e07a7fd5SMatthias Springer     // Return value has no equivalent bbArg.
1271a36588eSKazu Hirata     return std::nullopt;
128e07a7fd5SMatthias Springer 
129e07a7fd5SMatthias Springer   return retValIt->getSecond();
130e07a7fd5SMatthias Springer }
131e07a7fd5SMatthias Springer 
132e07a7fd5SMatthias Springer struct CallOpInterface
133e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
134e07a7fd5SMatthias Springer                                                     func::CallOp> {
135e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
136e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
137e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
138e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
139e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
140e07a7fd5SMatthias Springer 
141e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
142e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is read.
143e07a7fd5SMatthias Springer       return true;
144e07a7fd5SMatthias Springer 
145cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
146e07a7fd5SMatthias Springer     return funcState.readBbArgs.lookup(funcOp).contains(
147e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
148e07a7fd5SMatthias Springer   }
149e07a7fd5SMatthias Springer 
150e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
151e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
152e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
153e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
154e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
155e07a7fd5SMatthias Springer 
156e07a7fd5SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
157e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Assume that OpOperand is written.
158e07a7fd5SMatthias Springer       return true;
159e07a7fd5SMatthias Springer 
160cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
161e07a7fd5SMatthias Springer     return funcState.writtenBbArgs.lookup(funcOp).contains(
162e07a7fd5SMatthias Springer         opOperand.getOperandNumber());
163e07a7fd5SMatthias Springer   }
164e07a7fd5SMatthias Springer 
165a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
166e07a7fd5SMatthias Springer                                       const AnalysisState &state) const {
167e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
168e07a7fd5SMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
169e07a7fd5SMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
170f3483c23SMatthias Springer     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
171e07a7fd5SMatthias Springer       // FuncOp not analyzed yet. Any OpResult may be aliasing.
172a02ad6c1SMatthias Springer       return detail::unknownGetAliasingValues(opOperand);
173e07a7fd5SMatthias Springer 
174e07a7fd5SMatthias Springer     // Get aliasing results from state.
175cd80617aSMatthias Springer     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
176e07a7fd5SMatthias Springer     auto aliasingReturnVals =
177e07a7fd5SMatthias Springer         funcState.aliasingReturnVals.lookup(funcOp).lookup(
178e07a7fd5SMatthias Springer             opOperand.getOperandNumber());
1799fa6b350SMatthias Springer 
1809fa6b350SMatthias Springer     // Check if the aliasing OpResult is equivalent to the OpOperand.
1819fa6b350SMatthias Springer     std::optional<int64_t> equivalent = {};
1829fa6b350SMatthias Springer     if (aliasingReturnVals.size() == 1) {
1839fa6b350SMatthias Springer       equivalent = getEquivalentFuncArgIdx(funcOp, funcState,
1849fa6b350SMatthias Springer                                            aliasingReturnVals.front());
1859fa6b350SMatthias Springer       assert((!equivalent.has_value() ||
1869fa6b350SMatthias Springer               *equivalent == opOperand.getOperandNumber()) &&
1879fa6b350SMatthias Springer              "inconsistent analysis state");
1889fa6b350SMatthias Springer     }
189a02ad6c1SMatthias Springer     AliasingValueList result;
190e07a7fd5SMatthias Springer     for (int64_t resultIdx : aliasingReturnVals)
1919fa6b350SMatthias Springer       result.addAlias({callOp->getOpResult(resultIdx),
1929fa6b350SMatthias Springer                        equivalent.has_value() ? BufferRelation::Equivalent
1939fa6b350SMatthias Springer                                               : BufferRelation::Unknown,
1949fa6b350SMatthias Springer                        /*isDefinite=*/equivalent.has_value()});
195e07a7fd5SMatthias Springer     return result;
196e07a7fd5SMatthias Springer   }
197e07a7fd5SMatthias Springer 
19806dacf5eSMatthias Springer   FailureOr<BaseMemRefType>
19906dacf5eSMatthias Springer   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
200878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
20106dacf5eSMatthias Springer     auto callOp = cast<func::CallOp>(op);
20206dacf5eSMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
20306dacf5eSMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
20406dacf5eSMatthias Springer 
205c271ba7fSMatthias Springer     // If the callee was already bufferized, we can directly take the type from
20606dacf5eSMatthias Springer     // its signature.
20706dacf5eSMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
208c271ba7fSMatthias Springer     Type resultType =
209c271ba7fSMatthias Springer         funcType.getResult(cast<OpResult>(value).getResultNumber());
210c271ba7fSMatthias Springer     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
211c271ba7fSMatthias Springer       return bufferizedType;
212c271ba7fSMatthias Springer 
213c271ba7fSMatthias Springer     // Otherwise, call the type converter to compute the bufferized type.
214c271ba7fSMatthias Springer     auto tensorType = cast<TensorType>(resultType);
215c271ba7fSMatthias Springer     return options.functionArgTypeConverterFn(
216c271ba7fSMatthias Springer         tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
21706dacf5eSMatthias Springer   }
21806dacf5eSMatthias Springer 
219e07a7fd5SMatthias Springer   /// All function arguments are writable. It is the responsibility of the
220e07a7fd5SMatthias Springer   /// CallOp to insert buffer copies where necessary.
221e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
222b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
223e07a7fd5SMatthias Springer     func::CallOp callOp = cast<func::CallOp>(op);
224e07a7fd5SMatthias Springer 
22588539c5bSMatthias Springer     // 1. Compute the result types of the new CallOp.
22606dacf5eSMatthias Springer     SmallVector<Type> resultTypes;
22706dacf5eSMatthias Springer     for (Value result : callOp.getResults()) {
22806dacf5eSMatthias Springer       Type returnType = result.getType();
2295550c821STres Popp       if (!isa<TensorType>(returnType)) {
230e07a7fd5SMatthias Springer         // Non-tensor values are returned.
231e07a7fd5SMatthias Springer         resultTypes.push_back(returnType);
232e07a7fd5SMatthias Springer         continue;
233e07a7fd5SMatthias Springer       }
234e07a7fd5SMatthias Springer 
23588539c5bSMatthias Springer       // Returning a memref.
23606dacf5eSMatthias Springer       FailureOr<BaseMemRefType> resultType =
23706dacf5eSMatthias Springer           bufferization::getBufferType(result, options);
23806dacf5eSMatthias Springer       if (failed(resultType))
23906dacf5eSMatthias Springer         return failure();
24006dacf5eSMatthias Springer       resultTypes.push_back(*resultType);
241e07a7fd5SMatthias Springer     }
242e07a7fd5SMatthias Springer 
24306dacf5eSMatthias Springer     // 2. Rewrite tensor operands as memrefs based on type of the already
24406dacf5eSMatthias Springer     //    bufferized callee.
24506dacf5eSMatthias Springer     SmallVector<Value> newOperands;
24606dacf5eSMatthias Springer     FuncOp funcOp = getCalledFunction(callOp);
24706dacf5eSMatthias Springer     assert(funcOp && "expected CallOp to a FuncOp");
24806dacf5eSMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
249e07a7fd5SMatthias Springer 
25006dacf5eSMatthias Springer     for (OpOperand &opOperand : callOp->getOpOperands()) {
251e07a7fd5SMatthias Springer       // Non-tensor operands are just copied.
25206dacf5eSMatthias Springer       if (!isa<TensorType>(opOperand.get().getType())) {
25306dacf5eSMatthias Springer         newOperands.push_back(opOperand.get());
254e07a7fd5SMatthias Springer         continue;
255e07a7fd5SMatthias Springer       }
256e07a7fd5SMatthias Springer 
25788539c5bSMatthias Springer       // Retrieve buffers for tensor operands.
2585d50f51cSMatthias Springer       FailureOr<Value> maybeBuffer =
2595d50f51cSMatthias Springer           getBuffer(rewriter, opOperand.get(), options);
2605d50f51cSMatthias Springer       if (failed(maybeBuffer))
2615d50f51cSMatthias Springer         return failure();
26206dacf5eSMatthias Springer       Value buffer = *maybeBuffer;
263e07a7fd5SMatthias Springer 
2647f04a8adSLongsheng Mou       // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
26506dacf5eSMatthias Springer       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
266c271ba7fSMatthias Springer       if (!isa<BaseMemRefType>(memRefType)) {
267c271ba7fSMatthias Springer         // The called function was not bufferized yet. This can happen when
268c271ba7fSMatthias Springer         // there cycles in the function call graph. Compute the bufferized
269c271ba7fSMatthias Springer         // result type.
270c271ba7fSMatthias Springer         FailureOr<BaseMemRefType> maybeMemRefType =
271c271ba7fSMatthias Springer             bufferization::getBufferType(
272c271ba7fSMatthias Springer                 funcOp.getArgument(opOperand.getOperandNumber()), options);
273c271ba7fSMatthias Springer         if (failed(maybeMemRefType))
274c271ba7fSMatthias Springer           return failure();
275c271ba7fSMatthias Springer         memRefType = *maybeMemRefType;
276c271ba7fSMatthias Springer       }
277c271ba7fSMatthias Springer 
278e07a7fd5SMatthias Springer       // Since we don't yet have a clear layout story, to_memref may
279e07a7fd5SMatthias Springer       // conservatively turn tensors into more dynamic memref than necessary.
280e07a7fd5SMatthias Springer       // If the memref type of the callee fails, introduce an extra memref.cast
281e07a7fd5SMatthias Springer       // that will either canonicalize away or fail compilation until we can do
2827f04a8adSLongsheng Mou       // something better. Insert a reallocation + copy if it cannot be
2837f04a8adSLongsheng Mou       // statically guaranteed that a direct cast would be valid.
284e07a7fd5SMatthias Springer       if (buffer.getType() != memRefType) {
2857f04a8adSLongsheng Mou         auto memrefDstType = dyn_cast<MemRefType>(memRefType);
2867f04a8adSLongsheng Mou         assert(memrefDstType &&
2877f04a8adSLongsheng Mou                "buffer layout not supported on unranked tensors");
2887f04a8adSLongsheng Mou         FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
2897f04a8adSLongsheng Mou             rewriter, buffer, memrefDstType, options);
2907f04a8adSLongsheng Mou         if (failed(replacement))
2917f04a8adSLongsheng Mou           return failure();
2927f04a8adSLongsheng Mou         buffer = *replacement;
293e07a7fd5SMatthias Springer       }
29406dacf5eSMatthias Springer       newOperands.push_back(buffer);
295e07a7fd5SMatthias Springer     }
296e07a7fd5SMatthias Springer 
297e07a7fd5SMatthias Springer     // 3. Create the new CallOp.
298e07a7fd5SMatthias Springer     Operation *newCallOp = rewriter.create<func::CallOp>(
299e07a7fd5SMatthias Springer         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
300e07a7fd5SMatthias Springer     newCallOp->setAttrs(callOp->getAttrs());
301e07a7fd5SMatthias Springer 
302e07a7fd5SMatthias Springer     // 4. Replace the old op with the new op.
30306dacf5eSMatthias Springer     replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults());
304e07a7fd5SMatthias Springer 
305e07a7fd5SMatthias Springer     return success();
306e07a7fd5SMatthias Springer   }
307e07a7fd5SMatthias Springer };
308e07a7fd5SMatthias Springer 
309e07a7fd5SMatthias Springer struct ReturnOpInterface
310e07a7fd5SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
311e07a7fd5SMatthias Springer                                                     func::ReturnOp> {
312e07a7fd5SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
313e07a7fd5SMatthias Springer                               const AnalysisState &state) const {
314e07a7fd5SMatthias Springer     return true;
315e07a7fd5SMatthias Springer   }
316e07a7fd5SMatthias Springer 
317e07a7fd5SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
318e07a7fd5SMatthias Springer                                const AnalysisState &state) const {
319e07a7fd5SMatthias Springer     return false;
320e07a7fd5SMatthias Springer   }
321e07a7fd5SMatthias Springer 
322a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
323e07a7fd5SMatthias Springer                                       const AnalysisState &state) const {
324e07a7fd5SMatthias Springer     return {};
325e07a7fd5SMatthias Springer   }
326e07a7fd5SMatthias Springer 
327e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
329e07a7fd5SMatthias Springer #ifndef NDEBUG
330e07a7fd5SMatthias Springer     auto returnOp = cast<func::ReturnOp>(op);
331e07a7fd5SMatthias Springer     assert(isa<FuncOp>(returnOp->getParentOp()) &&
332e07a7fd5SMatthias Springer            "only support FuncOp parent for ReturnOp");
333e07a7fd5SMatthias Springer #endif // NDEBUG
334e07a7fd5SMatthias Springer 
335e07a7fd5SMatthias Springer     // ReturnOps are bufferized as part of FuncOps.
3360b293bf0SMatthias Springer     return success();
337e07a7fd5SMatthias Springer   }
338e07a7fd5SMatthias Springer };
339e07a7fd5SMatthias Springer 
340e07a7fd5SMatthias Springer struct FuncOpInterface
3416ecebb49SMatthias Springer     : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
3426ecebb49SMatthias Springer           FuncOpInterface, FuncOp> {
3436ecebb49SMatthias Springer 
3446ecebb49SMatthias Springer   static bool supportsUnstructuredControlFlow() { return true; }
3456ecebb49SMatthias Springer 
3468f2d83daSMatthias Springer   bool hasTensorSemantics(Operation *op) const {
347971b8525SJakub Kuderski     auto isaTensor = llvm::IsaPred<TensorType>;
3488f2d83daSMatthias Springer 
3498f2d83daSMatthias Springer     // A function has tensor semantics if it has tensor arguments/results.
3508f2d83daSMatthias Springer     auto funcOp = cast<FuncOp>(op);
3518f2d83daSMatthias Springer     bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
3528f2d83daSMatthias Springer     bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
3538f2d83daSMatthias Springer     if (hasTensorArg || hasTensorResult)
3548f2d83daSMatthias Springer       return true;
3558f2d83daSMatthias Springer 
3568f2d83daSMatthias Springer     // It also has tensor semantics if it has tensor block arguments.
3578f2d83daSMatthias Springer     // TODO: Decouple bufferization of unstructured control flow from
3588f2d83daSMatthias Springer     // BufferizableOpInterface implementations. We should only care about
3598f2d83daSMatthias Springer     // region entry block arguments here (which are already covered by the
3608f2d83daSMatthias Springer     // argument types of the function).
3618f2d83daSMatthias Springer     for (Block &block : funcOp.getBody())
3628f2d83daSMatthias Springer       if (any_of(block.getArgumentTypes(), isaTensor))
3638f2d83daSMatthias Springer         return true;
3648f2d83daSMatthias Springer 
3658f2d83daSMatthias Springer     return false;
3668f2d83daSMatthias Springer   }
3678f2d83daSMatthias Springer 
3686ecebb49SMatthias Springer   AliasingOpOperandList
3696ecebb49SMatthias Springer   getAliasingOpOperands(Operation *op, Value value,
3706ecebb49SMatthias Springer                         const AnalysisState &state) const {
3716ecebb49SMatthias Springer     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
3726ecebb49SMatthias Springer   }
3736ecebb49SMatthias Springer 
37406dacf5eSMatthias Springer   FailureOr<BaseMemRefType>
37506dacf5eSMatthias Springer   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
376878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
37706dacf5eSMatthias Springer     auto funcOp = cast<FuncOp>(op);
37806dacf5eSMatthias Springer     auto bbArg = cast<BlockArgument>(value);
3796ecebb49SMatthias Springer 
3806ecebb49SMatthias Springer     // Function arguments are special.
3816ecebb49SMatthias Springer     if (bbArg.getOwner() == &funcOp.getBody().front())
3826ecebb49SMatthias Springer       return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
3836ecebb49SMatthias Springer                                           options);
3846ecebb49SMatthias Springer 
3856ecebb49SMatthias Springer     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
3866ecebb49SMatthias Springer         getBufferType(op, value, options, invocationStack);
3876ecebb49SMatthias Springer   }
3886ecebb49SMatthias Springer 
389f287da8aSMatthias Springer   /// Rewrite function bbArgs and return values into buffer form. This function
390f287da8aSMatthias Springer   /// bufferizes the function signature and the ReturnOp. When the entire
391f287da8aSMatthias Springer   /// function body has been bufferized, function return types can be switched
392f287da8aSMatthias Springer   /// to more concise memref types as part of `foldMemRefCasts`.
393e07a7fd5SMatthias Springer   ///
394e07a7fd5SMatthias Springer   /// All function bbArgs are writable unless they are explicitly marked as
395e07a7fd5SMatthias Springer   /// read-only. Callers must insert copies when needed.
396e07a7fd5SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
397b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
398e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
399e07a7fd5SMatthias Springer     FunctionType funcType = funcOp.getFunctionType();
400e07a7fd5SMatthias Springer 
401217700baSMatthias Springer     // Compute the argument types.
402e07a7fd5SMatthias Springer     SmallVector<Type> argTypes;
403e07a7fd5SMatthias Springer     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
404e07a7fd5SMatthias Springer       Type argType = it.value();
405217700baSMatthias Springer       if (isa<TensorType>(argType)) {
406e07a7fd5SMatthias Springer         argTypes.push_back(
407e07a7fd5SMatthias Springer             getBufferizedFunctionArgType(funcOp, it.index(), options));
408e07a7fd5SMatthias Springer         continue;
409e07a7fd5SMatthias Springer       }
410e07a7fd5SMatthias Springer       argTypes.push_back(argType);
411e07a7fd5SMatthias Springer     }
412e07a7fd5SMatthias Springer 
413217700baSMatthias Springer     // Compute the result types.
414e07a7fd5SMatthias Springer     SmallVector<Type> retTypes;
415e07a7fd5SMatthias Springer     for (Type resultType : funcType.getResults()) {
416217700baSMatthias Springer       if (auto tensorType = dyn_cast<TensorType>(resultType)) {
417217700baSMatthias Springer         BaseMemRefType resultType = options.functionArgTypeConverterFn(
418217700baSMatthias Springer             tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
419217700baSMatthias Springer             options);
420217700baSMatthias Springer         retTypes.push_back(resultType);
421217700baSMatthias Springer         continue;
422217700baSMatthias Springer       }
423e07a7fd5SMatthias Springer       retTypes.push_back(resultType);
424e07a7fd5SMatthias Springer     }
425217700baSMatthias Springer 
426217700baSMatthias Springer     // Compute the new function type.
427217700baSMatthias Springer     auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
428217700baSMatthias Springer 
429217700baSMatthias Springer     // If the function has no body, set the new function type and we are done.
430217700baSMatthias Springer     if (funcOp.isExternal()) {
431217700baSMatthias Springer       funcOp.setType(newFuncType);
432e07a7fd5SMatthias Springer       return success();
433e07a7fd5SMatthias Springer     }
434e07a7fd5SMatthias Springer 
435a88732d9SMatthias Springer     // 1. Bufferize every block.
436a88732d9SMatthias Springer     for (Block &block : funcOp.getBody())
437a88732d9SMatthias Springer       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
438a88732d9SMatthias Springer                                                         options)))
43906dacf5eSMatthias Springer         return failure();
440e07a7fd5SMatthias Springer 
441*b0a4e958SMatthias Springer     // 2. Bufferize the operands of the all return op.
442*b0a4e958SMatthias Springer     for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
443*b0a4e958SMatthias Springer       assert(returnOp->getNumOperands() == retTypes.size() &&
444*b0a4e958SMatthias Springer              "incorrect number of return values");
445e07a7fd5SMatthias Springer       SmallVector<Value> returnValues;
446217700baSMatthias Springer       for (auto [returnVal, bufferizedType] :
447217700baSMatthias Springer            llvm::zip_equal(returnOp->getOperands(), retTypes)) {
4485550c821STres Popp         auto tensorType = dyn_cast<TensorType>(returnVal.getType());
449f287da8aSMatthias Springer         rewriter.setInsertionPoint(returnOp);
450e07a7fd5SMatthias Springer 
451e07a7fd5SMatthias Springer         // If not a tensor type just forward it.
452f287da8aSMatthias Springer         if (!tensorType) {
453e07a7fd5SMatthias Springer           returnValues.push_back(returnVal);
454e07a7fd5SMatthias Springer           continue;
455e07a7fd5SMatthias Springer         }
456e07a7fd5SMatthias Springer 
457217700baSMatthias Springer         // Note: If `inferFunctionResultLayout = true`, casts are later folded
45875ef84bfSOleg Shyshkov         // away.
459f287da8aSMatthias Springer         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
460*b0a4e958SMatthias Springer             returnOp.getLoc(), bufferizedType, returnVal);
461f287da8aSMatthias Springer         returnValues.push_back(toMemrefOp);
462e07a7fd5SMatthias Springer       }
463e07a7fd5SMatthias Springer 
464b74192b7SRiver Riddle       returnOp.getOperandsMutable().assign(returnValues);
465*b0a4e958SMatthias Springer     }
466e07a7fd5SMatthias Springer 
467217700baSMatthias Springer     // 3. Set the new function type.
468217700baSMatthias Springer     funcOp.setType(newFuncType);
469e07a7fd5SMatthias Springer     return success();
470e07a7fd5SMatthias Springer   }
471e07a7fd5SMatthias Springer 
472e07a7fd5SMatthias Springer   /// Return `true` if the given function argument is writable.
473e07a7fd5SMatthias Springer   bool isWritable(Operation *op, Value value,
474e07a7fd5SMatthias Springer                   const AnalysisState &state) const {
475e07a7fd5SMatthias Springer     auto funcOp = cast<FuncOp>(op);
4765550c821STres Popp     BlockArgument bbArg = dyn_cast<BlockArgument>(value);
477e07a7fd5SMatthias Springer     assert(bbArg && "expected BlockArgument");
478e07a7fd5SMatthias Springer 
4796ecebb49SMatthias Springer     // Non-entry block arguments are always writable. (They may alias with
4806ecebb49SMatthias Springer     // values that are not writable, which will turn them into read-only.)
4816ecebb49SMatthias Springer     if (bbArg.getOwner() != &funcOp.getBody().front())
4826ecebb49SMatthias Springer       return true;
4836ecebb49SMatthias Springer 
484e07a7fd5SMatthias Springer     // "bufferization.writable" overrides other writability decisions. This is
485e07a7fd5SMatthias Springer     // currently used for testing only.
486e07a7fd5SMatthias Springer     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
487e07a7fd5SMatthias Springer             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
488e07a7fd5SMatthias Springer       return writable.getValue();
489e07a7fd5SMatthias Springer 
490e07a7fd5SMatthias Springer     // All function arguments are writable by default.
491e07a7fd5SMatthias Springer     return true;
492e07a7fd5SMatthias Springer   }
493e07a7fd5SMatthias Springer };
494e07a7fd5SMatthias Springer 
495e07a7fd5SMatthias Springer } // namespace func_ext
496e07a7fd5SMatthias Springer } // namespace bufferization
497e07a7fd5SMatthias Springer } // namespace mlir
498e07a7fd5SMatthias Springer 
499e07a7fd5SMatthias Springer void mlir::bufferization::func_ext::
500e07a7fd5SMatthias Springer     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
501e07a7fd5SMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
502e07a7fd5SMatthias Springer     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
503e07a7fd5SMatthias Springer     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
504e07a7fd5SMatthias Springer     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
505e07a7fd5SMatthias Springer   });
506e07a7fd5SMatthias Springer }
507