xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
10f297cadSHanhan Wang //===- TestDataLayoutPropagation.cpp --------------------------------------===//
20f297cadSHanhan Wang //
30f297cadSHanhan Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40f297cadSHanhan Wang // See https://llvm.org/LICENSE.txt for license information.
50f297cadSHanhan Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60f297cadSHanhan Wang 
70f297cadSHanhan Wang #include "mlir/Dialect/Affine/IR/AffineOps.h"
80f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
90f297cadSHanhan Wang #include "mlir/Pass/Pass.h"
100f297cadSHanhan Wang #include "mlir/Pass/PassManager.h"
110f297cadSHanhan Wang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
120f297cadSHanhan Wang 
130f297cadSHanhan Wang using namespace mlir;
140f297cadSHanhan Wang 
150f297cadSHanhan Wang namespace {
160f297cadSHanhan Wang struct TestDataLayoutPropagationPass
170f297cadSHanhan Wang     : public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> {
180f297cadSHanhan Wang   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass)
190f297cadSHanhan Wang 
200f297cadSHanhan Wang   void getDependentDialects(DialectRegistry &registry) const override {
214c48f016SMatthias Springer     registry.insert<affine::AffineDialect, linalg::LinalgDialect,
224c48f016SMatthias Springer                     tensor::TensorDialect>();
230f297cadSHanhan Wang   }
240f297cadSHanhan Wang 
250f297cadSHanhan Wang   StringRef getArgument() const final {
260f297cadSHanhan Wang     return "test-linalg-data-layout-propagation";
270f297cadSHanhan Wang   }
280f297cadSHanhan Wang   StringRef getDescription() const final {
290f297cadSHanhan Wang     return "Test data layout propagation";
300f297cadSHanhan Wang   }
310f297cadSHanhan Wang 
320f297cadSHanhan Wang   void runOnOperation() override {
330f297cadSHanhan Wang     MLIRContext *context = &getContext();
340f297cadSHanhan Wang     RewritePatternSet patterns(context);
35b4563ee1SQuinn Dawkins     linalg::populateDataLayoutPropagationPatterns(
3604fc471fSHan-Chung Wang         patterns, [](OpOperand *opOperand) { return true; });
37*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
380f297cadSHanhan Wang       return signalPassFailure();
390f297cadSHanhan Wang   }
400f297cadSHanhan Wang };
410f297cadSHanhan Wang } // namespace
420f297cadSHanhan Wang 
430f297cadSHanhan Wang namespace mlir {
440f297cadSHanhan Wang namespace test {
450f297cadSHanhan Wang void registerTestDataLayoutPropagation() {
460f297cadSHanhan Wang   PassRegistration<TestDataLayoutPropagationPass>();
470f297cadSHanhan Wang }
480f297cadSHanhan Wang } // namespace test
490f297cadSHanhan Wang } // namespace mlir
50