xref: /llvm-project/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 struct TestAllSliceOpLoweringPass
21     : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> {
22   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass)
23 
24   void runOnOperation() override {
25     RewritePatternSet patterns(&getContext());
26     SymbolTableCollection symbolTableCollection;
27     mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
28     LogicalResult status =
29         applyPatternsGreedily(getOperation(), std::move(patterns));
30     (void)status;
31     assert(succeeded(status) && "applyPatternsGreedily failed.");
32   }
33   void getDependentDialects(DialectRegistry &registry) const override {
34     mesh::registerAllSliceOpLoweringDialects(registry);
35   }
36   StringRef getArgument() const final {
37     return "test-mesh-all-slice-op-lowering";
38   }
39   StringRef getDescription() const final {
40     return "Test lowering of all-slice.";
41   }
42 };
43 
44 struct TestMultiIndexOpLoweringPass
45     : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
46   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
47 
48   void runOnOperation() override {
49     RewritePatternSet patterns(&getContext());
50     SymbolTableCollection symbolTableCollection;
51     mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
52                                                       symbolTableCollection);
53     LogicalResult status =
54         applyPatternsGreedily(getOperation(), std::move(patterns));
55     (void)status;
56     assert(succeeded(status) && "applyPatternsGreedily failed.");
57   }
58   void getDependentDialects(DialectRegistry &registry) const override {
59     mesh::registerProcessMultiIndexOpLoweringDialects(registry);
60   }
61   StringRef getArgument() const final {
62     return "test-mesh-process-multi-index-op-lowering";
63   }
64   StringRef getDescription() const final {
65     return "Test lowering of mesh.process_multi_index op.";
66   }
67 };
68 
69 } // namespace
70 
71 namespace mlir {
72 namespace test {
73 void registerTestOpLoweringPasses() {
74   PassRegistration<TestAllSliceOpLoweringPass>();
75   PassRegistration<TestMultiIndexOpLoweringPass>();
76 }
77 } // namespace test
78 } // namespace mlir
79