xref: /llvm-project/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp (revision 4c48f016effde67d500fc95290096aec9f3bdb70)
1b1357fe6SThomas Raoux //===- TestComposeSubView.cpp - Test composed subviews --------------------===//
2b1357fe6SThomas Raoux //
3b1357fe6SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b1357fe6SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5b1357fe6SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b1357fe6SThomas Raoux //
7b1357fe6SThomas Raoux //===----------------------------------------------------------------------===//
8b1357fe6SThomas Raoux 
9b1357fe6SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h"
10b1357fe6SThomas Raoux #include "mlir/Dialect/MemRef/Transforms/Passes.h"
11faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
12b1357fe6SThomas Raoux 
13b1357fe6SThomas Raoux #include "mlir/Dialect/MemRef/IR/MemRef.h"
14b1357fe6SThomas Raoux #include "mlir/Pass/Pass.h"
15b1357fe6SThomas Raoux 
16b1357fe6SThomas Raoux using namespace mlir;
17b1357fe6SThomas Raoux 
18b1357fe6SThomas Raoux namespace {
19b1357fe6SThomas Raoux struct TestMultiBufferingPass
2087d6bf37SRiver Riddle     : public PassWrapper<TestMultiBufferingPass, OperationPass<>> {
215e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiBufferingPass)
225e50dd04SRiver Riddle 
23b1357fe6SThomas Raoux   TestMultiBufferingPass() = default;
TestMultiBufferingPass__anon5303f5600111::TestMultiBufferingPass24b1357fe6SThomas Raoux   TestMultiBufferingPass(const TestMultiBufferingPass &pass)
25b1357fe6SThomas Raoux       : PassWrapper(pass) {}
getDependentDialects__anon5303f5600111::TestMultiBufferingPass26b1357fe6SThomas Raoux   void getDependentDialects(DialectRegistry &registry) const override {
27*4c48f016SMatthias Springer     registry.insert<affine::AffineDialect>();
28b1357fe6SThomas Raoux   }
getArgument__anon5303f5600111::TestMultiBufferingPass29b1357fe6SThomas Raoux   StringRef getArgument() const final { return "test-multi-buffering"; }
getDescription__anon5303f5600111::TestMultiBufferingPass30b1357fe6SThomas Raoux   StringRef getDescription() const final {
31b1357fe6SThomas Raoux     return "Test multi buffering transformation";
32b1357fe6SThomas Raoux   }
33b1357fe6SThomas Raoux   void runOnOperation() override;
34b1357fe6SThomas Raoux   Option<unsigned> multiplier{
35b1357fe6SThomas Raoux       *this, "multiplier",
36b1357fe6SThomas Raoux       llvm::cl::desc(
37b1357fe6SThomas Raoux           "Decide how many versions of the buffer should be created,"),
38b1357fe6SThomas Raoux       llvm::cl::init(2)};
39b1357fe6SThomas Raoux };
40b1357fe6SThomas Raoux 
runOnOperation()41b1357fe6SThomas Raoux void TestMultiBufferingPass::runOnOperation() {
42b1357fe6SThomas Raoux   SmallVector<memref::AllocOp> allocs;
4387d6bf37SRiver Riddle   getOperation()->walk(
44b1357fe6SThomas Raoux       [&allocs](memref::AllocOp alloc) { allocs.push_back(alloc); });
45b1357fe6SThomas Raoux   for (memref::AllocOp alloc : allocs)
46b1357fe6SThomas Raoux     (void)multiBuffer(alloc, multiplier);
47b1357fe6SThomas Raoux }
48b1357fe6SThomas Raoux } // namespace
49b1357fe6SThomas Raoux 
50b1357fe6SThomas Raoux namespace mlir {
51b1357fe6SThomas Raoux namespace test {
registerTestMultiBuffering()52b1357fe6SThomas Raoux void registerTestMultiBuffering() {
53b1357fe6SThomas Raoux   PassRegistration<TestMultiBufferingPass>();
54b1357fe6SThomas Raoux }
55b1357fe6SThomas Raoux } // namespace test
56b1357fe6SThomas Raoux } // namespace mlir
57