xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp (revision 4dbaef6d5ea71fb183114a82da4028960906c42b)
1 //===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing the transformation to drop unit
10 // extent dimensions from `linalg.generic` operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
25                                     linalg::GenericOp genericOp) {
26   linalg::ControlDropUnitDims options;
27   options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
28   FailureOr<linalg::DropUnitDimsResult> result =
29       linalg::dropUnitDims(rewriter, genericOp, options);
30   if (failed(result)) {
31     return failure();
32   }
33   rewriter.replaceOp(genericOp, result->replacements);
34   return success();
35 }
36 
37 struct TestLinalgDropUnitDims
38     : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
39 
40   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
41 
42   TestLinalgDropUnitDims() = default;
43   TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) = default;
44 
45   void getDependentDialects(DialectRegistry &registry) const override {
46     registry.insert<linalg::LinalgDialect>();
47   }
48 
49   StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; }
50 
51   StringRef getDescriptions() const {
52     return "Test transformation to drop unit-extent dims from Linalg "
53            "operations";
54   }
55 
56   void runOnOperation() override {
57     MLIRContext *context = &this->getContext();
58     func::FuncOp funcOp = this->getOperation();
59     IRRewriter rewriter(context);
60     SmallVector<linalg::GenericOp> genericOps;
61     funcOp.walk(
62         [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); });
63 
64     for (auto genericOp : genericOps) {
65       rewriter.setInsertionPoint(genericOp);
66       (void)dropOutermostUnitDims(rewriter, genericOp);
67     }
68   }
69 };
70 } // namespace
71 
72 namespace mlir {
73 namespace test {
74 void registerTestLinalgDropUnitDims() {
75   PassRegistration<TestLinalgDropUnitDims>();
76 }
77 } // namespace test
78 } // namespace mlir
79