132e8b30dSMogball //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
232e8b30dSMogball //
332e8b30dSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
432e8b30dSMogball // See https://llvm.org/LICENSE.txt for license information.
532e8b30dSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
632e8b30dSMogball //
732e8b30dSMogball //===----------------------------------------------------------------------===//
832e8b30dSMogball
932e8b30dSMogball #include "mlir/IR/PatternMatch.h"
109eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
1132e8b30dSMogball #include "mlir/Pass/PassManager.h"
1232e8b30dSMogball #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1332e8b30dSMogball #include "mlir/Transforms/Passes.h"
1432e8b30dSMogball #include "gtest/gtest.h"
1532e8b30dSMogball
1632e8b30dSMogball using namespace mlir;
1732e8b30dSMogball
1832e8b30dSMogball namespace {
1932e8b30dSMogball
2032e8b30dSMogball struct DisabledPattern : public RewritePattern {
DisabledPattern__anon224433b80111::DisabledPattern2132e8b30dSMogball DisabledPattern(MLIRContext *context)
2232e8b30dSMogball : RewritePattern("test.foo", /*benefit=*/0, context,
2332e8b30dSMogball /*generatedNamed=*/{}) {
2432e8b30dSMogball setDebugName("DisabledPattern");
2532e8b30dSMogball }
2632e8b30dSMogball
matchAndRewrite__anon224433b80111::DisabledPattern2732e8b30dSMogball LogicalResult matchAndRewrite(Operation *op,
2832e8b30dSMogball PatternRewriter &rewriter) const override {
2932e8b30dSMogball if (op->getNumResults() != 1)
3032e8b30dSMogball return failure();
3132e8b30dSMogball rewriter.eraseOp(op);
3232e8b30dSMogball return success();
3332e8b30dSMogball }
3432e8b30dSMogball };
3532e8b30dSMogball
3632e8b30dSMogball struct EnabledPattern : public RewritePattern {
EnabledPattern__anon224433b80111::EnabledPattern3732e8b30dSMogball EnabledPattern(MLIRContext *context)
3832e8b30dSMogball : RewritePattern("test.foo", /*benefit=*/0, context,
3932e8b30dSMogball /*generatedNamed=*/{}) {
4032e8b30dSMogball setDebugName("EnabledPattern");
4132e8b30dSMogball }
4232e8b30dSMogball
matchAndRewrite__anon224433b80111::EnabledPattern4332e8b30dSMogball LogicalResult matchAndRewrite(Operation *op,
4432e8b30dSMogball PatternRewriter &rewriter) const override {
4532e8b30dSMogball if (op->getNumResults() == 1)
4632e8b30dSMogball return failure();
4732e8b30dSMogball rewriter.eraseOp(op);
4832e8b30dSMogball return success();
4932e8b30dSMogball }
5032e8b30dSMogball };
5132e8b30dSMogball
5232e8b30dSMogball struct TestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon224433b80111::TestDialect53*5e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
54*5e50dd04SRiver Riddle
5532e8b30dSMogball static StringRef getDialectNamespace() { return "test"; }
5632e8b30dSMogball
TestDialect__anon224433b80111::TestDialect5732e8b30dSMogball TestDialect(MLIRContext *context)
5832e8b30dSMogball : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
5932e8b30dSMogball allowUnknownOperations();
6032e8b30dSMogball }
6132e8b30dSMogball
getCanonicalizationPatterns__anon224433b80111::TestDialect6232e8b30dSMogball void getCanonicalizationPatterns(RewritePatternSet &results) const override {
63b4e0507cSTres Popp results.add<DisabledPattern, EnabledPattern>(results.getContext());
6432e8b30dSMogball }
6532e8b30dSMogball };
6632e8b30dSMogball
TEST(CanonicalizerTest,TestDisablePatterns)6732e8b30dSMogball TEST(CanonicalizerTest, TestDisablePatterns) {
6832e8b30dSMogball MLIRContext context;
6932e8b30dSMogball context.getOrLoadDialect<TestDialect>();
7032e8b30dSMogball PassManager mgr(&context);
7132e8b30dSMogball mgr.addPass(
7232e8b30dSMogball createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
7332e8b30dSMogball
7432e8b30dSMogball const char *const code = R"mlir(
7532e8b30dSMogball %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
7632e8b30dSMogball %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
7732e8b30dSMogball )mlir";
7832e8b30dSMogball
79dfaadf6bSChristian Sigg OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
8032e8b30dSMogball ASSERT_TRUE(succeeded(mgr.run(*module)));
8132e8b30dSMogball
8232e8b30dSMogball EXPECT_TRUE(module->lookupSymbol("B"));
8332e8b30dSMogball EXPECT_FALSE(module->lookupSymbol("A"));
8432e8b30dSMogball }
8532e8b30dSMogball
8632e8b30dSMogball } // end anonymous namespace
87