xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (revision bfefa15cc18f6f4b0b07849c619989c1a8c5aef9)
1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Interfaces/ControlFlowInterfaces.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/Passes.h"
25 #include <optional>
26 
27 namespace mlir {
28 namespace bufferization {
29 #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
30 #define GEN_PASS_DEF_ONESHOTBUFFERIZE
31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
32 } // namespace bufferization
33 } // namespace mlir
34 
35 #define DEBUG_TYPE "bufferize"
36 
37 using namespace mlir;
38 using namespace mlir::bufferization;
39 
40 namespace {
41 
42 static LayoutMapOption parseLayoutMapOption(const std::string &s) {
43   if (s == "fully-dynamic-layout-map")
44     return LayoutMapOption::FullyDynamicLayoutMap;
45   if (s == "identity-layout-map")
46     return LayoutMapOption::IdentityLayoutMap;
47   if (s == "infer-layout-map")
48     return LayoutMapOption::InferLayoutMap;
49   llvm_unreachable("invalid layout map option");
50 }
51 
52 static OneShotBufferizationOptions::AnalysisHeuristic
53 parseHeuristicOption(const std::string &s) {
54   if (s == "bottom-up")
55     return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp;
56   if (s == "top-down")
57     return OneShotBufferizationOptions::AnalysisHeuristic::TopDown;
58   if (s == "bottom-up-from-terminators")
59     return OneShotBufferizationOptions::AnalysisHeuristic::
60         BottomUpFromTerminators;
61   if (s == "fuzzer")
62     return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer;
63   llvm_unreachable("invalid analysisheuristic option");
64 }
65 
66 struct OneShotBufferizePass
67     : public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
68   OneShotBufferizePass() = default;
69 
70   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
71       : options(options) {}
72 
73   void getDependentDialects(DialectRegistry &registry) const override {
74     registry
75         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
76   }
77 
78   void runOnOperation() override {
79     OneShotBufferizationOptions opt;
80     if (!options) {
81       // Make new bufferization options if none were provided when creating the
82       // pass.
83       opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
84       opt.allowUnknownOps = allowUnknownOps;
85       opt.analysisFuzzerSeed = analysisFuzzerSeed;
86       opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
87       opt.copyBeforeWrite = copyBeforeWrite;
88       opt.dumpAliasSets = dumpAliasSets;
89       opt.setFunctionBoundaryTypeConversion(
90           parseLayoutMapOption(functionBoundaryTypeConversion));
91 
92       if (mustInferMemorySpace && useEncodingForMemorySpace) {
93         emitError(getOperation()->getLoc())
94             << "only one of 'must-infer-memory-space' and "
95                "'use-encoding-for-memory-space' are allowed in "
96             << getArgument();
97         return signalPassFailure();
98       }
99 
100       if (mustInferMemorySpace) {
101         opt.defaultMemorySpaceFn =
102             [](TensorType t) -> std::optional<Attribute> {
103           return std::nullopt;
104         };
105       }
106 
107       if (useEncodingForMemorySpace) {
108         opt.defaultMemorySpaceFn =
109             [](TensorType t) -> std::optional<Attribute> {
110           if (auto rtt = dyn_cast<RankedTensorType>(t))
111             return rtt.getEncoding();
112           return std::nullopt;
113         };
114       }
115 
116       opt.printConflicts = printConflicts;
117       opt.bufferAlignment = bufferAlignment;
118       opt.testAnalysisOnly = testAnalysisOnly;
119       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
120       opt.checkParallelRegions = checkParallelRegions;
121       opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
122 
123       // Configure type converter.
124       LayoutMapOption unknownTypeConversionOption =
125           parseLayoutMapOption(unknownTypeConversion);
126       if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
127         emitError(UnknownLoc::get(&getContext()),
128                   "Invalid option: 'infer-layout-map' is not a valid value for "
129                   "'unknown-type-conversion'");
130         return signalPassFailure();
131       }
132       opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
133                                        const BufferizationOptions &options) {
134         auto tensorType = cast<TensorType>(value.getType());
135         if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
136           return bufferization::getMemRefTypeWithStaticIdentityLayout(
137               tensorType, memorySpace);
138         assert(unknownTypeConversionOption ==
139                    LayoutMapOption::FullyDynamicLayoutMap &&
140                "invalid layout map option");
141         return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
142                                                                   memorySpace);
143       };
144 
145       // Configure op filter.
146       OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
147         // Filter may be specified via options.
148         if (this->dialectFilter.hasValue())
149           return llvm::is_contained(this->dialectFilter,
150                                     op->getDialect()->getNamespace());
151         // No filter specified: All other ops are allowed.
152         return true;
153       };
154       opt.opFilter.allowOperation(filterFn);
155     } else {
156       opt = *options;
157     }
158 
159     if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
160       // These two flags do not make sense together: "copy-before-write"
161       // indicates that copies should be inserted before every memory write,
162       // but "test-analysis-only" indicates that only the analysis should be
163       // tested. (I.e., no IR is bufferized.)
164       emitError(UnknownLoc::get(&getContext()),
165                 "Invalid option: 'copy-before-write' cannot be used with "
166                 "'test-analysis-only'");
167       return signalPassFailure();
168     }
169 
170     if (opt.printConflicts && !opt.testAnalysisOnly) {
171       emitError(
172           UnknownLoc::get(&getContext()),
173           "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
174       return signalPassFailure();
175     }
176 
177     if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
178       emitError(
179           UnknownLoc::get(&getContext()),
180           "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
181       return signalPassFailure();
182     }
183 
184     BufferizationStatistics statistics;
185     ModuleOp moduleOp = getOperation();
186     if (opt.bufferizeFunctionBoundaries) {
187       if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
188         signalPassFailure();
189         return;
190       }
191     } else {
192       if (!opt.noAnalysisFuncFilter.empty()) {
193         emitError(UnknownLoc::get(&getContext()),
194                   "Invalid option: 'no-analysis-func-filter' requires "
195                   "'bufferize-function-boundaries'");
196         return signalPassFailure();
197       }
198       if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
199         signalPassFailure();
200         return;
201       }
202     }
203 
204     // Set pass statistics.
205     this->numBufferAlloc = statistics.numBufferAlloc;
206     this->numTensorInPlace = statistics.numTensorInPlace;
207     this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
208   }
209 
210 private:
211   std::optional<OneShotBufferizationOptions> options;
212 };
213 } // namespace
214 
215 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
216   return std::make_unique<OneShotBufferizePass>();
217 }
218 
219 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
220     const OneShotBufferizationOptions &options) {
221   return std::make_unique<OneShotBufferizePass>(options);
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // BufferizableOpInterface-based Bufferization
226 //===----------------------------------------------------------------------===//
227 
228 namespace {
229 /// A rewriter that keeps track of extra information during bufferization.
230 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
231 public:
232   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
233                         DenseSet<Operation *> &toMemrefOps,
234                         SmallVector<Operation *> &worklist,
235                         const BufferizationOptions &options,
236                         BufferizationStatistics *statistics)
237       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
238         worklist(worklist), analysisState(options), statistics(statistics) {
239     setListener(this);
240   }
241 
242 protected:
243   void notifyOperationErased(Operation *op) override {
244     erasedOps.insert(op);
245     // Erase if present.
246     toMemrefOps.erase(op);
247   }
248 
249   void notifyOperationInserted(Operation *op, InsertPoint previous) override {
250     // We only care about newly created ops.
251     if (previous.isSet())
252       return;
253 
254     erasedOps.erase(op);
255 
256     // Gather statistics about allocs.
257     if (statistics) {
258       if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
259         statistics->numBufferAlloc += static_cast<int64_t>(
260             sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
261     }
262 
263     // Keep track of to_memref ops.
264     if (isa<ToMemrefOp>(op)) {
265       toMemrefOps.insert(op);
266       return;
267     }
268 
269     // Skip to_tensor ops.
270     if (isa<ToTensorOp>(op))
271       return;
272 
273     // Skip non-tensor ops.
274     if (!hasTensorSemantics(op))
275       return;
276 
277     // Skip ops that are not allowed to be bufferized.
278     auto const &options = analysisState.getOptions();
279     if (!options.isOpAllowed(op))
280       return;
281 
282     // Add op to worklist.
283     worklist.push_back(op);
284   }
285 
286 private:
287   /// A set of all erased ops.
288   DenseSet<Operation *> &erasedOps;
289 
290   /// A set of all to_memref ops.
291   DenseSet<Operation *> &toMemrefOps;
292 
293   /// The worklist of ops to be bufferized.
294   SmallVector<Operation *> &worklist;
295 
296   /// The analysis state. Used for debug assertions and access to the
297   /// bufferization options.
298   const AnalysisState analysisState;
299 
300   /// Bufferization statistics for debugging.
301   BufferizationStatistics *statistics;
302 };
303 } // namespace
304 
305 LogicalResult bufferization::bufferizeOp(Operation *op,
306                                          const BufferizationOptions &options,
307                                          BufferizationStatistics *statistics) {
308   if (options.copyBeforeWrite) {
309     AnalysisState state(options);
310     if (failed(insertTensorCopies(op, state)))
311       return failure();
312   }
313 
314   // Keep track of to_memref ops.
315   DenseSet<Operation *> toMemrefOps;
316   op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
317 
318   // Gather all bufferizable ops in top-to-bottom order.
319   //
320   // We should ideally know the exact memref type of all operands when
321   // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
322   // Otherwise, we have to use a memref type with a fully dynamic layout map to
323   // avoid copies. We are currently missing patterns for layout maps to
324   // canonicalize away (or canonicalize to more precise layouts).
325   SmallVector<Operation *> worklist;
326   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
327     if (options.isOpAllowed(op) && hasTensorSemantics(op))
328       worklist.push_back(op);
329   });
330 
331   // Keep track of all erased ops.
332   DenseSet<Operation *> erasedOps;
333 
334   // Bufferize all ops.
335   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
336                                  worklist, options, statistics);
337   for (unsigned i = 0; i < worklist.size(); ++i) {
338     Operation *nextOp = worklist[i];
339     // Skip ops that were erased.
340     if (erasedOps.contains(nextOp))
341       continue;
342     // Skip ops that are not bufferizable or not allowed.
343     auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
344     if (!bufferizableOp)
345       continue;
346     // Skip ops that no longer have tensor semantics.
347     if (!hasTensorSemantics(nextOp))
348       continue;
349     // Check for unsupported unstructured control flow.
350     if (!bufferizableOp.supportsUnstructuredControlFlow())
351       for (Region &r : nextOp->getRegions())
352         if (r.getBlocks().size() > 1)
353           return nextOp->emitOpError(
354               "op or BufferizableOpInterface implementation does not support "
355               "unstructured control flow, but at least one region has multiple "
356               "blocks");
357 
358     // Bufferize the op.
359     LLVM_DEBUG(llvm::dbgs()
360                << "//===-------------------------------------------===//\n"
361                << "IR after bufferizing: " << nextOp->getName() << "\n");
362     rewriter.setInsertionPoint(nextOp);
363     if (failed(bufferizableOp.bufferize(rewriter, options))) {
364       LLVM_DEBUG(llvm::dbgs()
365                  << "failed to bufferize\n"
366                  << "//===-------------------------------------------===//\n");
367       return nextOp->emitError("failed to bufferize op");
368     }
369     LLVM_DEBUG(llvm::dbgs()
370                << *op
371                << "\n//===-------------------------------------------===//\n");
372   }
373 
374   // Return early if the top-level op is entirely gone.
375   if (erasedOps.contains(op))
376     return success();
377 
378   // Fold all to_memref(to_tensor(x)) pairs.
379   for (Operation *op : toMemrefOps) {
380     rewriter.setInsertionPoint(op);
381     (void)bufferization::foldToMemrefToTensorPair(
382         rewriter, cast<ToMemrefOp>(op), options);
383   }
384 
385   // Remove all dead to_tensor ops.
386   op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
387     if (toTensorOp->getUses().empty()) {
388       rewriter.eraseOp(toTensorOp);
389       return WalkResult::skip();
390     }
391     return WalkResult::advance();
392   });
393 
394   /// Check the result of bufferization. Return an error if an op was not
395   /// bufferized, unless partial bufferization is allowed.
396   if (options.allowUnknownOps)
397     return success();
398 
399   for (Operation *op : worklist) {
400     // Skip ops that are entirely gone.
401     if (erasedOps.contains(op))
402       continue;
403     // Ops that no longer have tensor semantics (because they were updated
404     // in-place) are allowed.
405     if (!hasTensorSemantics(op))
406       continue;
407     // Continue ops that are not allowed.
408     if (!options.isOpAllowed(op))
409       continue;
410     // Ops without any uses and no side effects will fold away.
411     if (op->getUses().empty() && isMemoryEffectFree(op))
412       continue;
413     // ToTensorOps/ToMemrefOps are allowed in the output.
414     if (isa<ToTensorOp, ToMemrefOp>(op))
415       continue;
416     return op->emitError("op was not bufferized");
417   }
418 
419   return success();
420 }
421 
422 LogicalResult
423 bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
424                                        const BufferizationOptions &options) {
425   OpBuilder::InsertionGuard g(rewriter);
426   auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
427   if (!bufferizableOp)
428     return failure();
429 
430   // Compute the new signature.
431   SmallVector<Type> newTypes;
432   for (BlockArgument &bbArg : block->getArguments()) {
433     auto tensorType = dyn_cast<TensorType>(bbArg.getType());
434     if (!tensorType) {
435       newTypes.push_back(bbArg.getType());
436       continue;
437     }
438 
439     FailureOr<BaseMemRefType> memrefType =
440         bufferization::getBufferType(bbArg, options);
441     if (failed(memrefType))
442       return failure();
443     newTypes.push_back(*memrefType);
444   }
445 
446   // Change the type of all block arguments.
447   for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
448     if (bbArg.getType() == type)
449       continue;
450 
451     // Collect all uses of the bbArg.
452     SmallVector<OpOperand *> bbArgUses;
453     for (OpOperand &use : bbArg.getUses())
454       bbArgUses.push_back(&use);
455 
456     Type tensorType = bbArg.getType();
457     // Change the bbArg type to memref.
458     bbArg.setType(type);
459 
460     // Replace all uses of the original tensor bbArg.
461     rewriter.setInsertionPointToStart(block);
462     if (!bbArgUses.empty()) {
463       Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
464           bbArg.getLoc(), tensorType, bbArg);
465       for (OpOperand *use : bbArgUses)
466         use->set(toTensorOp);
467     }
468   }
469 
470   // Bufferize callers of the block.
471   for (Operation *op : block->getUsers()) {
472     auto branchOp = dyn_cast<BranchOpInterface>(op);
473     if (!branchOp)
474       return op->emitOpError("cannot bufferize ops with block references that "
475                              "do not implement BranchOpInterface");
476 
477     auto it = llvm::find(op->getSuccessors(), block);
478     assert(it != op->getSuccessors().end() && "could find successor");
479     int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
480 
481     SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
482     SmallVector<Value> newOperands;
483     for (auto [operand, type] :
484          llvm::zip(operands.getForwardedOperands(), newTypes)) {
485       if (operand.getType() == type) {
486         // Not a tensor type. Nothing to do for this operand.
487         newOperands.push_back(operand);
488         continue;
489       }
490       FailureOr<BaseMemRefType> operandBufferType =
491           bufferization::getBufferType(operand, options);
492       if (failed(operandBufferType))
493         return failure();
494       rewriter.setInsertionPointAfterValue(operand);
495       Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>(
496           operand.getLoc(), *operandBufferType, operand);
497       // A cast is needed if the operand and the block argument have different
498       // bufferized types.
499       if (type != *operandBufferType)
500         bufferizedOperand = rewriter.create<memref::CastOp>(
501             operand.getLoc(), type, bufferizedOperand);
502       newOperands.push_back(bufferizedOperand);
503     }
504     operands.getMutableForwardedOperands().assign(newOperands);
505   }
506 
507   return success();
508 }
509