xref: /llvm-project/mlir/unittests/IR/PatternMatchTest.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
1 //===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "gtest/gtest.h"
11 
12 #include "../../test/lib/Dialect/Test/TestDialect.h"
13 #include "../../test/lib/Dialect/Test/TestOps.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 struct AnOpRewritePattern : OpRewritePattern<test::OpA> {
AnOpRewritePattern__anon4ae64e8d0111::AnOpRewritePattern19   AnOpRewritePattern(MLIRContext *context)
20       : OpRewritePattern(context, /*benefit=*/1,
21                          /*generatedNames=*/{test::OpB::getOperationName()}) {}
22 };
TEST(OpRewritePatternTest,GetGeneratedNames)23 TEST(OpRewritePatternTest, GetGeneratedNames) {
24   MLIRContext context;
25   AnOpRewritePattern pattern(&context);
26   ArrayRef<OperationName> ops = pattern.getGeneratedOps();
27 
28   ASSERT_EQ(ops.size(), 1u);
29   ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName());
30 }
31 } // end anonymous namespace
32 
33 namespace {
anOpRewritePatternFunc(test::OpA op,PatternRewriter & rewriter)34 LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) {
35   return failure();
36 }
TEST(AnOpRewritePatternTest,PatternFuncAttributes)37 TEST(AnOpRewritePatternTest, PatternFuncAttributes) {
38   MLIRContext context;
39   RewritePatternSet patterns(&context);
40 
41   patterns.add(anOpRewritePatternFunc, /*benefit=*/3,
42                /*generatedNames=*/{test::OpB::getOperationName()});
43   ASSERT_EQ(patterns.getNativePatterns().size(), 1U);
44   auto &pattern = patterns.getNativePatterns().front();
45   ASSERT_EQ(pattern->getBenefit(), 3);
46   ASSERT_EQ(pattern->getGeneratedOps().size(), 1U);
47   ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(),
48             test::OpB::getOperationName());
49 }
50 } // end anonymous namespace
51