1dc3258c6SBoian Petkantchin //===- TestProcessMultiIndexOpLowering.cpp --------------------------------===// 2dc3258c6SBoian Petkantchin // 3dc3258c6SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4dc3258c6SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information. 5dc3258c6SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6dc3258c6SBoian Petkantchin // 7dc3258c6SBoian Petkantchin //===----------------------------------------------------------------------===// 8dc3258c6SBoian Petkantchin 9dc3258c6SBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h" 10dc3258c6SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Transforms.h" 11dc3258c6SBoian Petkantchin #include "mlir/Dialect/Utils/IndexingUtils.h" 12dc3258c6SBoian Petkantchin #include "mlir/IR/SymbolTable.h" 13dc3258c6SBoian Petkantchin #include "mlir/Pass/Pass.h" 14dc3258c6SBoian Petkantchin #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15dc3258c6SBoian Petkantchin 16dc3258c6SBoian Petkantchin using namespace mlir; 17dc3258c6SBoian Petkantchin 18dc3258c6SBoian Petkantchin namespace { 19dc3258c6SBoian Petkantchin 20dc3258c6SBoian Petkantchin struct TestAllSliceOpLoweringPass 21dc3258c6SBoian Petkantchin : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> { 22dc3258c6SBoian Petkantchin MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass) 23dc3258c6SBoian Petkantchin 24dc3258c6SBoian Petkantchin void runOnOperation() override { 25dc3258c6SBoian Petkantchin RewritePatternSet patterns(&getContext()); 26dc3258c6SBoian Petkantchin SymbolTableCollection symbolTableCollection; 27dc3258c6SBoian Petkantchin mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); 28dc3258c6SBoian Petkantchin LogicalResult status = 29*09dfc571SJacques Pienaar applyPatternsGreedily(getOperation(), std::move(patterns)); 30dc3258c6SBoian Petkantchin (void)status; 31*09dfc571SJacques Pienaar assert(succeeded(status) && "applyPatternsGreedily failed."); 32dc3258c6SBoian Petkantchin } 33dc3258c6SBoian Petkantchin void getDependentDialects(DialectRegistry ®istry) const override { 34dc3258c6SBoian Petkantchin mesh::registerAllSliceOpLoweringDialects(registry); 35dc3258c6SBoian Petkantchin } 36dc3258c6SBoian Petkantchin StringRef getArgument() const final { 37dc3258c6SBoian Petkantchin return "test-mesh-all-slice-op-lowering"; 38dc3258c6SBoian Petkantchin } 39dc3258c6SBoian Petkantchin StringRef getDescription() const final { 40dc3258c6SBoian Petkantchin return "Test lowering of all-slice."; 41dc3258c6SBoian Petkantchin } 42dc3258c6SBoian Petkantchin }; 43dc3258c6SBoian Petkantchin 44dc3258c6SBoian Petkantchin struct TestMultiIndexOpLoweringPass 45dc3258c6SBoian Petkantchin : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> { 46dc3258c6SBoian Petkantchin MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass) 47dc3258c6SBoian Petkantchin 48dc3258c6SBoian Petkantchin void runOnOperation() override { 49dc3258c6SBoian Petkantchin RewritePatternSet patterns(&getContext()); 50dc3258c6SBoian Petkantchin SymbolTableCollection symbolTableCollection; 51dc3258c6SBoian Petkantchin mesh::populateProcessMultiIndexOpLoweringPatterns(patterns, 52dc3258c6SBoian Petkantchin symbolTableCollection); 53dc3258c6SBoian Petkantchin LogicalResult status = 54*09dfc571SJacques Pienaar applyPatternsGreedily(getOperation(), std::move(patterns)); 55dc3258c6SBoian Petkantchin (void)status; 56*09dfc571SJacques Pienaar assert(succeeded(status) && "applyPatternsGreedily failed."); 57dc3258c6SBoian Petkantchin } 58dc3258c6SBoian Petkantchin void getDependentDialects(DialectRegistry ®istry) const override { 59dc3258c6SBoian Petkantchin mesh::registerProcessMultiIndexOpLoweringDialects(registry); 60dc3258c6SBoian Petkantchin } 61dc3258c6SBoian Petkantchin StringRef getArgument() const final { 62dc3258c6SBoian Petkantchin return "test-mesh-process-multi-index-op-lowering"; 63dc3258c6SBoian Petkantchin } 64dc3258c6SBoian Petkantchin StringRef getDescription() const final { 65dc3258c6SBoian Petkantchin return "Test lowering of mesh.process_multi_index op."; 66dc3258c6SBoian Petkantchin } 67dc3258c6SBoian Petkantchin }; 68dc3258c6SBoian Petkantchin 69dc3258c6SBoian Petkantchin } // namespace 70dc3258c6SBoian Petkantchin 71dc3258c6SBoian Petkantchin namespace mlir { 72dc3258c6SBoian Petkantchin namespace test { 73dc3258c6SBoian Petkantchin void registerTestOpLoweringPasses() { 74dc3258c6SBoian Petkantchin PassRegistration<TestAllSliceOpLoweringPass>(); 75dc3258c6SBoian Petkantchin PassRegistration<TestMultiIndexOpLoweringPass>(); 76dc3258c6SBoian Petkantchin } 77dc3258c6SBoian Petkantchin } // namespace test 78dc3258c6SBoian Petkantchin } // namespace mlir 79