1 //===- TestProcessMultiIndexOpLowering.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Mesh/Transforms/Transforms.h" 11 #include "mlir/Dialect/Utils/IndexingUtils.h" 12 #include "mlir/IR/SymbolTable.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 using namespace mlir; 17 18 namespace { 19 20 struct TestAllSliceOpLoweringPass 21 : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> { 22 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass) 23 24 void runOnOperation() override { 25 RewritePatternSet patterns(&getContext()); 26 SymbolTableCollection symbolTableCollection; 27 mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); 28 LogicalResult status = 29 applyPatternsGreedily(getOperation(), std::move(patterns)); 30 (void)status; 31 assert(succeeded(status) && "applyPatternsGreedily failed."); 32 } 33 void getDependentDialects(DialectRegistry ®istry) const override { 34 mesh::registerAllSliceOpLoweringDialects(registry); 35 } 36 StringRef getArgument() const final { 37 return "test-mesh-all-slice-op-lowering"; 38 } 39 StringRef getDescription() const final { 40 return "Test lowering of all-slice."; 41 } 42 }; 43 44 struct TestMultiIndexOpLoweringPass 45 : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> { 46 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass) 47 48 void runOnOperation() override { 49 RewritePatternSet patterns(&getContext()); 50 SymbolTableCollection symbolTableCollection; 51 mesh::populateProcessMultiIndexOpLoweringPatterns(patterns, 52 symbolTableCollection); 53 LogicalResult status = 54 applyPatternsGreedily(getOperation(), std::move(patterns)); 55 (void)status; 56 assert(succeeded(status) && "applyPatternsGreedily failed."); 57 } 58 void getDependentDialects(DialectRegistry ®istry) const override { 59 mesh::registerProcessMultiIndexOpLoweringDialects(registry); 60 } 61 StringRef getArgument() const final { 62 return "test-mesh-process-multi-index-op-lowering"; 63 } 64 StringRef getDescription() const final { 65 return "Test lowering of mesh.process_multi_index op."; 66 } 67 }; 68 69 } // namespace 70 71 namespace mlir { 72 namespace test { 73 void registerTestOpLoweringPasses() { 74 PassRegistration<TestAllSliceOpLoweringPass>(); 75 PassRegistration<TestMultiIndexOpLoweringPass>(); 76 } 77 } // namespace test 78 } // namespace mlir 79