xref: /llvm-project/mlir/unittests/Transforms/Canonicalizer.cpp (revision 5e50dd048e3a20cde5da5d7a754dfee775ef35d6)
1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Parser/Parser.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "gtest/gtest.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 struct DisabledPattern : public RewritePattern {
DisabledPattern__anon224433b80111::DisabledPattern21   DisabledPattern(MLIRContext *context)
22       : RewritePattern("test.foo", /*benefit=*/0, context,
23                        /*generatedNamed=*/{}) {
24     setDebugName("DisabledPattern");
25   }
26 
matchAndRewrite__anon224433b80111::DisabledPattern27   LogicalResult matchAndRewrite(Operation *op,
28                                 PatternRewriter &rewriter) const override {
29     if (op->getNumResults() != 1)
30       return failure();
31     rewriter.eraseOp(op);
32     return success();
33   }
34 };
35 
36 struct EnabledPattern : public RewritePattern {
EnabledPattern__anon224433b80111::EnabledPattern37   EnabledPattern(MLIRContext *context)
38       : RewritePattern("test.foo", /*benefit=*/0, context,
39                        /*generatedNamed=*/{}) {
40     setDebugName("EnabledPattern");
41   }
42 
matchAndRewrite__anon224433b80111::EnabledPattern43   LogicalResult matchAndRewrite(Operation *op,
44                                 PatternRewriter &rewriter) const override {
45     if (op->getNumResults() == 1)
46       return failure();
47     rewriter.eraseOp(op);
48     return success();
49   }
50 };
51 
52 struct TestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon224433b80111::TestDialect53   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
54 
55   static StringRef getDialectNamespace() { return "test"; }
56 
TestDialect__anon224433b80111::TestDialect57   TestDialect(MLIRContext *context)
58       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
59     allowUnknownOperations();
60   }
61 
getCanonicalizationPatterns__anon224433b80111::TestDialect62   void getCanonicalizationPatterns(RewritePatternSet &results) const override {
63     results.add<DisabledPattern, EnabledPattern>(results.getContext());
64   }
65 };
66 
TEST(CanonicalizerTest,TestDisablePatterns)67 TEST(CanonicalizerTest, TestDisablePatterns) {
68   MLIRContext context;
69   context.getOrLoadDialect<TestDialect>();
70   PassManager mgr(&context);
71   mgr.addPass(
72       createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
73 
74   const char *const code = R"mlir(
75     %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
76     %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
77   )mlir";
78 
79   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
80   ASSERT_TRUE(succeeded(mgr.run(*module)));
81 
82   EXPECT_TRUE(module->lookupSymbol("B"));
83   EXPECT_FALSE(module->lookupSymbol("A"));
84 }
85 
86 } // end anonymous namespace
87