1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/DialectInterface.h" 11 #include "gtest/gtest.h" 12 13 using namespace mlir; 14 using namespace mlir::detail; 15 16 namespace { 17 struct TestDialect : public Dialect { 18 static StringRef getDialectNamespace() { return "test"; }; 19 TestDialect(MLIRContext *context) 20 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} 21 }; 22 struct AnotherTestDialect : public Dialect { 23 static StringRef getDialectNamespace() { return "test"; }; 24 AnotherTestDialect(MLIRContext *context) 25 : Dialect(getDialectNamespace(), context, 26 TypeID::get<AnotherTestDialect>()) {} 27 }; 28 29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { 30 MLIRContext context; 31 32 // Registering a dialect with the same namespace twice should result in a 33 // failure. 34 context.loadDialect<TestDialect>(); 35 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), ""); 36 } 37 38 struct SecondTestDialect : public Dialect { 39 static StringRef getDialectNamespace() { return "test2"; } 40 SecondTestDialect(MLIRContext *context) 41 : Dialect(getDialectNamespace(), context, 42 TypeID::get<SecondTestDialect>()) {} 43 }; 44 45 struct TestDialectInterfaceBase 46 : public DialectInterface::Base<TestDialectInterfaceBase> { 47 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} 48 virtual int function() const { return 42; } 49 }; 50 51 struct TestDialectInterface : public TestDialectInterfaceBase { 52 using TestDialectInterfaceBase::TestDialectInterfaceBase; 53 int function() const final { return 56; } 54 }; 55 56 struct SecondTestDialectInterface : public TestDialectInterfaceBase { 57 using TestDialectInterfaceBase::TestDialectInterfaceBase; 58 int function() const final { return 78; } 59 }; 60 61 TEST(Dialect, DelayedInterfaceRegistration) { 62 DialectRegistry registry; 63 registry.insert<TestDialect, SecondTestDialect>(); 64 65 // Delayed registration of an interface for TestDialect. 66 registry.addDialectInterface<TestDialect, TestDialectInterface>(); 67 68 MLIRContext context(registry); 69 70 // Load the TestDialect and check that the interface got registered for it. 71 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 72 ASSERT_TRUE(testDialect != nullptr); 73 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 74 EXPECT_TRUE(testDialectInterface != nullptr); 75 76 // Load the SecondTestDialect and check that the interface is not registered 77 // for it. 78 Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); 79 ASSERT_TRUE(secondTestDialect != nullptr); 80 auto *secondTestDialectInterface = 81 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 82 EXPECT_TRUE(secondTestDialectInterface == nullptr); 83 84 // Use the same mechanism as for delayed registration but for an already 85 // loaded dialect and check that the interface is now registered. 86 DialectRegistry secondRegistry; 87 secondRegistry.insert<SecondTestDialect>(); 88 secondRegistry 89 .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>(); 90 context.appendDialectRegistry(secondRegistry); 91 secondTestDialectInterface = 92 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 93 EXPECT_TRUE(secondTestDialectInterface != nullptr); 94 } 95 96 TEST(Dialect, RepeatedDelayedRegistration) { 97 // Set up the delayed registration. 98 DialectRegistry registry; 99 registry.insert<TestDialect>(); 100 registry.addDialectInterface<TestDialect, TestDialectInterface>(); 101 MLIRContext context(registry); 102 103 // Load the TestDialect and check that the interface got registered for it. 104 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 105 ASSERT_TRUE(testDialect != nullptr); 106 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 107 EXPECT_TRUE(testDialectInterface != nullptr); 108 109 // Try adding the same dialect interface again and check that we don't crash 110 // on repeated interface registration. 111 DialectRegistry secondRegistry; 112 secondRegistry.insert<TestDialect>(); 113 secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>(); 114 context.appendDialectRegistry(secondRegistry); 115 testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 116 EXPECT_TRUE(testDialectInterface != nullptr); 117 } 118 119 // A dialect that registers two interfaces with the same InterfaceID, triggering 120 // an assertion failure. 121 struct RepeatedRegistrationDialect : public Dialect { 122 static StringRef getDialectNamespace() { return "repeatedreg"; } 123 RepeatedRegistrationDialect(MLIRContext *context) 124 : Dialect(getDialectNamespace(), context, 125 TypeID::get<RepeatedRegistrationDialect>()) { 126 addInterfaces<TestDialectInterface>(); 127 addInterfaces<SecondTestDialectInterface>(); 128 } 129 }; 130 131 TEST(Dialect, RepeatedInterfaceRegistrationDeath) { 132 MLIRContext context; 133 (void)context; 134 135 // This triggers an assertion in debug mode. 136 #ifndef NDEBUG 137 ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(), 138 "interface kind has already been registered"); 139 #endif 140 } 141 142 } // namespace 143