1 //===- TestDataLayoutPropagation.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 7 #include "mlir/Dialect/Affine/IR/AffineOps.h" 8 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 9 #include "mlir/Pass/Pass.h" 10 #include "mlir/Pass/PassManager.h" 11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 12 13 using namespace mlir; 14 15 namespace { 16 struct TestDataLayoutPropagationPass 17 : public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> { 18 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass) 19 20 void getDependentDialects(DialectRegistry ®istry) const override { 21 registry.insert<affine::AffineDialect, linalg::LinalgDialect, 22 tensor::TensorDialect>(); 23 } 24 25 StringRef getArgument() const final { 26 return "test-linalg-data-layout-propagation"; 27 } 28 StringRef getDescription() const final { 29 return "Test data layout propagation"; 30 } 31 32 void runOnOperation() override { 33 MLIRContext *context = &getContext(); 34 RewritePatternSet patterns(context); 35 linalg::populateDataLayoutPropagationPatterns( 36 patterns, [](OpOperand *opOperand) { return true; }); 37 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 38 return signalPassFailure(); 39 } 40 }; 41 } // namespace 42 43 namespace mlir { 44 namespace test { 45 void registerTestDataLayoutPropagation() { 46 PassRegistration<TestDataLayoutPropagationPass>(); 47 } 48 } // namespace test 49 } // namespace mlir 50