xref: /llvm-project/mlir/unittests/TableGen/PassGenTest.cpp (revision fef3566a25ff0e34fb87339ba5e13eca17cec00f)
1 //===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/Pass.h"
10 #include "llvm/ADT/STLExtras.h"
11 
12 #include "gmock/gmock.h"
13 
14 std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0);
15 
16 #define GEN_PASS_DECL
17 #define GEN_PASS_REGISTRATION
18 #include "PassGenTest.h.inc"
19 
20 #define GEN_PASS_DEF_TESTPASS
21 #define GEN_PASS_DEF_TESTPASSWITHOPTIONS
22 #define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR
23 #include "PassGenTest.h.inc"
24 
25 struct TestPass : public impl::TestPassBase<TestPass> {
26   using TestPassBase::TestPassBase;
27 
28   void runOnOperation() override {}
29 
30   std::unique_ptr<mlir::Pass> clone() const {
31     return TestPassBase<TestPass>::clone();
32   }
33 };
34 
35 TEST(PassGenTest, defaultGeneratedConstructor) {
36   std::unique_ptr<mlir::Pass> pass = createTestPass();
37   EXPECT_TRUE(pass.get() != nullptr);
38 }
39 
40 TEST(PassGenTest, PassClone) {
41   mlir::MLIRContext context;
42 
43   const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
44     return static_cast<const TestPass *>(pass.get());
45   };
46 
47   const auto origPass = createTestPass();
48   const auto clonePass = unwrap(origPass)->clone();
49 
50   EXPECT_TRUE(clonePass.get() != nullptr);
51   EXPECT_TRUE(origPass.get() != clonePass.get());
52 }
53 
54 struct TestPassWithOptions
55     : public impl::TestPassWithOptionsBase<TestPassWithOptions> {
56   using TestPassWithOptionsBase::TestPassWithOptionsBase;
57 
58   void runOnOperation() override {}
59 
60   std::unique_ptr<mlir::Pass> clone() const {
61     return TestPassWithOptionsBase<TestPassWithOptions>::clone();
62   }
63 
64   int getTestOption() const { return testOption; }
65 
66   llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; }
67 };
68 
69 TEST(PassGenTest, PassOptions) {
70   mlir::MLIRContext context;
71 
72   TestPassWithOptionsOptions options;
73   options.testOption = 57;
74 
75   options.testListOption = {1, 2};
76 
77   const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
78     return static_cast<const TestPassWithOptions *>(pass.get());
79   };
80 
81   const auto pass = createTestPassWithOptions(options);
82 
83   EXPECT_EQ(unwrap(pass)->getTestOption(), 57);
84   EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1);
85   EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2);
86 }
87 
88 struct TestPassWithCustomConstructor
89     : public impl::TestPassWithCustomConstructorBase<
90           TestPassWithCustomConstructor> {
91   explicit TestPassWithCustomConstructor(int v) : extraVal(v) {}
92 
93   void runOnOperation() override {}
94 
95   std::unique_ptr<mlir::Pass> clone() const {
96     return TestPassWithCustomConstructorBase<
97         TestPassWithCustomConstructor>::clone();
98   }
99 
100   unsigned int extraVal = 23;
101 };
102 
103 std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) {
104   return std::make_unique<TestPassWithCustomConstructor>(v);
105 }
106 
107 TEST(PassGenTest, PassCloneWithCustomConstructor) {
108   mlir::MLIRContext context;
109 
110   const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
111     return static_cast<const TestPassWithCustomConstructor *>(pass.get());
112   };
113 
114   const auto origPass = createTestPassWithCustomConstructor(10);
115   const auto clonePass = unwrap(origPass)->clone();
116 
117   EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);
118 }
119