xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (revision b0a4e958e85784cff46303c92b6a3a14b20fa1d8)
1e07a7fd5SMatthias Springer //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2e07a7fd5SMatthias Springer //
3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e07a7fd5SMatthias Springer //
7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===//
8e07a7fd5SMatthias Springer //
9e07a7fd5SMatthias Springer // Module Bufferization is an extension of One-Shot Bufferize that
10e07a7fd5SMatthias Springer // bufferizes function boundaries. It provides `BufferizableOpInterface`
11e07a7fd5SMatthias Springer // implementations for FuncOp, CallOp and ReturnOp.
12e07a7fd5SMatthias Springer //
13e07a7fd5SMatthias Springer // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14e07a7fd5SMatthias Springer // This function analyzes the given module and determines the order of analysis
15e07a7fd5SMatthias Springer // and bufferization: Functions that are called are processed before their
16e07a7fd5SMatthias Springer // respective callers.
17e07a7fd5SMatthias Springer //
18e07a7fd5SMatthias Springer // After analyzing a FuncOp, additional information about its bbArgs is
193490aadfSMatthias Springer // gathered and stored in `FuncAnalysisState`.
20e07a7fd5SMatthias Springer //
21e07a7fd5SMatthias Springer // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
22e07a7fd5SMatthias Springer // for
23e07a7fd5SMatthias Springer //   each tensor return value (if any).
24e07a7fd5SMatthias Springer // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25e07a7fd5SMatthias Springer //   read/written.
26e07a7fd5SMatthias Springer //
27e07a7fd5SMatthias Springer // Module Bufferization implements the following calling convention.
28e07a7fd5SMatthias Springer //
29e07a7fd5SMatthias Springer // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30e07a7fd5SMatthias Springer //   be written to in-place.
31e07a7fd5SMatthias Springer // * If a tensor operand of a CallOp is read after the CallOp, the operand of
32e07a7fd5SMatthias Springer //   the CallOp must bufferize out-of-place.
33e07a7fd5SMatthias Springer //
34e07a7fd5SMatthias Springer // Example: The tensor.insert op bufferizes in-place because it is allowed to
35e07a7fd5SMatthias Springer // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
36e07a7fd5SMatthias Springer // out-of-place because `%t0` is modified by the callee but read by the
37e07a7fd5SMatthias Springer // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
38e07a7fd5SMatthias Springer // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
39e07a7fd5SMatthias Springer // ```
40e07a7fd5SMatthias Springer // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
41e07a7fd5SMatthias Springer //   %f = ... : f32
42e07a7fd5SMatthias Springer //   %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
43e07a7fd5SMatthias Springer //   return %0 : tensor<?xf32>
44e07a7fd5SMatthias Springer // }
45e07a7fd5SMatthias Springer //
46e07a7fd5SMatthias Springer // func @caller() -> () {
47e07a7fd5SMatthias Springer //   %t0 = ... : tensor<?xf32>
48e07a7fd5SMatthias Springer //   %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
49e07a7fd5SMatthias Springer //   %2 = tensor.extract %1[...]  : tensor<?xf32>
50e07a7fd5SMatthias Springer // }
51e07a7fd5SMatthias Springer // ```
52e07a7fd5SMatthias Springer //
53e07a7fd5SMatthias Springer // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
54e07a7fd5SMatthias Springer // analyze the function body. In such a case, the CallOp analysis conservatively
55e07a7fd5SMatthias Springer // assumes that each tensor OpOperand is both read and written.
56e07a7fd5SMatthias Springer //
57e07a7fd5SMatthias Springer // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58e07a7fd5SMatthias Springer // as "not reading" and/or "not writing".
59e07a7fd5SMatthias Springer 
60e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
61e07a7fd5SMatthias Springer 
62e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
63e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
64e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
65e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
66e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
6728b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
68e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
69e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
70971b8525SJakub Kuderski #include "mlir/IR/BuiltinTypes.h"
71e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h"
72e07a7fd5SMatthias Springer 
73e07a7fd5SMatthias Springer using namespace mlir;
74e07a7fd5SMatthias Springer using namespace mlir::bufferization;
75e07a7fd5SMatthias Springer using namespace mlir::bufferization::func_ext;
76e07a7fd5SMatthias Springer 
77e07a7fd5SMatthias Springer /// A mapping of FuncOps to their callers.
7891c11574SAndrzej Warzyński using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
79e07a7fd5SMatthias Springer 
80e07a7fd5SMatthias Springer /// Get or create FuncAnalysisState.
81faa9be75SMatthias Springer static FuncAnalysisState &
82faa9be75SMatthias Springer getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
83faa9be75SMatthias Springer   auto *result = state.getExtension<FuncAnalysisState>();
84faa9be75SMatthias Springer   if (result)
85faa9be75SMatthias Springer     return *result;
86faa9be75SMatthias Springer   return state.addExtension<FuncAnalysisState>();
87e07a7fd5SMatthias Springer }
88e07a7fd5SMatthias Springer 
89e07a7fd5SMatthias Springer namespace {
90e07a7fd5SMatthias Springer 
91e07a7fd5SMatthias Springer /// Annotate IR with the results of the analysis. For testing purposes only.
92e07a7fd5SMatthias Springer static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
93e07a7fd5SMatthias Springer                                           BlockArgument bbArg) {
94e07a7fd5SMatthias Springer   const char *kEquivalentArgsAttr = "__equivalent_func_args__";
95e07a7fd5SMatthias Springer   Operation *op = returnVal.getOwner();
96e07a7fd5SMatthias Springer 
97e07a7fd5SMatthias Springer   SmallVector<int64_t> equivBbArgs;
98e07a7fd5SMatthias Springer   if (op->hasAttr(kEquivalentArgsAttr)) {
995550c821STres Popp     auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
100e07a7fd5SMatthias Springer     equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
1015550c821STres Popp       return cast<IntegerAttr>(a).getValue().getSExtValue();
102e07a7fd5SMatthias Springer     }));
103e07a7fd5SMatthias Springer   } else {
104e07a7fd5SMatthias Springer     equivBbArgs.append(op->getNumOperands(), -1);
105e07a7fd5SMatthias Springer   }
106e07a7fd5SMatthias Springer   equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
107e07a7fd5SMatthias Springer 
108e07a7fd5SMatthias Springer   OpBuilder b(op->getContext());
109e07a7fd5SMatthias Springer   op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
110e07a7fd5SMatthias Springer }
111e07a7fd5SMatthias Springer 
112e07a7fd5SMatthias Springer /// Store function BlockArguments that are equivalent to/aliasing a returned
113e07a7fd5SMatthias Springer /// value in FuncAnalysisState.
114faa9be75SMatthias Springer static LogicalResult
11591c11574SAndrzej Warzyński aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
116faa9be75SMatthias Springer                              FuncAnalysisState &funcState) {
11791c11574SAndrzej Warzyński   if (funcOp.getBody().empty()) {
1184002eaaaSMatthias Springer     // No function body available. Conservatively assume that every tensor
1194002eaaaSMatthias Springer     // return value may alias with any tensor bbArg.
12091c11574SAndrzej Warzyński     FunctionType type = funcOp.getFunctionType();
12191c11574SAndrzej Warzyński     for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
1225550c821STres Popp       if (!isa<TensorType>(inputIt.value()))
1234002eaaaSMatthias Springer         continue;
12491c11574SAndrzej Warzyński       for (const auto &resultIt : llvm::enumerate(type.getResults())) {
1255550c821STres Popp         if (!isa<TensorType>(resultIt.value()))
1264002eaaaSMatthias Springer           continue;
1274002eaaaSMatthias Springer         int64_t returnIdx = resultIt.index();
1284002eaaaSMatthias Springer         int64_t bbArgIdx = inputIt.index();
1294002eaaaSMatthias Springer         funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
1304002eaaaSMatthias Springer       }
1314002eaaaSMatthias Springer     }
1324002eaaaSMatthias Springer     return success();
1334002eaaaSMatthias Springer   }
1344002eaaaSMatthias Springer 
135*b0a4e958SMatthias Springer   // Find all func.return ops.
136*b0a4e958SMatthias Springer   SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
137*b0a4e958SMatthias Springer   assert(!returnOps.empty() && "expected at least one ReturnOp");
138e07a7fd5SMatthias Springer 
139*b0a4e958SMatthias Springer   // Build alias sets. Merge all aliases from all func.return ops.
140*b0a4e958SMatthias Springer   for (BlockArgument bbArg : funcOp.getArguments()) {
1415550c821STres Popp     if (isa<RankedTensorType>(bbArg.getType())) {
142e07a7fd5SMatthias Springer       int64_t bbArgIdx = bbArg.getArgNumber();
143*b0a4e958SMatthias Springer       // Store aliases in a set, so that we don't add the same alias twice.
144*b0a4e958SMatthias Springer       SetVector<int64_t> aliases;
145*b0a4e958SMatthias Springer       for (func::ReturnOp returnOp : returnOps) {
146*b0a4e958SMatthias Springer         for (OpOperand &returnVal : returnOp->getOpOperands()) {
147*b0a4e958SMatthias Springer           if (isa<RankedTensorType>(returnVal.get().getType())) {
148*b0a4e958SMatthias Springer             int64_t returnIdx = returnVal.getOperandNumber();
149b7858f85SMatthias Springer             if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
150*b0a4e958SMatthias Springer               aliases.insert(returnIdx);
151*b0a4e958SMatthias Springer           }
152*b0a4e958SMatthias Springer         }
153*b0a4e958SMatthias Springer       }
154*b0a4e958SMatthias Springer       for (int64_t alias : aliases)
155*b0a4e958SMatthias Springer         funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
156*b0a4e958SMatthias Springer     }
157*b0a4e958SMatthias Springer   }
158*b0a4e958SMatthias Springer 
159*b0a4e958SMatthias Springer   // Build equivalence sets.
160*b0a4e958SMatthias Springer   // Helper function that finds an equivalent block argument index for the
161*b0a4e958SMatthias Springer   // given OpOperand. Return std::nullopt if no equivalent block argument could
162*b0a4e958SMatthias Springer   // be found.
163*b0a4e958SMatthias Springer   auto findEquivalentBlockArgIdx =
164*b0a4e958SMatthias Springer       [&](OpOperand &opOperand) -> std::optional<int64_t> {
165*b0a4e958SMatthias Springer     Value v = opOperand.get();
166*b0a4e958SMatthias Springer     if (!isa<TensorType>(v.getType()))
167*b0a4e958SMatthias Springer       return std::nullopt;
168*b0a4e958SMatthias Springer     for (BlockArgument bbArg : funcOp.getArguments()) {
169*b0a4e958SMatthias Springer       if (isa<RankedTensorType>(bbArg.getType())) {
170*b0a4e958SMatthias Springer         if (state.areEquivalentBufferizedValues(v, bbArg)) {
171*b0a4e958SMatthias Springer           if (state.getOptions().testAnalysisOnly)
172*b0a4e958SMatthias Springer             annotateEquivalentReturnBbArg(opOperand, bbArg);
173*b0a4e958SMatthias Springer           return bbArg.getArgNumber();
174*b0a4e958SMatthias Springer         }
175*b0a4e958SMatthias Springer       }
176*b0a4e958SMatthias Springer     }
177*b0a4e958SMatthias Springer     return std::nullopt;
178*b0a4e958SMatthias Springer   };
179*b0a4e958SMatthias Springer 
180*b0a4e958SMatthias Springer   int64_t numResults = returnOps.front()->getNumOperands();
181*b0a4e958SMatthias Springer   for (int64_t i = 0; i < numResults; ++i) {
182*b0a4e958SMatthias Springer     // Find the equivalent block argument index for the i-th operand of the
183*b0a4e958SMatthias Springer     // first func.return op.
184*b0a4e958SMatthias Springer     std::optional<int64_t> maybeEquiv =
185*b0a4e958SMatthias Springer         findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186*b0a4e958SMatthias Springer     if (!maybeEquiv.has_value())
187*b0a4e958SMatthias Springer       continue;
188*b0a4e958SMatthias Springer     int64_t bbArgIdx = *maybeEquiv;
189*b0a4e958SMatthias Springer     bool allEquiv = true;
190*b0a4e958SMatthias Springer 
191*b0a4e958SMatthias Springer     // Check if all other func.return ops have the same equivalent block
192*b0a4e958SMatthias Springer     // argument for the i-th operand. In contrast to aliasing information,
193*b0a4e958SMatthias Springer     // which is just "merged", equivalence information must match across all
194*b0a4e958SMatthias Springer     // func.return ops.
195*b0a4e958SMatthias Springer     for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
196*b0a4e958SMatthias Springer       std::optional<int64_t> maybeEquiv =
197*b0a4e958SMatthias Springer           findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198*b0a4e958SMatthias Springer       if (maybeEquiv != bbArgIdx) {
199*b0a4e958SMatthias Springer         allEquiv = false;
200*b0a4e958SMatthias Springer         break;
201*b0a4e958SMatthias Springer       }
202*b0a4e958SMatthias Springer     }
203*b0a4e958SMatthias Springer 
204*b0a4e958SMatthias Springer     // All func.return ops have the same equivalent block argument for the i-th
205*b0a4e958SMatthias Springer     // operand.
206*b0a4e958SMatthias Springer     if (allEquiv)
207*b0a4e958SMatthias Springer       funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
208e07a7fd5SMatthias Springer   }
209e07a7fd5SMatthias Springer 
210e07a7fd5SMatthias Springer   return success();
211e07a7fd5SMatthias Springer }
212e07a7fd5SMatthias Springer 
21391c11574SAndrzej Warzyński static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
21491c11574SAndrzej Warzyński                                   bool isWritten) {
215e07a7fd5SMatthias Springer   OpBuilder b(funcOp.getContext());
216e07a7fd5SMatthias Springer   Attribute accessType;
217e07a7fd5SMatthias Springer   if (isRead && isWritten) {
218e07a7fd5SMatthias Springer     accessType = b.getStringAttr("read-write");
219e07a7fd5SMatthias Springer   } else if (isRead) {
220e07a7fd5SMatthias Springer     accessType = b.getStringAttr("read");
221e07a7fd5SMatthias Springer   } else if (isWritten) {
222e07a7fd5SMatthias Springer     accessType = b.getStringAttr("write");
223e07a7fd5SMatthias Springer   } else {
224e07a7fd5SMatthias Springer     accessType = b.getStringAttr("none");
225e07a7fd5SMatthias Springer   }
2264002eaaaSMatthias Springer   funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
2274002eaaaSMatthias Springer                     accessType);
228e07a7fd5SMatthias Springer }
229e07a7fd5SMatthias Springer 
2303490aadfSMatthias Springer /// Determine which FuncOp bbArgs are read and which are written. When run on a
2313490aadfSMatthias Springer /// function with unknown ops, we conservatively assume that such ops bufferize
2323490aadfSMatthias Springer /// to a read + write.
233faa9be75SMatthias Springer static LogicalResult
23491c11574SAndrzej Warzyński funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
235faa9be75SMatthias Springer                              FuncAnalysisState &funcState) {
23691c11574SAndrzej Warzyński   for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
23791c11574SAndrzej Warzyński        ++idx) {
2384002eaaaSMatthias Springer     // Skip non-tensor arguments.
23991c11574SAndrzej Warzyński     if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
2404002eaaaSMatthias Springer       continue;
2414002eaaaSMatthias Springer     bool isRead;
2424002eaaaSMatthias Springer     bool isWritten;
2434002eaaaSMatthias Springer     if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
2444002eaaaSMatthias Springer             idx, BufferizationDialect::kBufferAccessAttrName)) {
2454002eaaaSMatthias Springer       // Buffer access behavior is specified on the function. Skip the analysis.
2464002eaaaSMatthias Springer       StringRef str = accessAttr.getValue();
2474002eaaaSMatthias Springer       isRead = str == "read" || str == "read-write";
2484002eaaaSMatthias Springer       isWritten = str == "write" || str == "read-write";
24991c11574SAndrzej Warzyński     } else if (funcOp.getBody().empty()) {
250e07a7fd5SMatthias Springer       // If the function has no body, conservatively assume that all args are
251e07a7fd5SMatthias Springer       // read + written.
2524002eaaaSMatthias Springer       isRead = true;
2534002eaaaSMatthias Springer       isWritten = true;
2544002eaaaSMatthias Springer     } else {
2554002eaaaSMatthias Springer       // Analyze the body of the function.
2564002eaaaSMatthias Springer       BlockArgument bbArg = funcOp.getArgument(idx);
2574002eaaaSMatthias Springer       isRead = state.isValueRead(bbArg);
2584002eaaaSMatthias Springer       isWritten = state.isValueWritten(bbArg);
259e07a7fd5SMatthias Springer     }
260e07a7fd5SMatthias Springer 
261e07a7fd5SMatthias Springer     if (state.getOptions().testAnalysisOnly)
2624002eaaaSMatthias Springer       annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
263e07a7fd5SMatthias Springer     if (isRead)
2644002eaaaSMatthias Springer       funcState.readBbArgs[funcOp].insert(idx);
265e07a7fd5SMatthias Springer     if (isWritten)
2664002eaaaSMatthias Springer       funcState.writtenBbArgs[funcOp].insert(idx);
267e07a7fd5SMatthias Springer   }
268e07a7fd5SMatthias Springer 
269e07a7fd5SMatthias Springer   return success();
270e07a7fd5SMatthias Springer }
271e07a7fd5SMatthias Springer } // namespace
272e07a7fd5SMatthias Springer 
273e07a7fd5SMatthias Springer /// Remove bufferization attributes on FuncOp arguments.
274e07a7fd5SMatthias Springer static void removeBufferizationAttributes(BlockArgument bbArg) {
27591c11574SAndrzej Warzyński   auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
276e07a7fd5SMatthias Springer   funcOp.removeArgAttr(bbArg.getArgNumber(),
277e07a7fd5SMatthias Springer                        BufferizationDialect::kBufferLayoutAttrName);
278e07a7fd5SMatthias Springer   funcOp.removeArgAttr(bbArg.getArgNumber(),
279e07a7fd5SMatthias Springer                        BufferizationDialect::kWritableAttrName);
280e07a7fd5SMatthias Springer }
281e07a7fd5SMatthias Springer 
28291c11574SAndrzej Warzyński /// Return the func::FuncOp called by `callOp`.
28391c11574SAndrzej Warzyński static func::FuncOp getCalledFunction(func::CallOp callOp) {
2849d34c052SMatthias Springer   SymbolRefAttr sym =
2859d34c052SMatthias Springer       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
286e07a7fd5SMatthias Springer   if (!sym)
287e07a7fd5SMatthias Springer     return nullptr;
28891c11574SAndrzej Warzyński   return dyn_cast_or_null<func::FuncOp>(
289e07a7fd5SMatthias Springer       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
290e07a7fd5SMatthias Springer }
291e07a7fd5SMatthias Springer 
292e07a7fd5SMatthias Springer /// Gather equivalence info of CallOps.
293e07a7fd5SMatthias Springer /// Note: This only adds new equivalence info if the called function was already
294e07a7fd5SMatthias Springer /// analyzed.
295e07a7fd5SMatthias Springer // TODO: This does not handle cyclic function call graphs etc.
29691c11574SAndrzej Warzyński static void equivalenceAnalysis(func::FuncOp funcOp,
297faa9be75SMatthias Springer                                 OneShotAnalysisState &state,
298faa9be75SMatthias Springer                                 FuncAnalysisState &funcState) {
29991c11574SAndrzej Warzyński   funcOp->walk([&](func::CallOp callOp) {
30091c11574SAndrzej Warzyński     func::FuncOp calledFunction = getCalledFunction(callOp);
30191c11574SAndrzej Warzyński     assert(calledFunction && "could not retrieved called func::FuncOp");
302e07a7fd5SMatthias Springer 
303e07a7fd5SMatthias Springer     // No equivalence info available for the called function.
304e07a7fd5SMatthias Springer     if (!funcState.equivalentFuncArgs.count(calledFunction))
305e07a7fd5SMatthias Springer       return WalkResult::skip();
306e07a7fd5SMatthias Springer 
307e07a7fd5SMatthias Springer     for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
308e07a7fd5SMatthias Springer       int64_t returnIdx = it.first;
309e07a7fd5SMatthias Springer       int64_t bbargIdx = it.second;
310bf582569SMatthias Springer       if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
311bf582569SMatthias Springer         continue;
31291c11574SAndrzej Warzyński       Value returnVal = callOp.getResult(returnIdx);
313e07a7fd5SMatthias Springer       Value argVal = callOp->getOperand(bbargIdx);
314cf2d374eSMatthias Springer       state.unionEquivalenceClasses(returnVal, argVal);
315e07a7fd5SMatthias Springer     }
316e07a7fd5SMatthias Springer 
317e07a7fd5SMatthias Springer     return WalkResult::advance();
318e07a7fd5SMatthias Springer   });
319e07a7fd5SMatthias Springer }
320e07a7fd5SMatthias Springer 
3213d0ca2cfSMatthias Springer /// Return "true" if the given function signature has tensor semantics.
32291c11574SAndrzej Warzyński static bool hasTensorSignature(func::FuncOp funcOp) {
32391c11574SAndrzej Warzyński   return llvm::any_of(funcOp.getFunctionType().getInputs(),
32491c11574SAndrzej Warzyński                       llvm::IsaPred<TensorType>) ||
32591c11574SAndrzej Warzyński          llvm::any_of(funcOp.getFunctionType().getResults(),
32691c11574SAndrzej Warzyński                       llvm::IsaPred<TensorType>);
3273d0ca2cfSMatthias Springer }
3283d0ca2cfSMatthias Springer 
329e07a7fd5SMatthias Springer /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
330c271ba7fSMatthias Springer /// callee-caller order (i.e., callees without callers first). Store all
331c271ba7fSMatthias Springer /// remaining functions (i.e., the ones that call each other recursively) in
332c271ba7fSMatthias Springer /// `remainingFuncOps`.
333c271ba7fSMatthias Springer ///
334e07a7fd5SMatthias Springer /// Store the map of FuncOp to all its callers in `callerMap`.
335c271ba7fSMatthias Springer ///
336c271ba7fSMatthias Springer /// Return `failure()` if we are unable to retrieve the called FuncOp from
337c271ba7fSMatthias Springer /// any func::CallOp.
338c271ba7fSMatthias Springer static LogicalResult getFuncOpsOrderedByCalls(
339c271ba7fSMatthias Springer     ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
340c271ba7fSMatthias Springer     SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
341e07a7fd5SMatthias Springer   // For each FuncOp, the set of functions called by it (i.e. the union of
342dc700f1eSIngo Müller   // symbols of all nested func::CallOp).
34391c11574SAndrzej Warzyński   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
344dc700f1eSIngo Müller   // For each FuncOp, the number of func::CallOp it contains.
34591c11574SAndrzej Warzyński   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
34691c11574SAndrzej Warzyński   WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
3473d0ca2cfSMatthias Springer     // Collect function calls and populate the caller map.
348e07a7fd5SMatthias Springer     numberCallOpsContainedInFuncOp[funcOp] = 0;
34991c11574SAndrzej Warzyński     return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
35091c11574SAndrzej Warzyński       func::FuncOp calledFunction = getCalledFunction(callOp);
35191c11574SAndrzej Warzyński       assert(calledFunction && "could not retrieved called func::FuncOp");
3523d0ca2cfSMatthias Springer       // If the called function does not have any tensors in its signature, then
3533d0ca2cfSMatthias Springer       // it is not necessary to bufferize the callee before the caller.
3543d0ca2cfSMatthias Springer       if (!hasTensorSignature(calledFunction))
3553d0ca2cfSMatthias Springer         return WalkResult::skip();
3563d0ca2cfSMatthias Springer 
35786fd1c13SBenjamin Kramer       callerMap[calledFunction].insert(callOp);
35886fd1c13SBenjamin Kramer       if (calledBy[calledFunction].insert(funcOp).second) {
359e07a7fd5SMatthias Springer         numberCallOpsContainedInFuncOp[funcOp]++;
360e07a7fd5SMatthias Springer       }
361e07a7fd5SMatthias Springer       return WalkResult::advance();
362e07a7fd5SMatthias Springer     });
363e07a7fd5SMatthias Springer   });
364e07a7fd5SMatthias Springer   if (res.wasInterrupted())
365e07a7fd5SMatthias Springer     return failure();
366c271ba7fSMatthias Springer 
3673d0ca2cfSMatthias Springer   // Iteratively remove function operations that do not call any of the
368c271ba7fSMatthias Springer   // functions remaining in the callCounter map and add them to ordered list.
369e07a7fd5SMatthias Springer   while (!numberCallOpsContainedInFuncOp.empty()) {
370e07a7fd5SMatthias Springer     auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
371e07a7fd5SMatthias Springer                             [](auto entry) { return entry.getSecond() == 0; });
372e07a7fd5SMatthias Springer     if (it == numberCallOpsContainedInFuncOp.end())
373c271ba7fSMatthias Springer       break;
374e07a7fd5SMatthias Springer     orderedFuncOps.push_back(it->getFirst());
375e07a7fd5SMatthias Springer     for (auto callee : calledBy[it->getFirst()])
376e07a7fd5SMatthias Springer       numberCallOpsContainedInFuncOp[callee]--;
377e07a7fd5SMatthias Springer     numberCallOpsContainedInFuncOp.erase(it);
378e07a7fd5SMatthias Springer   }
379c271ba7fSMatthias Springer 
380c271ba7fSMatthias Springer   // Put all other functions in the list of remaining functions. These are
381c271ba7fSMatthias Springer   // functions that call each other circularly.
382c271ba7fSMatthias Springer   for (auto it : numberCallOpsContainedInFuncOp)
383c271ba7fSMatthias Springer     remainingFuncOps.push_back(it.first);
384c271ba7fSMatthias Springer 
385e07a7fd5SMatthias Springer   return success();
386e07a7fd5SMatthias Springer }
387e07a7fd5SMatthias Springer 
388*b0a4e958SMatthias Springer /// Helper function that extracts the source from a memref.cast. If the given
389*b0a4e958SMatthias Springer /// value is not a memref.cast result, simply returns the given value.
390*b0a4e958SMatthias Springer static Value unpackCast(Value v) {
391*b0a4e958SMatthias Springer   auto castOp = v.getDefiningOp<memref::CastOp>();
392*b0a4e958SMatthias Springer   if (!castOp)
393*b0a4e958SMatthias Springer     return v;
394*b0a4e958SMatthias Springer   return castOp.getSource();
395*b0a4e958SMatthias Springer }
396*b0a4e958SMatthias Springer 
397*b0a4e958SMatthias Springer /// Helper function that returns the return types (skipping casts) of the given
398*b0a4e958SMatthias Springer /// func.return ops. This function returns as many types as the return ops have
399*b0a4e958SMatthias Springer /// operands. If the i-th operand is not the same for all func.return ops, then
400*b0a4e958SMatthias Springer /// the i-th returned type is an "empty" type.
401*b0a4e958SMatthias Springer static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
402*b0a4e958SMatthias Springer   assert(!returnOps.empty() && "expected at least one ReturnOp");
403*b0a4e958SMatthias Springer   int numOperands = returnOps.front()->getNumOperands();
404*b0a4e958SMatthias Springer 
405*b0a4e958SMatthias Springer   // Helper function that unpacks memref.cast ops and returns the type.
406*b0a4e958SMatthias Springer   auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
407*b0a4e958SMatthias Springer 
408*b0a4e958SMatthias Springer   SmallVector<Type> result;
409*b0a4e958SMatthias Springer   for (int i = 0; i < numOperands; ++i) {
410*b0a4e958SMatthias Springer     // Get the type of the i-th operand of the first func.return ops.
411*b0a4e958SMatthias Springer     Type t = getSourceType(returnOps.front()->getOperand(i));
412*b0a4e958SMatthias Springer 
413*b0a4e958SMatthias Springer     // Check if all other func.return ops have a matching operand type.
414*b0a4e958SMatthias Springer     for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
415*b0a4e958SMatthias Springer       if (getSourceType(returnOps[j]->getOperand(i)) != t)
416*b0a4e958SMatthias Springer         t = Type();
417*b0a4e958SMatthias Springer 
418*b0a4e958SMatthias Springer     result.push_back(t);
419*b0a4e958SMatthias Springer   }
420*b0a4e958SMatthias Springer 
421*b0a4e958SMatthias Springer   return result;
422*b0a4e958SMatthias Springer }
423*b0a4e958SMatthias Springer 
424e07a7fd5SMatthias Springer /// Fold return values that are memref casts and update function return types.
425e07a7fd5SMatthias Springer ///
426e07a7fd5SMatthias Springer /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
427e07a7fd5SMatthias Springer /// is not known yet. Therefore, the bufferization uses memref types with the
428e07a7fd5SMatthias Springer /// most generic layout map as function return types. After bufferizing the
429e07a7fd5SMatthias Springer /// entire function body, a more concise memref type can potentially be used for
430e07a7fd5SMatthias Springer /// the return type of the function.
43191c11574SAndrzej Warzyński static void foldMemRefCasts(func::FuncOp funcOp) {
432*b0a4e958SMatthias Springer   // There is nothing to do for bodiless ops.
43391c11574SAndrzej Warzyński   if (funcOp.getBody().empty())
434e07a7fd5SMatthias Springer     return;
435e07a7fd5SMatthias Springer 
436*b0a4e958SMatthias Springer   // Compute the common result types of all return ops.
437*b0a4e958SMatthias Springer   SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
438*b0a4e958SMatthias Springer   SmallVector<Type> resultTypes = getReturnTypes(returnOps);
439e07a7fd5SMatthias Springer 
440*b0a4e958SMatthias Springer   // Remove direct casts.
441*b0a4e958SMatthias Springer   for (func::ReturnOp returnOp : returnOps) {
442e07a7fd5SMatthias Springer     for (OpOperand &operand : returnOp->getOpOperands()) {
443*b0a4e958SMatthias Springer       // Bail if no common result type was found.
444*b0a4e958SMatthias Springer       if (resultTypes[operand.getOperandNumber()]) {
445*b0a4e958SMatthias Springer         operand.set(unpackCast(operand.get()));
446*b0a4e958SMatthias Springer       }
447e07a7fd5SMatthias Springer     }
448e07a7fd5SMatthias Springer   }
449e07a7fd5SMatthias Springer 
450*b0a4e958SMatthias Springer   // Fill in the missing result types that were not the same among all
451*b0a4e958SMatthias Springer   // func.return ops.
452*b0a4e958SMatthias Springer   for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
453*b0a4e958SMatthias Springer     if (resultTypes[i])
454*b0a4e958SMatthias Springer       continue;
455*b0a4e958SMatthias Springer     resultTypes[i] = funcOp.getFunctionType().getResult(i);
456*b0a4e958SMatthias Springer   }
457*b0a4e958SMatthias Springer 
458*b0a4e958SMatthias Springer   // Update the function type.
45991c11574SAndrzej Warzyński   auto newFuncType = FunctionType::get(
46091c11574SAndrzej Warzyński       funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
461e07a7fd5SMatthias Springer   funcOp.setType(newFuncType);
462e07a7fd5SMatthias Springer }
463e07a7fd5SMatthias Springer 
464f470f8cbSMatthias Springer LogicalResult
465f470f8cbSMatthias Springer mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
466ae05bd99SMatthias Springer                                      OneShotAnalysisState &state,
467ae05bd99SMatthias Springer                                      BufferizationStatistics *statistics) {
4687cdfc843SMatthias Springer   assert(state.getOptions().bufferizeFunctionBoundaries &&
469d6dab38aSMatthias Springer          "expected that function boundary bufferization is activated");
470faa9be75SMatthias Springer   FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
471e07a7fd5SMatthias Springer 
472c271ba7fSMatthias Springer   // A list of non-circular functions in the order in which they are analyzed
473c271ba7fSMatthias Springer   // and bufferized.
47491c11574SAndrzej Warzyński   SmallVector<func::FuncOp> orderedFuncOps;
475c271ba7fSMatthias Springer   // A list of all other functions. I.e., functions that call each other
476c271ba7fSMatthias Springer   // recursively. For these, we analyze the function body but not the function
477c271ba7fSMatthias Springer   // boundary.
478c271ba7fSMatthias Springer   SmallVector<func::FuncOp> remainingFuncOps;
479e07a7fd5SMatthias Springer 
480e07a7fd5SMatthias Springer   // A mapping of FuncOps to their callers.
481e07a7fd5SMatthias Springer   FuncCallerMap callerMap;
482e07a7fd5SMatthias Springer 
483c271ba7fSMatthias Springer   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
484c271ba7fSMatthias Springer                                       remainingFuncOps, callerMap)))
485e07a7fd5SMatthias Springer     return failure();
486e07a7fd5SMatthias Springer 
487c271ba7fSMatthias Springer   // Analyze functions in order. Starting with functions that are not calling
488c271ba7fSMatthias Springer   // any other functions.
48991c11574SAndrzej Warzyński   for (func::FuncOp funcOp : orderedFuncOps) {
490060c8be5SMaya Amrami     if (!state.getOptions().isOpAllowed(funcOp))
491060c8be5SMaya Amrami       continue;
492060c8be5SMaya Amrami 
493e07a7fd5SMatthias Springer     // Now analyzing function.
494e07a7fd5SMatthias Springer     funcState.startFunctionAnalysis(funcOp);
495e07a7fd5SMatthias Springer 
496e07a7fd5SMatthias Springer     // Gather equivalence info for CallOps.
497cf2d374eSMatthias Springer     equivalenceAnalysis(funcOp, state, funcState);
498e07a7fd5SMatthias Springer 
499e07a7fd5SMatthias Springer     // Analyze funcOp.
500ae05bd99SMatthias Springer     if (failed(analyzeOp(funcOp, state, statistics)))
501e07a7fd5SMatthias Springer       return failure();
502e07a7fd5SMatthias Springer 
5033490aadfSMatthias Springer     // Run some extra function analyses.
504faa9be75SMatthias Springer     if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
505faa9be75SMatthias Springer         failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
5063490aadfSMatthias Springer       return failure();
5073490aadfSMatthias Springer 
508e07a7fd5SMatthias Springer     // Mark op as fully analyzed.
509e07a7fd5SMatthias Springer     funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
510e07a7fd5SMatthias Springer   }
511e07a7fd5SMatthias Springer 
512c271ba7fSMatthias Springer   // Analyze all other functions. All function boundary analyses are skipped.
513c271ba7fSMatthias Springer   for (func::FuncOp funcOp : remainingFuncOps) {
514c271ba7fSMatthias Springer     if (!state.getOptions().isOpAllowed(funcOp))
515c271ba7fSMatthias Springer       continue;
516c271ba7fSMatthias Springer 
517c271ba7fSMatthias Springer     // Gather equivalence info for CallOps.
518c271ba7fSMatthias Springer     equivalenceAnalysis(funcOp, state, funcState);
519c271ba7fSMatthias Springer 
520c271ba7fSMatthias Springer     // Analyze funcOp.
521c271ba7fSMatthias Springer     if (failed(analyzeOp(funcOp, state, statistics)))
522c271ba7fSMatthias Springer       return failure();
523c271ba7fSMatthias Springer 
524c271ba7fSMatthias Springer     // TODO: We currently skip all function argument analyses for functions
525c271ba7fSMatthias Springer     // that call each other circularly. These analyses do not support recursive
526c271ba7fSMatthias Springer     // calls yet. The `BufferizableOpInterface` implementations of `func`
527c271ba7fSMatthias Springer     // dialect ops return conservative results in the absence of analysis
528c271ba7fSMatthias Springer     // information.
529c271ba7fSMatthias Springer   }
530c271ba7fSMatthias Springer 
531e07a7fd5SMatthias Springer   return success();
532f470f8cbSMatthias Springer }
533f470f8cbSMatthias Springer 
534c7a9e5e5SPeiming Liu void mlir::bufferization::removeBufferizationAttributesInModule(
535c7a9e5e5SPeiming Liu     ModuleOp moduleOp) {
53691c11574SAndrzej Warzyński   moduleOp.walk([&](func::FuncOp op) {
537c7a9e5e5SPeiming Liu     for (BlockArgument bbArg : op.getArguments())
538c7a9e5e5SPeiming Liu       removeBufferizationAttributes(bbArg);
539c7a9e5e5SPeiming Liu   });
540c7a9e5e5SPeiming Liu }
541c7a9e5e5SPeiming Liu 
542f470f8cbSMatthias Springer LogicalResult mlir::bufferization::bufferizeModuleOp(
543ae05bd99SMatthias Springer     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
5449cf96850SMaya Amrami     BufferizationStatistics *statistics) {
545f470f8cbSMatthias Springer   assert(options.bufferizeFunctionBoundaries &&
546f470f8cbSMatthias Springer          "expected that function boundary bufferization is activated");
547f470f8cbSMatthias Springer   IRRewriter rewriter(moduleOp.getContext());
548f470f8cbSMatthias Springer 
549c271ba7fSMatthias Springer   // A list of non-circular functions in the order in which they are analyzed
550c271ba7fSMatthias Springer   // and bufferized.
55191c11574SAndrzej Warzyński   SmallVector<func::FuncOp> orderedFuncOps;
552c271ba7fSMatthias Springer   // A list of all other functions. I.e., functions that call each other
553c271ba7fSMatthias Springer   // recursively. For these, we analyze the function body but not the function
554c271ba7fSMatthias Springer   // boundary.
555c271ba7fSMatthias Springer   SmallVector<func::FuncOp> remainingFuncOps;
556f470f8cbSMatthias Springer 
557f470f8cbSMatthias Springer   // A mapping of FuncOps to their callers.
558f470f8cbSMatthias Springer   FuncCallerMap callerMap;
559f470f8cbSMatthias Springer 
560c271ba7fSMatthias Springer   // Try to bufferize functions in calling order. I.e., first bufferize
561c271ba7fSMatthias Springer   // functions that do not call other functions. This allows us to infer
562c271ba7fSMatthias Springer   // accurate buffer types for function return values. Functions that call
563c271ba7fSMatthias Springer   // each other recursively are bufferized in an unspecified order at the end.
564c271ba7fSMatthias Springer   // We may use unnecessarily "complex" (in terms of layout map) buffer types.
565c271ba7fSMatthias Springer   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
566c271ba7fSMatthias Springer                                       remainingFuncOps, callerMap)))
567f470f8cbSMatthias Springer     return failure();
568c271ba7fSMatthias Springer   llvm::append_range(orderedFuncOps, remainingFuncOps);
569e07a7fd5SMatthias Springer 
570e07a7fd5SMatthias Springer   // Bufferize functions.
57191c11574SAndrzej Warzyński   for (func::FuncOp funcOp : orderedFuncOps) {
572e07a7fd5SMatthias Springer     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
573e07a7fd5SMatthias Springer     // would be invalidated.
5749d34c052SMatthias Springer 
57591c11574SAndrzej Warzyński     if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
5769d34c052SMatthias Springer       // This function was not analyzed and RaW conflicts were not resolved.
5779d34c052SMatthias Springer       // Buffer copies must be inserted before every write.
5789d34c052SMatthias Springer       OneShotBufferizationOptions updatedOptions = options;
5799d34c052SMatthias Springer       updatedOptions.copyBeforeWrite = true;
5809d34c052SMatthias Springer       if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
581e07a7fd5SMatthias Springer         return failure();
5829d34c052SMatthias Springer     } else {
5839d34c052SMatthias Springer       if (failed(bufferizeOp(funcOp, options, statistics)))
5849d34c052SMatthias Springer         return failure();
5859d34c052SMatthias Springer     }
5869d34c052SMatthias Springer 
587f287da8aSMatthias Springer     // Change buffer return types to more precise layout maps.
58875ef84bfSOleg Shyshkov     if (options.inferFunctionResultLayout)
589e07a7fd5SMatthias Springer       foldMemRefCasts(funcOp);
590e07a7fd5SMatthias Springer   }
591e07a7fd5SMatthias Springer 
5928f2d83daSMatthias Springer   // Bufferize all other ops.
593fa101214SRyan Holt   for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
5948f2d83daSMatthias Springer     // Functions were already bufferized.
59591c11574SAndrzej Warzyński     if (isa<func::FuncOp>(&op))
5968f2d83daSMatthias Springer       continue;
5978f2d83daSMatthias Springer     if (failed(bufferizeOp(&op, options, statistics)))
5988f2d83daSMatthias Springer       return failure();
5998f2d83daSMatthias Springer   }
6008f2d83daSMatthias Springer 
601e07a7fd5SMatthias Springer   // Post-pass cleanup of function argument attributes.
602c7a9e5e5SPeiming Liu   removeBufferizationAttributesInModule(moduleOp);
603e07a7fd5SMatthias Springer 
604e07a7fd5SMatthias Springer   return success();
605e07a7fd5SMatthias Springer }
606f470f8cbSMatthias Springer 
607f470f8cbSMatthias Springer LogicalResult mlir::bufferization::runOneShotModuleBufferize(
608ae05bd99SMatthias Springer     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
6099cf96850SMaya Amrami     BufferizationStatistics *statistics) {
610f470f8cbSMatthias Springer   assert(options.bufferizeFunctionBoundaries &&
611f470f8cbSMatthias Springer          "expected that function boundary bufferization is activated");
612f7dd9a32SMatthias Springer   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
613f7dd9a32SMatthias Springer          "invalid combination of bufferization flags");
614f7dd9a32SMatthias Springer   if (!options.copyBeforeWrite) {
6159cf96850SMaya Amrami     if (options.noAnalysisFuncFilter.empty()) {
616ae05bd99SMatthias Springer       if (failed(insertTensorCopies(moduleOp, options, statistics)))
617f470f8cbSMatthias Springer         return failure();
618060c8be5SMaya Amrami     } else {
6199cf96850SMaya Amrami       // FuncOps whose names are specified in options.noAnalysisFuncFilter will
6209cf96850SMaya Amrami       // not be analyzed. Ops in these FuncOps will not be analyzed as well.
6219cf96850SMaya Amrami       OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
62291c11574SAndrzej Warzyński         auto func = dyn_cast<func::FuncOp>(op);
6239cf96850SMaya Amrami         if (!func)
62491c11574SAndrzej Warzyński           func = op->getParentOfType<func::FuncOp>();
6259cf96850SMaya Amrami         if (func)
6269cf96850SMaya Amrami           return llvm::is_contained(options.noAnalysisFuncFilter,
62791c11574SAndrzej Warzyński                                     func.getSymName());
6289cf96850SMaya Amrami         return false;
6299cf96850SMaya Amrami       };
630060c8be5SMaya Amrami       OneShotBufferizationOptions updatedOptions(options);
631060c8be5SMaya Amrami       updatedOptions.opFilter.denyOperation(analysisFilterFn);
632060c8be5SMaya Amrami       if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
633060c8be5SMaya Amrami         return failure();
634060c8be5SMaya Amrami     }
635f7dd9a32SMatthias Springer   }
636f470f8cbSMatthias Springer   if (options.testAnalysisOnly)
637f470f8cbSMatthias Springer     return success();
6389cf96850SMaya Amrami   if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
639f470f8cbSMatthias Springer     return failure();
640f470f8cbSMatthias Springer   return success();
641f470f8cbSMatthias Springer }
642