xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (revision bfefa15cc18f6f4b0b07849c619989c1a8c5aef9)
1f89bb3c0SAlexander Belyaev //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2f89bb3c0SAlexander Belyaev //
3f89bb3c0SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f89bb3c0SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5f89bb3c0SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f89bb3c0SAlexander Belyaev //
7f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===//
8f89bb3c0SAlexander Belyaev 
967d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10f89bb3c0SAlexander Belyaev 
117a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14d2dacde5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
1628b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1723aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
18eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
19a6236898SMatthias Springer #include "mlir/IR/Diagnostics.h"
20f89bb3c0SAlexander Belyaev #include "mlir/IR/Operation.h"
216ecebb49SMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
22fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
23d2dacde5SMatthias Springer #include "mlir/Pass/PassManager.h"
24d2dacde5SMatthias Springer #include "mlir/Transforms/Passes.h"
25a1fe1f5fSKazu Hirata #include <optional>
26f89bb3c0SAlexander Belyaev 
2767d0d7acSMichele Scuttari namespace mlir {
2867d0d7acSMichele Scuttari namespace bufferization {
2967d0d7acSMichele Scuttari #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_ONESHOTBUFFERIZE
3167d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
3267d0d7acSMichele Scuttari } // namespace bufferization
3367d0d7acSMichele Scuttari } // namespace mlir
3467d0d7acSMichele Scuttari 
35df23ede2SMatthias Springer #define DEBUG_TYPE "bufferize"
36df23ede2SMatthias Springer 
37f89bb3c0SAlexander Belyaev using namespace mlir;
38f89bb3c0SAlexander Belyaev using namespace mlir::bufferization;
39f89bb3c0SAlexander Belyaev 
40f89bb3c0SAlexander Belyaev namespace {
41d2dacde5SMatthias Springer 
42c780184aSLorenzo Chelini static LayoutMapOption parseLayoutMapOption(const std::string &s) {
43f287da8aSMatthias Springer   if (s == "fully-dynamic-layout-map")
44c780184aSLorenzo Chelini     return LayoutMapOption::FullyDynamicLayoutMap;
45f287da8aSMatthias Springer   if (s == "identity-layout-map")
46c780184aSLorenzo Chelini     return LayoutMapOption::IdentityLayoutMap;
47f287da8aSMatthias Springer   if (s == "infer-layout-map")
48c780184aSLorenzo Chelini     return LayoutMapOption::InferLayoutMap;
49f287da8aSMatthias Springer   llvm_unreachable("invalid layout map option");
50f287da8aSMatthias Springer }
51f287da8aSMatthias Springer 
521b99f3a2SMatthias Springer static OneShotBufferizationOptions::AnalysisHeuristic
531b99f3a2SMatthias Springer parseHeuristicOption(const std::string &s) {
541b99f3a2SMatthias Springer   if (s == "bottom-up")
551b99f3a2SMatthias Springer     return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp;
561b99f3a2SMatthias Springer   if (s == "top-down")
571b99f3a2SMatthias Springer     return OneShotBufferizationOptions::AnalysisHeuristic::TopDown;
5835d3b343SMatthias Springer   if (s == "bottom-up-from-terminators")
5935d3b343SMatthias Springer     return OneShotBufferizationOptions::AnalysisHeuristic::
6035d3b343SMatthias Springer         BottomUpFromTerminators;
6135d3b343SMatthias Springer   if (s == "fuzzer")
6235d3b343SMatthias Springer     return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer;
631b99f3a2SMatthias Springer   llvm_unreachable("invalid analysisheuristic option");
641b99f3a2SMatthias Springer }
651b99f3a2SMatthias Springer 
66d2dacde5SMatthias Springer struct OneShotBufferizePass
6767d0d7acSMichele Scuttari     : public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
680666e50eSMehdi Amini   OneShotBufferizePass() = default;
69d2dacde5SMatthias Springer 
709597b16aSMatthias Springer   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
71d2dacde5SMatthias Springer       : options(options) {}
72d2dacde5SMatthias Springer 
73d2dacde5SMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
74c076fa1cSMatthias Springer     registry
75c076fa1cSMatthias Springer         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
76d2dacde5SMatthias Springer   }
77d2dacde5SMatthias Springer 
78d2dacde5SMatthias Springer   void runOnOperation() override {
799597b16aSMatthias Springer     OneShotBufferizationOptions opt;
80d2dacde5SMatthias Springer     if (!options) {
81d2dacde5SMatthias Springer       // Make new bufferization options if none were provided when creating the
82d2dacde5SMatthias Springer       // pass.
836bf043e7SMartin Erhart       opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
84d2dacde5SMatthias Springer       opt.allowUnknownOps = allowUnknownOps;
85d2dacde5SMatthias Springer       opt.analysisFuzzerSeed = analysisFuzzerSeed;
861b99f3a2SMatthias Springer       opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
87f7dd9a32SMatthias Springer       opt.copyBeforeWrite = copyBeforeWrite;
88bb9d1b55SMatthias Springer       opt.dumpAliasSets = dumpAliasSets;
8975ef84bfSOleg Shyshkov       opt.setFunctionBoundaryTypeConversion(
9075ef84bfSOleg Shyshkov           parseLayoutMapOption(functionBoundaryTypeConversion));
91ced2fc78SChristopher Bate 
92ced2fc78SChristopher Bate       if (mustInferMemorySpace && useEncodingForMemorySpace) {
93ced2fc78SChristopher Bate         emitError(getOperation()->getLoc())
94ced2fc78SChristopher Bate             << "only one of 'must-infer-memory-space' and "
95ced2fc78SChristopher Bate                "'use-encoding-for-memory-space' are allowed in "
96ced2fc78SChristopher Bate             << getArgument();
97ced2fc78SChristopher Bate         return signalPassFailure();
98ced2fc78SChristopher Bate       }
99ced2fc78SChristopher Bate 
100067d2779Sian Bearman       if (mustInferMemorySpace) {
101067d2779Sian Bearman         opt.defaultMemorySpaceFn =
102067d2779Sian Bearman             [](TensorType t) -> std::optional<Attribute> {
103067d2779Sian Bearman           return std::nullopt;
104067d2779Sian Bearman         };
105067d2779Sian Bearman       }
106ced2fc78SChristopher Bate 
107ced2fc78SChristopher Bate       if (useEncodingForMemorySpace) {
108ced2fc78SChristopher Bate         opt.defaultMemorySpaceFn =
109ced2fc78SChristopher Bate             [](TensorType t) -> std::optional<Attribute> {
110ced2fc78SChristopher Bate           if (auto rtt = dyn_cast<RankedTensorType>(t))
111ced2fc78SChristopher Bate             return rtt.getEncoding();
112ced2fc78SChristopher Bate           return std::nullopt;
113ced2fc78SChristopher Bate         };
114ced2fc78SChristopher Bate       }
115ced2fc78SChristopher Bate 
116d2dacde5SMatthias Springer       opt.printConflicts = printConflicts;
11770334081SSimon Camphausen       opt.bufferAlignment = bufferAlignment;
118d2dacde5SMatthias Springer       opt.testAnalysisOnly = testAnalysisOnly;
119d6dab38aSMatthias Springer       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
120d5863721SMax191       opt.checkParallelRegions = checkParallelRegions;
1219cf96850SMaya Amrami       opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
122d2dacde5SMatthias Springer 
123606f7c8fSMatthias Springer       // Configure type converter.
124c780184aSLorenzo Chelini       LayoutMapOption unknownTypeConversionOption =
125606f7c8fSMatthias Springer           parseLayoutMapOption(unknownTypeConversion);
126a6236898SMatthias Springer       if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
127a6236898SMatthias Springer         emitError(UnknownLoc::get(&getContext()),
128a6236898SMatthias Springer                   "Invalid option: 'infer-layout-map' is not a valid value for "
129a6236898SMatthias Springer                   "'unknown-type-conversion'");
130a6236898SMatthias Springer         return signalPassFailure();
131a6236898SMatthias Springer       }
1329bb63374SLei Zhang       opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
133606f7c8fSMatthias Springer                                        const BufferizationOptions &options) {
1345550c821STres Popp         auto tensorType = cast<TensorType>(value.getType());
135c780184aSLorenzo Chelini         if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
136606f7c8fSMatthias Springer           return bufferization::getMemRefTypeWithStaticIdentityLayout(
137606f7c8fSMatthias Springer               tensorType, memorySpace);
138c780184aSLorenzo Chelini         assert(unknownTypeConversionOption ==
139c780184aSLorenzo Chelini                    LayoutMapOption::FullyDynamicLayoutMap &&
140606f7c8fSMatthias Springer                "invalid layout map option");
141606f7c8fSMatthias Springer         return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
142606f7c8fSMatthias Springer                                                                   memorySpace);
143606f7c8fSMatthias Springer       };
144606f7c8fSMatthias Springer 
145606f7c8fSMatthias Springer       // Configure op filter.
146b7f93c28SJeff Niu       OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
147d2dacde5SMatthias Springer         // Filter may be specified via options.
148d2dacde5SMatthias Springer         if (this->dialectFilter.hasValue())
149ad44495aSjacquesguan           return llvm::is_contained(this->dialectFilter,
150ad44495aSjacquesguan                                     op->getDialect()->getNamespace());
151d2dacde5SMatthias Springer         // No filter specified: All other ops are allowed.
152d2dacde5SMatthias Springer         return true;
153d2dacde5SMatthias Springer       };
1541534177fSMatthias Springer       opt.opFilter.allowOperation(filterFn);
155d2dacde5SMatthias Springer     } else {
156d2dacde5SMatthias Springer       opt = *options;
157d2dacde5SMatthias Springer     }
158d2dacde5SMatthias Springer 
159a6236898SMatthias Springer     if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
160a6236898SMatthias Springer       // These two flags do not make sense together: "copy-before-write"
161a6236898SMatthias Springer       // indicates that copies should be inserted before every memory write,
162a6236898SMatthias Springer       // but "test-analysis-only" indicates that only the analysis should be
163a6236898SMatthias Springer       // tested. (I.e., no IR is bufferized.)
164a6236898SMatthias Springer       emitError(UnknownLoc::get(&getContext()),
165a6236898SMatthias Springer                 "Invalid option: 'copy-before-write' cannot be used with "
166a6236898SMatthias Springer                 "'test-analysis-only'");
167a6236898SMatthias Springer       return signalPassFailure();
168a6236898SMatthias Springer     }
169a6236898SMatthias Springer 
170a6236898SMatthias Springer     if (opt.printConflicts && !opt.testAnalysisOnly) {
171a6236898SMatthias Springer       emitError(
172a6236898SMatthias Springer           UnknownLoc::get(&getContext()),
173a6236898SMatthias Springer           "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
174a6236898SMatthias Springer       return signalPassFailure();
175a6236898SMatthias Springer     }
176a6236898SMatthias Springer 
177a6236898SMatthias Springer     if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
178a6236898SMatthias Springer       emitError(
179a6236898SMatthias Springer           UnknownLoc::get(&getContext()),
180a6236898SMatthias Springer           "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
181a6236898SMatthias Springer       return signalPassFailure();
182a6236898SMatthias Springer     }
183a6236898SMatthias Springer 
184ae05bd99SMatthias Springer     BufferizationStatistics statistics;
185d2dacde5SMatthias Springer     ModuleOp moduleOp = getOperation();
186d6dab38aSMatthias Springer     if (opt.bufferizeFunctionBoundaries) {
1879cf96850SMaya Amrami       if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
188e07a7fd5SMatthias Springer         signalPassFailure();
189e07a7fd5SMatthias Springer         return;
190e07a7fd5SMatthias Springer       }
191e07a7fd5SMatthias Springer     } else {
192a6236898SMatthias Springer       if (!opt.noAnalysisFuncFilter.empty()) {
193a6236898SMatthias Springer         emitError(UnknownLoc::get(&getContext()),
194a6236898SMatthias Springer                   "Invalid option: 'no-analysis-func-filter' requires "
195a6236898SMatthias Springer                   "'bufferize-function-boundaries'");
196a6236898SMatthias Springer         return signalPassFailure();
197a6236898SMatthias Springer       }
198ae05bd99SMatthias Springer       if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
199d2dacde5SMatthias Springer         signalPassFailure();
200d2dacde5SMatthias Springer         return;
201d2dacde5SMatthias Springer       }
202e07a7fd5SMatthias Springer     }
203d2dacde5SMatthias Springer 
204ae05bd99SMatthias Springer     // Set pass statistics.
205ae05bd99SMatthias Springer     this->numBufferAlloc = statistics.numBufferAlloc;
206ae05bd99SMatthias Springer     this->numTensorInPlace = statistics.numTensorInPlace;
207ae05bd99SMatthias Springer     this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
208d2dacde5SMatthias Springer   }
209d2dacde5SMatthias Springer 
210d2dacde5SMatthias Springer private:
2110a81ace0SKazu Hirata   std::optional<OneShotBufferizationOptions> options;
212d2dacde5SMatthias Springer };
213f89bb3c0SAlexander Belyaev } // namespace
214f89bb3c0SAlexander Belyaev 
215d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
216d2dacde5SMatthias Springer   return std::make_unique<OneShotBufferizePass>();
217d2dacde5SMatthias Springer }
218d2dacde5SMatthias Springer 
219d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
2209597b16aSMatthias Springer     const OneShotBufferizationOptions &options) {
221d2dacde5SMatthias Springer   return std::make_unique<OneShotBufferizePass>(options);
222d2dacde5SMatthias Springer }
223d2dacde5SMatthias Springer 
22449e37000SMatthias Springer //===----------------------------------------------------------------------===//
22549e37000SMatthias Springer // BufferizableOpInterface-based Bufferization
22649e37000SMatthias Springer //===----------------------------------------------------------------------===//
22749e37000SMatthias Springer 
228d820acddSMatthias Springer namespace {
229d820acddSMatthias Springer /// A rewriter that keeps track of extra information during bufferization.
230c6532830SMatthias Springer class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
231d820acddSMatthias Springer public:
232d820acddSMatthias Springer   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
233d820acddSMatthias Springer                         DenseSet<Operation *> &toMemrefOps,
234b3ebe3beSMatthias Springer                         SmallVector<Operation *> &worklist,
2352f0a634cSMatthias Springer                         const BufferizationOptions &options,
236ae05bd99SMatthias Springer                         BufferizationStatistics *statistics)
237d820acddSMatthias Springer       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
2389d34c052SMatthias Springer         worklist(worklist), analysisState(options), statistics(statistics) {
239c6532830SMatthias Springer     setListener(this);
240c6532830SMatthias Springer   }
241d820acddSMatthias Springer 
242d820acddSMatthias Springer protected:
243914e6074SMatthias Springer   void notifyOperationErased(Operation *op) override {
244d820acddSMatthias Springer     erasedOps.insert(op);
2459785eb1bSMatthias Springer     // Erase if present.
2469785eb1bSMatthias Springer     toMemrefOps.erase(op);
247d820acddSMatthias Springer   }
248d820acddSMatthias Springer 
2495cc0f76dSMatthias Springer   void notifyOperationInserted(Operation *op, InsertPoint previous) override {
2505cc0f76dSMatthias Springer     // We only care about newly created ops.
2515cc0f76dSMatthias Springer     if (previous.isSet())
2525cc0f76dSMatthias Springer       return;
2535cc0f76dSMatthias Springer 
254b3ebe3beSMatthias Springer     erasedOps.erase(op);
255d820acddSMatthias Springer 
2566bf043e7SMartin Erhart     // Gather statistics about allocs.
257ae05bd99SMatthias Springer     if (statistics) {
2586bf043e7SMartin Erhart       if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
259ae05bd99SMatthias Springer         statistics->numBufferAlloc += static_cast<int64_t>(
260ae05bd99SMatthias Springer             sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
261ae05bd99SMatthias Springer     }
262ae05bd99SMatthias Springer 
263d820acddSMatthias Springer     // Keep track of to_memref ops.
264d820acddSMatthias Springer     if (isa<ToMemrefOp>(op)) {
265d820acddSMatthias Springer       toMemrefOps.insert(op);
266d820acddSMatthias Springer       return;
267d820acddSMatthias Springer     }
268d820acddSMatthias Springer 
269d820acddSMatthias Springer     // Skip to_tensor ops.
270d820acddSMatthias Springer     if (isa<ToTensorOp>(op))
271d820acddSMatthias Springer       return;
272d820acddSMatthias Springer 
2732f0a634cSMatthias Springer     // Skip non-tensor ops.
2742f0a634cSMatthias Springer     if (!hasTensorSemantics(op))
2752f0a634cSMatthias Springer       return;
2762f0a634cSMatthias Springer 
277b3ebe3beSMatthias Springer     // Skip ops that are not allowed to be bufferized.
278b3ebe3beSMatthias Springer     auto const &options = analysisState.getOptions();
2799d34c052SMatthias Springer     if (!options.isOpAllowed(op))
2802f0a634cSMatthias Springer       return;
2812f0a634cSMatthias Springer 
282b3ebe3beSMatthias Springer     // Add op to worklist.
283b3ebe3beSMatthias Springer     worklist.push_back(op);
284d820acddSMatthias Springer   }
285d820acddSMatthias Springer 
286d820acddSMatthias Springer private:
287d820acddSMatthias Springer   /// A set of all erased ops.
288d820acddSMatthias Springer   DenseSet<Operation *> &erasedOps;
289d820acddSMatthias Springer 
290d820acddSMatthias Springer   /// A set of all to_memref ops.
291d820acddSMatthias Springer   DenseSet<Operation *> &toMemrefOps;
292d820acddSMatthias Springer 
293b3ebe3beSMatthias Springer   /// The worklist of ops to be bufferized.
294b3ebe3beSMatthias Springer   SmallVector<Operation *> &worklist;
2952f0a634cSMatthias Springer 
296b3ebe3beSMatthias Springer   /// The analysis state. Used for debug assertions and access to the
297b3ebe3beSMatthias Springer   /// bufferization options.
298b3ebe3beSMatthias Springer   const AnalysisState analysisState;
299b3ebe3beSMatthias Springer 
300ae05bd99SMatthias Springer   /// Bufferization statistics for debugging.
301ae05bd99SMatthias Springer   BufferizationStatistics *statistics;
302d820acddSMatthias Springer };
303d820acddSMatthias Springer } // namespace
304d820acddSMatthias Springer 
3052f0a634cSMatthias Springer LogicalResult bufferization::bufferizeOp(Operation *op,
306b3ebe3beSMatthias Springer                                          const BufferizationOptions &options,
307ae05bd99SMatthias Springer                                          BufferizationStatistics *statistics) {
3089d34c052SMatthias Springer   if (options.copyBeforeWrite) {
309b3ebe3beSMatthias Springer     AnalysisState state(options);
310b3ebe3beSMatthias Springer     if (failed(insertTensorCopies(op, state)))
311b3ebe3beSMatthias Springer       return failure();
312b3ebe3beSMatthias Springer   }
313b3ebe3beSMatthias Springer 
314d820acddSMatthias Springer   // Keep track of to_memref ops.
315d820acddSMatthias Springer   DenseSet<Operation *> toMemrefOps;
316d820acddSMatthias Springer   op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
317d820acddSMatthias Springer 
318d820acddSMatthias Springer   // Gather all bufferizable ops in top-to-bottom order.
31976b16010SMatthias Springer   //
320d820acddSMatthias Springer   // We should ideally know the exact memref type of all operands when
321d820acddSMatthias Springer   // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
322f287da8aSMatthias Springer   // Otherwise, we have to use a memref type with a fully dynamic layout map to
323f287da8aSMatthias Springer   // avoid copies. We are currently missing patterns for layout maps to
324f287da8aSMatthias Springer   // canonicalize away (or canonicalize to more precise layouts).
325d820acddSMatthias Springer   SmallVector<Operation *> worklist;
326ba9d886dSMatthias Springer   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
327db604911SBenjamin Kramer     if (options.isOpAllowed(op) && hasTensorSemantics(op))
328d820acddSMatthias Springer       worklist.push_back(op);
329d820acddSMatthias Springer   });
3306fc753adSMatthias Springer 
331d820acddSMatthias Springer   // Keep track of all erased ops.
332d820acddSMatthias Springer   DenseSet<Operation *> erasedOps;
3337a1579acSMatthias Springer 
334d820acddSMatthias Springer   // Bufferize all ops.
335d820acddSMatthias Springer   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
3369d34c052SMatthias Springer                                  worklist, options, statistics);
337d820acddSMatthias Springer   for (unsigned i = 0; i < worklist.size(); ++i) {
338df23ede2SMatthias Springer     Operation *nextOp = worklist[i];
339d820acddSMatthias Springer     // Skip ops that were erased.
340df23ede2SMatthias Springer     if (erasedOps.contains(nextOp))
341d820acddSMatthias Springer       continue;
3429785eb1bSMatthias Springer     // Skip ops that are not bufferizable or not allowed.
343df23ede2SMatthias Springer     auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
344d820acddSMatthias Springer     if (!bufferizableOp)
345d820acddSMatthias Springer       continue;
3469785eb1bSMatthias Springer     // Skip ops that no longer have tensor semantics.
347df23ede2SMatthias Springer     if (!hasTensorSemantics(nextOp))
348d820acddSMatthias Springer       continue;
349061aa2e3SMatthias Springer     // Check for unsupported unstructured control flow.
350061aa2e3SMatthias Springer     if (!bufferizableOp.supportsUnstructuredControlFlow())
351061aa2e3SMatthias Springer       for (Region &r : nextOp->getRegions())
352061aa2e3SMatthias Springer         if (r.getBlocks().size() > 1)
353061aa2e3SMatthias Springer           return nextOp->emitOpError(
354061aa2e3SMatthias Springer               "op or BufferizableOpInterface implementation does not support "
355061aa2e3SMatthias Springer               "unstructured control flow, but at least one region has multiple "
356061aa2e3SMatthias Springer               "blocks");
357061aa2e3SMatthias Springer 
358d820acddSMatthias Springer     // Bufferize the op.
359df23ede2SMatthias Springer     LLVM_DEBUG(llvm::dbgs()
360df23ede2SMatthias Springer                << "//===-------------------------------------------===//\n"
361df23ede2SMatthias Springer                << "IR after bufferizing: " << nextOp->getName() << "\n");
362df23ede2SMatthias Springer     rewriter.setInsertionPoint(nextOp);
363df23ede2SMatthias Springer     if (failed(bufferizableOp.bufferize(rewriter, options))) {
364df23ede2SMatthias Springer       LLVM_DEBUG(llvm::dbgs()
365df23ede2SMatthias Springer                  << "failed to bufferize\n"
366df23ede2SMatthias Springer                  << "//===-------------------------------------------===//\n");
367df23ede2SMatthias Springer       return nextOp->emitError("failed to bufferize op");
368df23ede2SMatthias Springer     }
369df23ede2SMatthias Springer     LLVM_DEBUG(llvm::dbgs()
370df23ede2SMatthias Springer                << *op
371df23ede2SMatthias Springer                << "\n//===-------------------------------------------===//\n");
372d820acddSMatthias Springer   }
373d820acddSMatthias Springer 
374fa101214SRyan Holt   // Return early if the top-level op is entirely gone.
375fa101214SRyan Holt   if (erasedOps.contains(op))
376fa101214SRyan Holt     return success();
377fa101214SRyan Holt 
378d820acddSMatthias Springer   // Fold all to_memref(to_tensor(x)) pairs.
379d820acddSMatthias Springer   for (Operation *op : toMemrefOps) {
380d820acddSMatthias Springer     rewriter.setInsertionPoint(op);
381c515c780SMatthias Gehre     (void)bufferization::foldToMemrefToTensorPair(
382c515c780SMatthias Gehre         rewriter, cast<ToMemrefOp>(op), options);
383d820acddSMatthias Springer   }
384d820acddSMatthias Springer 
385199f368eSMatthias Springer   // Remove all dead to_tensor ops.
386199f368eSMatthias Springer   op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
387199f368eSMatthias Springer     if (toTensorOp->getUses().empty()) {
388199f368eSMatthias Springer       rewriter.eraseOp(toTensorOp);
389199f368eSMatthias Springer       return WalkResult::skip();
390199f368eSMatthias Springer     }
391199f368eSMatthias Springer     return WalkResult::advance();
392199f368eSMatthias Springer   });
393199f368eSMatthias Springer 
394d820acddSMatthias Springer   /// Check the result of bufferization. Return an error if an op was not
395d820acddSMatthias Springer   /// bufferized, unless partial bufferization is allowed.
396b55d55ecSMatthias Springer   if (options.allowUnknownOps)
397d820acddSMatthias Springer     return success();
398d820acddSMatthias Springer 
399d820acddSMatthias Springer   for (Operation *op : worklist) {
400d820acddSMatthias Springer     // Skip ops that are entirely gone.
401d820acddSMatthias Springer     if (erasedOps.contains(op))
402d820acddSMatthias Springer       continue;
403d820acddSMatthias Springer     // Ops that no longer have tensor semantics (because they were updated
404d820acddSMatthias Springer     // in-place) are allowed.
405d820acddSMatthias Springer     if (!hasTensorSemantics(op))
406d820acddSMatthias Springer       continue;
407d820acddSMatthias Springer     // Continue ops that are not allowed.
408d820acddSMatthias Springer     if (!options.isOpAllowed(op))
409d820acddSMatthias Springer       continue;
410d820acddSMatthias Springer     // Ops without any uses and no side effects will fold away.
411fc367dfaSMahesh Ravishankar     if (op->getUses().empty() && isMemoryEffectFree(op))
412d820acddSMatthias Springer       continue;
413b3ebe3beSMatthias Springer     // ToTensorOps/ToMemrefOps are allowed in the output.
414b3ebe3beSMatthias Springer     if (isa<ToTensorOp, ToMemrefOp>(op))
415b3ebe3beSMatthias Springer       continue;
416d820acddSMatthias Springer     return op->emitError("op was not bufferized");
417d820acddSMatthias Springer   }
41805e0495fSMatthias Springer 
41905e0495fSMatthias Springer   return success();
4207a1579acSMatthias Springer }
421daf18108SMatthias Springer 
422a88732d9SMatthias Springer LogicalResult
423a88732d9SMatthias Springer bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
424a88732d9SMatthias Springer                                        const BufferizationOptions &options) {
425a88732d9SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
426a88732d9SMatthias Springer   auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
427a88732d9SMatthias Springer   if (!bufferizableOp)
428a88732d9SMatthias Springer     return failure();
429a88732d9SMatthias Springer 
430a88732d9SMatthias Springer   // Compute the new signature.
431a88732d9SMatthias Springer   SmallVector<Type> newTypes;
432a88732d9SMatthias Springer   for (BlockArgument &bbArg : block->getArguments()) {
433a88732d9SMatthias Springer     auto tensorType = dyn_cast<TensorType>(bbArg.getType());
434a88732d9SMatthias Springer     if (!tensorType) {
435a88732d9SMatthias Springer       newTypes.push_back(bbArg.getType());
436a88732d9SMatthias Springer       continue;
437a88732d9SMatthias Springer     }
438a88732d9SMatthias Springer 
439a88732d9SMatthias Springer     FailureOr<BaseMemRefType> memrefType =
440a88732d9SMatthias Springer         bufferization::getBufferType(bbArg, options);
441a88732d9SMatthias Springer     if (failed(memrefType))
442a88732d9SMatthias Springer       return failure();
443a88732d9SMatthias Springer     newTypes.push_back(*memrefType);
444a88732d9SMatthias Springer   }
445a88732d9SMatthias Springer 
446a88732d9SMatthias Springer   // Change the type of all block arguments.
447a88732d9SMatthias Springer   for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
448a88732d9SMatthias Springer     if (bbArg.getType() == type)
449a88732d9SMatthias Springer       continue;
450a88732d9SMatthias Springer 
451a88732d9SMatthias Springer     // Collect all uses of the bbArg.
452a88732d9SMatthias Springer     SmallVector<OpOperand *> bbArgUses;
453a88732d9SMatthias Springer     for (OpOperand &use : bbArg.getUses())
454a88732d9SMatthias Springer       bbArgUses.push_back(&use);
455a88732d9SMatthias Springer 
456*bfefa15cSYi Zhang     Type tensorType = bbArg.getType();
457a88732d9SMatthias Springer     // Change the bbArg type to memref.
458a88732d9SMatthias Springer     bbArg.setType(type);
459a88732d9SMatthias Springer 
460a88732d9SMatthias Springer     // Replace all uses of the original tensor bbArg.
461a88732d9SMatthias Springer     rewriter.setInsertionPointToStart(block);
462a88732d9SMatthias Springer     if (!bbArgUses.empty()) {
463*bfefa15cSYi Zhang       Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
464*bfefa15cSYi Zhang           bbArg.getLoc(), tensorType, bbArg);
465a88732d9SMatthias Springer       for (OpOperand *use : bbArgUses)
466a88732d9SMatthias Springer         use->set(toTensorOp);
467a88732d9SMatthias Springer     }
468a88732d9SMatthias Springer   }
469a88732d9SMatthias Springer 
4706ecebb49SMatthias Springer   // Bufferize callers of the block.
4716ecebb49SMatthias Springer   for (Operation *op : block->getUsers()) {
4726ecebb49SMatthias Springer     auto branchOp = dyn_cast<BranchOpInterface>(op);
4736ecebb49SMatthias Springer     if (!branchOp)
4746ecebb49SMatthias Springer       return op->emitOpError("cannot bufferize ops with block references that "
4756ecebb49SMatthias Springer                              "do not implement BranchOpInterface");
4766ecebb49SMatthias Springer 
4776ecebb49SMatthias Springer     auto it = llvm::find(op->getSuccessors(), block);
4786ecebb49SMatthias Springer     assert(it != op->getSuccessors().end() && "could find successor");
4796ecebb49SMatthias Springer     int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
4806ecebb49SMatthias Springer 
4816ecebb49SMatthias Springer     SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
4826ecebb49SMatthias Springer     SmallVector<Value> newOperands;
4836ecebb49SMatthias Springer     for (auto [operand, type] :
4846ecebb49SMatthias Springer          llvm::zip(operands.getForwardedOperands(), newTypes)) {
4856ecebb49SMatthias Springer       if (operand.getType() == type) {
4866ecebb49SMatthias Springer         // Not a tensor type. Nothing to do for this operand.
4876ecebb49SMatthias Springer         newOperands.push_back(operand);
4886ecebb49SMatthias Springer         continue;
4896ecebb49SMatthias Springer       }
4906ecebb49SMatthias Springer       FailureOr<BaseMemRefType> operandBufferType =
4916ecebb49SMatthias Springer           bufferization::getBufferType(operand, options);
4926ecebb49SMatthias Springer       if (failed(operandBufferType))
4936ecebb49SMatthias Springer         return failure();
4946ecebb49SMatthias Springer       rewriter.setInsertionPointAfterValue(operand);
4956ecebb49SMatthias Springer       Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>(
4966ecebb49SMatthias Springer           operand.getLoc(), *operandBufferType, operand);
4976ecebb49SMatthias Springer       // A cast is needed if the operand and the block argument have different
4986ecebb49SMatthias Springer       // bufferized types.
4996ecebb49SMatthias Springer       if (type != *operandBufferType)
5006ecebb49SMatthias Springer         bufferizedOperand = rewriter.create<memref::CastOp>(
5016ecebb49SMatthias Springer             operand.getLoc(), type, bufferizedOperand);
5026ecebb49SMatthias Springer       newOperands.push_back(bufferizedOperand);
5036ecebb49SMatthias Springer     }
5046ecebb49SMatthias Springer     operands.getMutableForwardedOperands().assign(newOperands);
5056ecebb49SMatthias Springer   }
5066ecebb49SMatthias Springer 
507a88732d9SMatthias Springer   return success();
508a88732d9SMatthias Springer }
509