1 //===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Pass/Pass.h" 10 #include "llvm/ADT/STLExtras.h" 11 12 #include "gmock/gmock.h" 13 14 std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0); 15 16 #define GEN_PASS_DECL 17 #define GEN_PASS_REGISTRATION 18 #include "PassGenTest.h.inc" 19 20 #define GEN_PASS_DEF_TESTPASS 21 #define GEN_PASS_DEF_TESTPASSWITHOPTIONS 22 #define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR 23 #include "PassGenTest.h.inc" 24 25 struct TestPass : public impl::TestPassBase<TestPass> { 26 using TestPassBase::TestPassBase; 27 28 void runOnOperation() override {} 29 30 std::unique_ptr<mlir::Pass> clone() const { 31 return TestPassBase<TestPass>::clone(); 32 } 33 }; 34 35 TEST(PassGenTest, defaultGeneratedConstructor) { 36 std::unique_ptr<mlir::Pass> pass = createTestPass(); 37 EXPECT_TRUE(pass.get() != nullptr); 38 } 39 40 TEST(PassGenTest, PassClone) { 41 mlir::MLIRContext context; 42 43 const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) { 44 return static_cast<const TestPass *>(pass.get()); 45 }; 46 47 const auto origPass = createTestPass(); 48 const auto clonePass = unwrap(origPass)->clone(); 49 50 EXPECT_TRUE(clonePass.get() != nullptr); 51 EXPECT_TRUE(origPass.get() != clonePass.get()); 52 } 53 54 struct TestPassWithOptions 55 : public impl::TestPassWithOptionsBase<TestPassWithOptions> { 56 using TestPassWithOptionsBase::TestPassWithOptionsBase; 57 58 void runOnOperation() override {} 59 60 std::unique_ptr<mlir::Pass> clone() const { 61 return TestPassWithOptionsBase<TestPassWithOptions>::clone(); 62 } 63 64 int getTestOption() const { return testOption; } 65 66 llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; } 67 }; 68 69 TEST(PassGenTest, PassOptions) { 70 mlir::MLIRContext context; 71 72 TestPassWithOptionsOptions options; 73 options.testOption = 57; 74 75 options.testListOption = {1, 2}; 76 77 const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) { 78 return static_cast<const TestPassWithOptions *>(pass.get()); 79 }; 80 81 const auto pass = createTestPassWithOptions(options); 82 83 EXPECT_EQ(unwrap(pass)->getTestOption(), 57); 84 EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1); 85 EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2); 86 } 87 88 struct TestPassWithCustomConstructor 89 : public impl::TestPassWithCustomConstructorBase< 90 TestPassWithCustomConstructor> { 91 explicit TestPassWithCustomConstructor(int v) : extraVal(v) {} 92 93 void runOnOperation() override {} 94 95 std::unique_ptr<mlir::Pass> clone() const { 96 return TestPassWithCustomConstructorBase< 97 TestPassWithCustomConstructor>::clone(); 98 } 99 100 unsigned int extraVal = 23; 101 }; 102 103 std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) { 104 return std::make_unique<TestPassWithCustomConstructor>(v); 105 } 106 107 TEST(PassGenTest, PassCloneWithCustomConstructor) { 108 mlir::MLIRContext context; 109 110 const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) { 111 return static_cast<const TestPassWithCustomConstructor *>(pass.get()); 112 }; 113 114 const auto origPass = createTestPassWithCustomConstructor(10); 115 const auto clonePass = unwrap(origPass)->clone(); 116 117 EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal); 118 } 119