xref: /llvm-project/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1dc3258c6SBoian Petkantchin //===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
2dc3258c6SBoian Petkantchin //
3dc3258c6SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4dc3258c6SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
5dc3258c6SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6dc3258c6SBoian Petkantchin //
7dc3258c6SBoian Petkantchin //===----------------------------------------------------------------------===//
8dc3258c6SBoian Petkantchin 
9dc3258c6SBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h"
10dc3258c6SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
11dc3258c6SBoian Petkantchin #include "mlir/Dialect/Utils/IndexingUtils.h"
12dc3258c6SBoian Petkantchin #include "mlir/IR/SymbolTable.h"
13dc3258c6SBoian Petkantchin #include "mlir/Pass/Pass.h"
14dc3258c6SBoian Petkantchin #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15dc3258c6SBoian Petkantchin 
16dc3258c6SBoian Petkantchin using namespace mlir;
17dc3258c6SBoian Petkantchin 
18dc3258c6SBoian Petkantchin namespace {
19dc3258c6SBoian Petkantchin 
20dc3258c6SBoian Petkantchin struct TestAllSliceOpLoweringPass
21dc3258c6SBoian Petkantchin     : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> {
22dc3258c6SBoian Petkantchin   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass)
23dc3258c6SBoian Petkantchin 
24dc3258c6SBoian Petkantchin   void runOnOperation() override {
25dc3258c6SBoian Petkantchin     RewritePatternSet patterns(&getContext());
26dc3258c6SBoian Petkantchin     SymbolTableCollection symbolTableCollection;
27dc3258c6SBoian Petkantchin     mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
28dc3258c6SBoian Petkantchin     LogicalResult status =
29*09dfc571SJacques Pienaar         applyPatternsGreedily(getOperation(), std::move(patterns));
30dc3258c6SBoian Petkantchin     (void)status;
31*09dfc571SJacques Pienaar     assert(succeeded(status) && "applyPatternsGreedily failed.");
32dc3258c6SBoian Petkantchin   }
33dc3258c6SBoian Petkantchin   void getDependentDialects(DialectRegistry &registry) const override {
34dc3258c6SBoian Petkantchin     mesh::registerAllSliceOpLoweringDialects(registry);
35dc3258c6SBoian Petkantchin   }
36dc3258c6SBoian Petkantchin   StringRef getArgument() const final {
37dc3258c6SBoian Petkantchin     return "test-mesh-all-slice-op-lowering";
38dc3258c6SBoian Petkantchin   }
39dc3258c6SBoian Petkantchin   StringRef getDescription() const final {
40dc3258c6SBoian Petkantchin     return "Test lowering of all-slice.";
41dc3258c6SBoian Petkantchin   }
42dc3258c6SBoian Petkantchin };
43dc3258c6SBoian Petkantchin 
44dc3258c6SBoian Petkantchin struct TestMultiIndexOpLoweringPass
45dc3258c6SBoian Petkantchin     : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
46dc3258c6SBoian Petkantchin   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
47dc3258c6SBoian Petkantchin 
48dc3258c6SBoian Petkantchin   void runOnOperation() override {
49dc3258c6SBoian Petkantchin     RewritePatternSet patterns(&getContext());
50dc3258c6SBoian Petkantchin     SymbolTableCollection symbolTableCollection;
51dc3258c6SBoian Petkantchin     mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
52dc3258c6SBoian Petkantchin                                                       symbolTableCollection);
53dc3258c6SBoian Petkantchin     LogicalResult status =
54*09dfc571SJacques Pienaar         applyPatternsGreedily(getOperation(), std::move(patterns));
55dc3258c6SBoian Petkantchin     (void)status;
56*09dfc571SJacques Pienaar     assert(succeeded(status) && "applyPatternsGreedily failed.");
57dc3258c6SBoian Petkantchin   }
58dc3258c6SBoian Petkantchin   void getDependentDialects(DialectRegistry &registry) const override {
59dc3258c6SBoian Petkantchin     mesh::registerProcessMultiIndexOpLoweringDialects(registry);
60dc3258c6SBoian Petkantchin   }
61dc3258c6SBoian Petkantchin   StringRef getArgument() const final {
62dc3258c6SBoian Petkantchin     return "test-mesh-process-multi-index-op-lowering";
63dc3258c6SBoian Petkantchin   }
64dc3258c6SBoian Petkantchin   StringRef getDescription() const final {
65dc3258c6SBoian Petkantchin     return "Test lowering of mesh.process_multi_index op.";
66dc3258c6SBoian Petkantchin   }
67dc3258c6SBoian Petkantchin };
68dc3258c6SBoian Petkantchin 
69dc3258c6SBoian Petkantchin } // namespace
70dc3258c6SBoian Petkantchin 
71dc3258c6SBoian Petkantchin namespace mlir {
72dc3258c6SBoian Petkantchin namespace test {
73dc3258c6SBoian Petkantchin void registerTestOpLoweringPasses() {
74dc3258c6SBoian Petkantchin   PassRegistration<TestAllSliceOpLoweringPass>();
75dc3258c6SBoian Petkantchin   PassRegistration<TestMultiIndexOpLoweringPass>();
76dc3258c6SBoian Petkantchin }
77dc3258c6SBoian Petkantchin } // namespace test
78dc3258c6SBoian Petkantchin } // namespace mlir
79