10f297cadSHanhan Wang //===- TestDataLayoutPropagation.cpp --------------------------------------===// 20f297cadSHanhan Wang // 30f297cadSHanhan Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40f297cadSHanhan Wang // See https://llvm.org/LICENSE.txt for license information. 50f297cadSHanhan Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60f297cadSHanhan Wang 70f297cadSHanhan Wang #include "mlir/Dialect/Affine/IR/AffineOps.h" 80f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 90f297cadSHanhan Wang #include "mlir/Pass/Pass.h" 100f297cadSHanhan Wang #include "mlir/Pass/PassManager.h" 110f297cadSHanhan Wang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 120f297cadSHanhan Wang 130f297cadSHanhan Wang using namespace mlir; 140f297cadSHanhan Wang 150f297cadSHanhan Wang namespace { 160f297cadSHanhan Wang struct TestDataLayoutPropagationPass 170f297cadSHanhan Wang : public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> { 180f297cadSHanhan Wang MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass) 190f297cadSHanhan Wang 200f297cadSHanhan Wang void getDependentDialects(DialectRegistry ®istry) const override { 214c48f016SMatthias Springer registry.insert<affine::AffineDialect, linalg::LinalgDialect, 224c48f016SMatthias Springer tensor::TensorDialect>(); 230f297cadSHanhan Wang } 240f297cadSHanhan Wang 250f297cadSHanhan Wang StringRef getArgument() const final { 260f297cadSHanhan Wang return "test-linalg-data-layout-propagation"; 270f297cadSHanhan Wang } 280f297cadSHanhan Wang StringRef getDescription() const final { 290f297cadSHanhan Wang return "Test data layout propagation"; 300f297cadSHanhan Wang } 310f297cadSHanhan Wang 320f297cadSHanhan Wang void runOnOperation() override { 330f297cadSHanhan Wang MLIRContext *context = &getContext(); 340f297cadSHanhan Wang RewritePatternSet patterns(context); 35b4563ee1SQuinn Dawkins linalg::populateDataLayoutPropagationPatterns( 3604fc471fSHan-Chung Wang patterns, [](OpOperand *opOperand) { return true; }); 37*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 380f297cadSHanhan Wang return signalPassFailure(); 390f297cadSHanhan Wang } 400f297cadSHanhan Wang }; 410f297cadSHanhan Wang } // namespace 420f297cadSHanhan Wang 430f297cadSHanhan Wang namespace mlir { 440f297cadSHanhan Wang namespace test { 450f297cadSHanhan Wang void registerTestDataLayoutPropagation() { 460f297cadSHanhan Wang PassRegistration<TestDataLayoutPropagationPass>(); 470f297cadSHanhan Wang } 480f297cadSHanhan Wang } // namespace test 490f297cadSHanhan Wang } // namespace mlir 50