xref: /llvm-project/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (revision 067d2779fcfc62dd429177f350b8cefe49b65b51)
1c1fef4e8SMatthias Springer //===- TestTensorCopyInsertion.cpp - Bufferization Analysis -----*- c++ -*-===//
2c1fef4e8SMatthias Springer //
3c1fef4e8SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c1fef4e8SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5c1fef4e8SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c1fef4e8SMatthias Springer //
7c1fef4e8SMatthias Springer //===----------------------------------------------------------------------===//
8c1fef4e8SMatthias Springer 
9c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
11c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
12c1fef4e8SMatthias Springer #include "mlir/Pass/Pass.h"
13c1fef4e8SMatthias Springer 
14c1fef4e8SMatthias Springer using namespace mlir;
15c1fef4e8SMatthias Springer 
16c1fef4e8SMatthias Springer namespace {
17c1fef4e8SMatthias Springer /// This pass runs One-Shot Analysis and inserts copies for all OpOperands that
18c1fef4e8SMatthias Springer /// were decided to bufferize out-of-place. After running this pass, a
19c1fef4e8SMatthias Springer /// bufferization can write to buffers directly (without making copies) and no
20c1fef4e8SMatthias Springer /// longer has to care about potential read-after-write conflicts.
21c1fef4e8SMatthias Springer ///
22c1fef4e8SMatthias Springer /// Note: By default, all newly inserted tensor copies/allocs (i.e., newly
23c1fef4e8SMatthias Springer /// created `bufferization.alloc_tensor` ops) that do not escape block are
24c1fef4e8SMatthias Springer /// annotated with `escape = false`. If `create-allocs` is unset, all newly
25c1fef4e8SMatthias Springer /// inserted tensor copies/allocs are annotated with `escape = true`. In that
26c1fef4e8SMatthias Springer /// case, they are not getting deallocated when bufferizing the IR.
27c1fef4e8SMatthias Springer struct TestTensorCopyInsertionPass
28c1fef4e8SMatthias Springer     : public PassWrapper<TestTensorCopyInsertionPass, OperationPass<ModuleOp>> {
29c1fef4e8SMatthias Springer   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorCopyInsertionPass)
30c1fef4e8SMatthias Springer 
31c1fef4e8SMatthias Springer   TestTensorCopyInsertionPass() = default;
TestTensorCopyInsertionPass__anonfd3215680111::TestTensorCopyInsertionPass32c1fef4e8SMatthias Springer   TestTensorCopyInsertionPass(const TestTensorCopyInsertionPass &pass)
33c1fef4e8SMatthias Springer       : PassWrapper(pass) {}
34c1fef4e8SMatthias Springer 
getDependentDialects__anonfd3215680111::TestTensorCopyInsertionPass35c1fef4e8SMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
36c1fef4e8SMatthias Springer     registry.insert<bufferization::BufferizationDialect>();
37c1fef4e8SMatthias Springer   }
getArgument__anonfd3215680111::TestTensorCopyInsertionPass38c1fef4e8SMatthias Springer   StringRef getArgument() const final { return "test-tensor-copy-insertion"; }
getDescription__anonfd3215680111::TestTensorCopyInsertionPass39c1fef4e8SMatthias Springer   StringRef getDescription() const final {
40c1fef4e8SMatthias Springer     return "Module pass to test Tensor Copy Insertion";
41c1fef4e8SMatthias Springer   }
42c1fef4e8SMatthias Springer 
runOnOperation__anonfd3215680111::TestTensorCopyInsertionPass43c1fef4e8SMatthias Springer   void runOnOperation() override {
44c1fef4e8SMatthias Springer     bufferization::OneShotBufferizationOptions options;
456bf043e7SMartin Erhart     options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
46c1fef4e8SMatthias Springer     options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
47*067d2779Sian Bearman     if (mustInferMemorySpace) {
48*067d2779Sian Bearman       options.defaultMemorySpaceFn =
49*067d2779Sian Bearman           [](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
50*067d2779Sian Bearman     }
51c1fef4e8SMatthias Springer     if (failed(bufferization::insertTensorCopies(getOperation(), options)))
52c1fef4e8SMatthias Springer       signalPassFailure();
53c1fef4e8SMatthias Springer   }
54c1fef4e8SMatthias Springer 
556bf043e7SMartin Erhart   Option<bool> allowReturnAllocsFromLoops{
566bf043e7SMartin Erhart       *this, "allow-return-allocs-from-loops",
576bf043e7SMartin Erhart       llvm::cl::desc("Allows returning/yielding new allocations from a loop."),
58c1fef4e8SMatthias Springer       llvm::cl::init(false)};
59c1fef4e8SMatthias Springer   Option<bool> bufferizeFunctionBoundaries{
60c1fef4e8SMatthias Springer       *this, "bufferize-function-boundaries",
61c1fef4e8SMatthias Springer       llvm::cl::desc("Bufferize function boundaries."), llvm::cl::init(false)};
62c1fef4e8SMatthias Springer   Option<bool> mustInferMemorySpace{
63c1fef4e8SMatthias Springer       *this, "must-infer-memory-space",
64c1fef4e8SMatthias Springer       llvm::cl::desc(
65c1fef4e8SMatthias Springer           "The memory space of an memref types must always be inferred. If "
66c1fef4e8SMatthias Springer           "unset, a default memory space of 0 is used otherwise."),
67c1fef4e8SMatthias Springer       llvm::cl::init(false)};
68c1fef4e8SMatthias Springer };
69c1fef4e8SMatthias Springer } // namespace
70c1fef4e8SMatthias Springer 
71c1fef4e8SMatthias Springer namespace mlir::test {
registerTestTensorCopyInsertionPass()72c1fef4e8SMatthias Springer void registerTestTensorCopyInsertionPass() {
73c1fef4e8SMatthias Springer   PassRegistration<TestTensorCopyInsertionPass>();
74c1fef4e8SMatthias Springer }
75c1fef4e8SMatthias Springer } // namespace mlir::test
76