xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (revision b0a4e958e85784cff46303c92b6a3a14b20fa1d8)
1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Module Bufferization is an extension of One-Shot Bufferize that
10 // bufferizes function boundaries. It provides `BufferizableOpInterface`
11 // implementations for FuncOp, CallOp and ReturnOp.
12 //
13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14 // This function analyzes the given module and determines the order of analysis
15 // and bufferization: Functions that are called are processed before their
16 // respective callers.
17 //
18 // After analyzing a FuncOp, additional information about its bbArgs is
19 // gathered and stored in `FuncAnalysisState`.
20 //
21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
22 // for
23 //   each tensor return value (if any).
24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25 //   read/written.
26 //
27 // Module Bufferization implements the following calling convention.
28 //
29 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30 //   be written to in-place.
31 // * If a tensor operand of a CallOp is read after the CallOp, the operand of
32 //   the CallOp must bufferize out-of-place.
33 //
34 // Example: The tensor.insert op bufferizes in-place because it is allowed to
35 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
36 // out-of-place because `%t0` is modified by the callee but read by the
37 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
38 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
39 // ```
40 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
41 //   %f = ... : f32
42 //   %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
43 //   return %0 : tensor<?xf32>
44 // }
45 //
46 // func @caller() -> () {
47 //   %t0 = ... : tensor<?xf32>
48 //   %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
49 //   %2 = tensor.extract %1[...]  : tensor<?xf32>
50 // }
51 // ```
52 //
53 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
54 // analyze the function body. In such a case, the CallOp analysis conservatively
55 // assumes that each tensor OpOperand is both read and written.
56 //
57 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58 // as "not reading" and/or "not writing".
59 
60 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
61 
62 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
63 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
64 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
65 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
66 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
67 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
68 #include "mlir/Dialect/Func/IR/FuncOps.h"
69 #include "mlir/Dialect/MemRef/IR/MemRef.h"
70 #include "mlir/IR/BuiltinTypes.h"
71 #include "mlir/IR/Operation.h"
72 
73 using namespace mlir;
74 using namespace mlir::bufferization;
75 using namespace mlir::bufferization::func_ext;
76 
77 /// A mapping of FuncOps to their callers.
78 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
79 
80 /// Get or create FuncAnalysisState.
81 static FuncAnalysisState &
82 getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
83   auto *result = state.getExtension<FuncAnalysisState>();
84   if (result)
85     return *result;
86   return state.addExtension<FuncAnalysisState>();
87 }
88 
89 namespace {
90 
91 /// Annotate IR with the results of the analysis. For testing purposes only.
92 static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
93                                           BlockArgument bbArg) {
94   const char *kEquivalentArgsAttr = "__equivalent_func_args__";
95   Operation *op = returnVal.getOwner();
96 
97   SmallVector<int64_t> equivBbArgs;
98   if (op->hasAttr(kEquivalentArgsAttr)) {
99     auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
100     equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
101       return cast<IntegerAttr>(a).getValue().getSExtValue();
102     }));
103   } else {
104     equivBbArgs.append(op->getNumOperands(), -1);
105   }
106   equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
107 
108   OpBuilder b(op->getContext());
109   op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
110 }
111 
112 /// Store function BlockArguments that are equivalent to/aliasing a returned
113 /// value in FuncAnalysisState.
114 static LogicalResult
115 aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
116                              FuncAnalysisState &funcState) {
117   if (funcOp.getBody().empty()) {
118     // No function body available. Conservatively assume that every tensor
119     // return value may alias with any tensor bbArg.
120     FunctionType type = funcOp.getFunctionType();
121     for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
122       if (!isa<TensorType>(inputIt.value()))
123         continue;
124       for (const auto &resultIt : llvm::enumerate(type.getResults())) {
125         if (!isa<TensorType>(resultIt.value()))
126           continue;
127         int64_t returnIdx = resultIt.index();
128         int64_t bbArgIdx = inputIt.index();
129         funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
130       }
131     }
132     return success();
133   }
134 
135   // Find all func.return ops.
136   SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
137   assert(!returnOps.empty() && "expected at least one ReturnOp");
138 
139   // Build alias sets. Merge all aliases from all func.return ops.
140   for (BlockArgument bbArg : funcOp.getArguments()) {
141     if (isa<RankedTensorType>(bbArg.getType())) {
142       int64_t bbArgIdx = bbArg.getArgNumber();
143       // Store aliases in a set, so that we don't add the same alias twice.
144       SetVector<int64_t> aliases;
145       for (func::ReturnOp returnOp : returnOps) {
146         for (OpOperand &returnVal : returnOp->getOpOperands()) {
147           if (isa<RankedTensorType>(returnVal.get().getType())) {
148             int64_t returnIdx = returnVal.getOperandNumber();
149             if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
150               aliases.insert(returnIdx);
151           }
152         }
153       }
154       for (int64_t alias : aliases)
155         funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
156     }
157   }
158 
159   // Build equivalence sets.
160   // Helper function that finds an equivalent block argument index for the
161   // given OpOperand. Return std::nullopt if no equivalent block argument could
162   // be found.
163   auto findEquivalentBlockArgIdx =
164       [&](OpOperand &opOperand) -> std::optional<int64_t> {
165     Value v = opOperand.get();
166     if (!isa<TensorType>(v.getType()))
167       return std::nullopt;
168     for (BlockArgument bbArg : funcOp.getArguments()) {
169       if (isa<RankedTensorType>(bbArg.getType())) {
170         if (state.areEquivalentBufferizedValues(v, bbArg)) {
171           if (state.getOptions().testAnalysisOnly)
172             annotateEquivalentReturnBbArg(opOperand, bbArg);
173           return bbArg.getArgNumber();
174         }
175       }
176     }
177     return std::nullopt;
178   };
179 
180   int64_t numResults = returnOps.front()->getNumOperands();
181   for (int64_t i = 0; i < numResults; ++i) {
182     // Find the equivalent block argument index for the i-th operand of the
183     // first func.return op.
184     std::optional<int64_t> maybeEquiv =
185         findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186     if (!maybeEquiv.has_value())
187       continue;
188     int64_t bbArgIdx = *maybeEquiv;
189     bool allEquiv = true;
190 
191     // Check if all other func.return ops have the same equivalent block
192     // argument for the i-th operand. In contrast to aliasing information,
193     // which is just "merged", equivalence information must match across all
194     // func.return ops.
195     for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
196       std::optional<int64_t> maybeEquiv =
197           findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198       if (maybeEquiv != bbArgIdx) {
199         allEquiv = false;
200         break;
201       }
202     }
203 
204     // All func.return ops have the same equivalent block argument for the i-th
205     // operand.
206     if (allEquiv)
207       funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
208   }
209 
210   return success();
211 }
212 
213 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
214                                   bool isWritten) {
215   OpBuilder b(funcOp.getContext());
216   Attribute accessType;
217   if (isRead && isWritten) {
218     accessType = b.getStringAttr("read-write");
219   } else if (isRead) {
220     accessType = b.getStringAttr("read");
221   } else if (isWritten) {
222     accessType = b.getStringAttr("write");
223   } else {
224     accessType = b.getStringAttr("none");
225   }
226   funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
227                     accessType);
228 }
229 
230 /// Determine which FuncOp bbArgs are read and which are written. When run on a
231 /// function with unknown ops, we conservatively assume that such ops bufferize
232 /// to a read + write.
233 static LogicalResult
234 funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
235                              FuncAnalysisState &funcState) {
236   for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
237        ++idx) {
238     // Skip non-tensor arguments.
239     if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
240       continue;
241     bool isRead;
242     bool isWritten;
243     if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
244             idx, BufferizationDialect::kBufferAccessAttrName)) {
245       // Buffer access behavior is specified on the function. Skip the analysis.
246       StringRef str = accessAttr.getValue();
247       isRead = str == "read" || str == "read-write";
248       isWritten = str == "write" || str == "read-write";
249     } else if (funcOp.getBody().empty()) {
250       // If the function has no body, conservatively assume that all args are
251       // read + written.
252       isRead = true;
253       isWritten = true;
254     } else {
255       // Analyze the body of the function.
256       BlockArgument bbArg = funcOp.getArgument(idx);
257       isRead = state.isValueRead(bbArg);
258       isWritten = state.isValueWritten(bbArg);
259     }
260 
261     if (state.getOptions().testAnalysisOnly)
262       annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
263     if (isRead)
264       funcState.readBbArgs[funcOp].insert(idx);
265     if (isWritten)
266       funcState.writtenBbArgs[funcOp].insert(idx);
267   }
268 
269   return success();
270 }
271 } // namespace
272 
273 /// Remove bufferization attributes on FuncOp arguments.
274 static void removeBufferizationAttributes(BlockArgument bbArg) {
275   auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
276   funcOp.removeArgAttr(bbArg.getArgNumber(),
277                        BufferizationDialect::kBufferLayoutAttrName);
278   funcOp.removeArgAttr(bbArg.getArgNumber(),
279                        BufferizationDialect::kWritableAttrName);
280 }
281 
282 /// Return the func::FuncOp called by `callOp`.
283 static func::FuncOp getCalledFunction(func::CallOp callOp) {
284   SymbolRefAttr sym =
285       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
286   if (!sym)
287     return nullptr;
288   return dyn_cast_or_null<func::FuncOp>(
289       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
290 }
291 
292 /// Gather equivalence info of CallOps.
293 /// Note: This only adds new equivalence info if the called function was already
294 /// analyzed.
295 // TODO: This does not handle cyclic function call graphs etc.
296 static void equivalenceAnalysis(func::FuncOp funcOp,
297                                 OneShotAnalysisState &state,
298                                 FuncAnalysisState &funcState) {
299   funcOp->walk([&](func::CallOp callOp) {
300     func::FuncOp calledFunction = getCalledFunction(callOp);
301     assert(calledFunction && "could not retrieved called func::FuncOp");
302 
303     // No equivalence info available for the called function.
304     if (!funcState.equivalentFuncArgs.count(calledFunction))
305       return WalkResult::skip();
306 
307     for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
308       int64_t returnIdx = it.first;
309       int64_t bbargIdx = it.second;
310       if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
311         continue;
312       Value returnVal = callOp.getResult(returnIdx);
313       Value argVal = callOp->getOperand(bbargIdx);
314       state.unionEquivalenceClasses(returnVal, argVal);
315     }
316 
317     return WalkResult::advance();
318   });
319 }
320 
321 /// Return "true" if the given function signature has tensor semantics.
322 static bool hasTensorSignature(func::FuncOp funcOp) {
323   return llvm::any_of(funcOp.getFunctionType().getInputs(),
324                       llvm::IsaPred<TensorType>) ||
325          llvm::any_of(funcOp.getFunctionType().getResults(),
326                       llvm::IsaPred<TensorType>);
327 }
328 
329 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
330 /// callee-caller order (i.e., callees without callers first). Store all
331 /// remaining functions (i.e., the ones that call each other recursively) in
332 /// `remainingFuncOps`.
333 ///
334 /// Store the map of FuncOp to all its callers in `callerMap`.
335 ///
336 /// Return `failure()` if we are unable to retrieve the called FuncOp from
337 /// any func::CallOp.
338 static LogicalResult getFuncOpsOrderedByCalls(
339     ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
340     SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
341   // For each FuncOp, the set of functions called by it (i.e. the union of
342   // symbols of all nested func::CallOp).
343   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
344   // For each FuncOp, the number of func::CallOp it contains.
345   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
346   WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
347     // Collect function calls and populate the caller map.
348     numberCallOpsContainedInFuncOp[funcOp] = 0;
349     return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
350       func::FuncOp calledFunction = getCalledFunction(callOp);
351       assert(calledFunction && "could not retrieved called func::FuncOp");
352       // If the called function does not have any tensors in its signature, then
353       // it is not necessary to bufferize the callee before the caller.
354       if (!hasTensorSignature(calledFunction))
355         return WalkResult::skip();
356 
357       callerMap[calledFunction].insert(callOp);
358       if (calledBy[calledFunction].insert(funcOp).second) {
359         numberCallOpsContainedInFuncOp[funcOp]++;
360       }
361       return WalkResult::advance();
362     });
363   });
364   if (res.wasInterrupted())
365     return failure();
366 
367   // Iteratively remove function operations that do not call any of the
368   // functions remaining in the callCounter map and add them to ordered list.
369   while (!numberCallOpsContainedInFuncOp.empty()) {
370     auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
371                             [](auto entry) { return entry.getSecond() == 0; });
372     if (it == numberCallOpsContainedInFuncOp.end())
373       break;
374     orderedFuncOps.push_back(it->getFirst());
375     for (auto callee : calledBy[it->getFirst()])
376       numberCallOpsContainedInFuncOp[callee]--;
377     numberCallOpsContainedInFuncOp.erase(it);
378   }
379 
380   // Put all other functions in the list of remaining functions. These are
381   // functions that call each other circularly.
382   for (auto it : numberCallOpsContainedInFuncOp)
383     remainingFuncOps.push_back(it.first);
384 
385   return success();
386 }
387 
388 /// Helper function that extracts the source from a memref.cast. If the given
389 /// value is not a memref.cast result, simply returns the given value.
390 static Value unpackCast(Value v) {
391   auto castOp = v.getDefiningOp<memref::CastOp>();
392   if (!castOp)
393     return v;
394   return castOp.getSource();
395 }
396 
397 /// Helper function that returns the return types (skipping casts) of the given
398 /// func.return ops. This function returns as many types as the return ops have
399 /// operands. If the i-th operand is not the same for all func.return ops, then
400 /// the i-th returned type is an "empty" type.
401 static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
402   assert(!returnOps.empty() && "expected at least one ReturnOp");
403   int numOperands = returnOps.front()->getNumOperands();
404 
405   // Helper function that unpacks memref.cast ops and returns the type.
406   auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
407 
408   SmallVector<Type> result;
409   for (int i = 0; i < numOperands; ++i) {
410     // Get the type of the i-th operand of the first func.return ops.
411     Type t = getSourceType(returnOps.front()->getOperand(i));
412 
413     // Check if all other func.return ops have a matching operand type.
414     for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
415       if (getSourceType(returnOps[j]->getOperand(i)) != t)
416         t = Type();
417 
418     result.push_back(t);
419   }
420 
421   return result;
422 }
423 
424 /// Fold return values that are memref casts and update function return types.
425 ///
426 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
427 /// is not known yet. Therefore, the bufferization uses memref types with the
428 /// most generic layout map as function return types. After bufferizing the
429 /// entire function body, a more concise memref type can potentially be used for
430 /// the return type of the function.
431 static void foldMemRefCasts(func::FuncOp funcOp) {
432   // There is nothing to do for bodiless ops.
433   if (funcOp.getBody().empty())
434     return;
435 
436   // Compute the common result types of all return ops.
437   SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
438   SmallVector<Type> resultTypes = getReturnTypes(returnOps);
439 
440   // Remove direct casts.
441   for (func::ReturnOp returnOp : returnOps) {
442     for (OpOperand &operand : returnOp->getOpOperands()) {
443       // Bail if no common result type was found.
444       if (resultTypes[operand.getOperandNumber()]) {
445         operand.set(unpackCast(operand.get()));
446       }
447     }
448   }
449 
450   // Fill in the missing result types that were not the same among all
451   // func.return ops.
452   for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
453     if (resultTypes[i])
454       continue;
455     resultTypes[i] = funcOp.getFunctionType().getResult(i);
456   }
457 
458   // Update the function type.
459   auto newFuncType = FunctionType::get(
460       funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
461   funcOp.setType(newFuncType);
462 }
463 
464 LogicalResult
465 mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
466                                      OneShotAnalysisState &state,
467                                      BufferizationStatistics *statistics) {
468   assert(state.getOptions().bufferizeFunctionBoundaries &&
469          "expected that function boundary bufferization is activated");
470   FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
471 
472   // A list of non-circular functions in the order in which they are analyzed
473   // and bufferized.
474   SmallVector<func::FuncOp> orderedFuncOps;
475   // A list of all other functions. I.e., functions that call each other
476   // recursively. For these, we analyze the function body but not the function
477   // boundary.
478   SmallVector<func::FuncOp> remainingFuncOps;
479 
480   // A mapping of FuncOps to their callers.
481   FuncCallerMap callerMap;
482 
483   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
484                                       remainingFuncOps, callerMap)))
485     return failure();
486 
487   // Analyze functions in order. Starting with functions that are not calling
488   // any other functions.
489   for (func::FuncOp funcOp : orderedFuncOps) {
490     if (!state.getOptions().isOpAllowed(funcOp))
491       continue;
492 
493     // Now analyzing function.
494     funcState.startFunctionAnalysis(funcOp);
495 
496     // Gather equivalence info for CallOps.
497     equivalenceAnalysis(funcOp, state, funcState);
498 
499     // Analyze funcOp.
500     if (failed(analyzeOp(funcOp, state, statistics)))
501       return failure();
502 
503     // Run some extra function analyses.
504     if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
505         failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
506       return failure();
507 
508     // Mark op as fully analyzed.
509     funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
510   }
511 
512   // Analyze all other functions. All function boundary analyses are skipped.
513   for (func::FuncOp funcOp : remainingFuncOps) {
514     if (!state.getOptions().isOpAllowed(funcOp))
515       continue;
516 
517     // Gather equivalence info for CallOps.
518     equivalenceAnalysis(funcOp, state, funcState);
519 
520     // Analyze funcOp.
521     if (failed(analyzeOp(funcOp, state, statistics)))
522       return failure();
523 
524     // TODO: We currently skip all function argument analyses for functions
525     // that call each other circularly. These analyses do not support recursive
526     // calls yet. The `BufferizableOpInterface` implementations of `func`
527     // dialect ops return conservative results in the absence of analysis
528     // information.
529   }
530 
531   return success();
532 }
533 
534 void mlir::bufferization::removeBufferizationAttributesInModule(
535     ModuleOp moduleOp) {
536   moduleOp.walk([&](func::FuncOp op) {
537     for (BlockArgument bbArg : op.getArguments())
538       removeBufferizationAttributes(bbArg);
539   });
540 }
541 
542 LogicalResult mlir::bufferization::bufferizeModuleOp(
543     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
544     BufferizationStatistics *statistics) {
545   assert(options.bufferizeFunctionBoundaries &&
546          "expected that function boundary bufferization is activated");
547   IRRewriter rewriter(moduleOp.getContext());
548 
549   // A list of non-circular functions in the order in which they are analyzed
550   // and bufferized.
551   SmallVector<func::FuncOp> orderedFuncOps;
552   // A list of all other functions. I.e., functions that call each other
553   // recursively. For these, we analyze the function body but not the function
554   // boundary.
555   SmallVector<func::FuncOp> remainingFuncOps;
556 
557   // A mapping of FuncOps to their callers.
558   FuncCallerMap callerMap;
559 
560   // Try to bufferize functions in calling order. I.e., first bufferize
561   // functions that do not call other functions. This allows us to infer
562   // accurate buffer types for function return values. Functions that call
563   // each other recursively are bufferized in an unspecified order at the end.
564   // We may use unnecessarily "complex" (in terms of layout map) buffer types.
565   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
566                                       remainingFuncOps, callerMap)))
567     return failure();
568   llvm::append_range(orderedFuncOps, remainingFuncOps);
569 
570   // Bufferize functions.
571   for (func::FuncOp funcOp : orderedFuncOps) {
572     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
573     // would be invalidated.
574 
575     if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
576       // This function was not analyzed and RaW conflicts were not resolved.
577       // Buffer copies must be inserted before every write.
578       OneShotBufferizationOptions updatedOptions = options;
579       updatedOptions.copyBeforeWrite = true;
580       if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
581         return failure();
582     } else {
583       if (failed(bufferizeOp(funcOp, options, statistics)))
584         return failure();
585     }
586 
587     // Change buffer return types to more precise layout maps.
588     if (options.inferFunctionResultLayout)
589       foldMemRefCasts(funcOp);
590   }
591 
592   // Bufferize all other ops.
593   for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
594     // Functions were already bufferized.
595     if (isa<func::FuncOp>(&op))
596       continue;
597     if (failed(bufferizeOp(&op, options, statistics)))
598       return failure();
599   }
600 
601   // Post-pass cleanup of function argument attributes.
602   removeBufferizationAttributesInModule(moduleOp);
603 
604   return success();
605 }
606 
607 LogicalResult mlir::bufferization::runOneShotModuleBufferize(
608     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
609     BufferizationStatistics *statistics) {
610   assert(options.bufferizeFunctionBoundaries &&
611          "expected that function boundary bufferization is activated");
612   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
613          "invalid combination of bufferization flags");
614   if (!options.copyBeforeWrite) {
615     if (options.noAnalysisFuncFilter.empty()) {
616       if (failed(insertTensorCopies(moduleOp, options, statistics)))
617         return failure();
618     } else {
619       // FuncOps whose names are specified in options.noAnalysisFuncFilter will
620       // not be analyzed. Ops in these FuncOps will not be analyzed as well.
621       OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
622         auto func = dyn_cast<func::FuncOp>(op);
623         if (!func)
624           func = op->getParentOfType<func::FuncOp>();
625         if (func)
626           return llvm::is_contained(options.noAnalysisFuncFilter,
627                                     func.getSymName());
628         return false;
629       };
630       OneShotBufferizationOptions updatedOptions(options);
631       updatedOptions.opFilter.denyOperation(analysisFilterFn);
632       if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
633         return failure();
634     }
635   }
636   if (options.testAnalysisOnly)
637     return success();
638   if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
639     return failure();
640   return success();
641 }
642