1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/DialectInterface.h" 11 #include "gtest/gtest.h" 12 13 using namespace mlir; 14 using namespace mlir::detail; 15 16 namespace { 17 struct TestDialect : public Dialect { 18 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) 19 20 static StringRef getDialectNamespace() { return "test"; }; 21 TestDialect(MLIRContext *context) 22 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} 23 }; 24 struct AnotherTestDialect : public Dialect { 25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect) 26 27 static StringRef getDialectNamespace() { return "test"; }; 28 AnotherTestDialect(MLIRContext *context) 29 : Dialect(getDialectNamespace(), context, 30 TypeID::get<AnotherTestDialect>()) {} 31 }; 32 33 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { 34 MLIRContext context; 35 36 // Registering a dialect with the same namespace twice should result in a 37 // failure. 38 context.loadDialect<TestDialect>(); 39 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), ""); 40 } 41 42 struct SecondTestDialect : public Dialect { 43 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect) 44 45 static StringRef getDialectNamespace() { return "test2"; } 46 SecondTestDialect(MLIRContext *context) 47 : Dialect(getDialectNamespace(), context, 48 TypeID::get<SecondTestDialect>()) {} 49 }; 50 51 struct TestDialectInterfaceBase 52 : public DialectInterface::Base<TestDialectInterfaceBase> { 53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase) 54 55 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} 56 virtual int function() const { return 42; } 57 }; 58 59 struct TestDialectInterface : public TestDialectInterfaceBase { 60 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface) 61 62 using TestDialectInterfaceBase::TestDialectInterfaceBase; 63 int function() const final { return 56; } 64 }; 65 66 struct SecondTestDialectInterface : public TestDialectInterfaceBase { 67 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface) 68 69 using TestDialectInterfaceBase::TestDialectInterfaceBase; 70 int function() const final { return 78; } 71 }; 72 73 TEST(Dialect, DelayedInterfaceRegistration) { 74 DialectRegistry registry; 75 registry.insert<TestDialect, SecondTestDialect>(); 76 77 // Delayed registration of an interface for TestDialect. 78 registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 79 dialect->addInterfaces<TestDialectInterface>(); 80 }); 81 82 MLIRContext context(registry); 83 84 // Load the TestDialect and check that the interface got registered for it. 85 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 86 ASSERT_TRUE(testDialect != nullptr); 87 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 88 EXPECT_TRUE(testDialectInterface != nullptr); 89 90 // Load the SecondTestDialect and check that the interface is not registered 91 // for it. 92 Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); 93 ASSERT_TRUE(secondTestDialect != nullptr); 94 auto *secondTestDialectInterface = 95 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 96 EXPECT_TRUE(secondTestDialectInterface == nullptr); 97 98 // Use the same mechanism as for delayed registration but for an already 99 // loaded dialect and check that the interface is now registered. 100 DialectRegistry secondRegistry; 101 secondRegistry.insert<SecondTestDialect>(); 102 secondRegistry.addExtension( 103 +[](MLIRContext *ctx, SecondTestDialect *dialect) { 104 dialect->addInterfaces<SecondTestDialectInterface>(); 105 }); 106 context.appendDialectRegistry(secondRegistry); 107 secondTestDialectInterface = 108 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 109 EXPECT_TRUE(secondTestDialectInterface != nullptr); 110 } 111 112 TEST(Dialect, RepeatedDelayedRegistration) { 113 // Set up the delayed registration. 114 DialectRegistry registry; 115 registry.insert<TestDialect>(); 116 registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 117 dialect->addInterfaces<TestDialectInterface>(); 118 }); 119 MLIRContext context(registry); 120 121 // Load the TestDialect and check that the interface got registered for it. 122 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 123 ASSERT_TRUE(testDialect != nullptr); 124 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 125 EXPECT_TRUE(testDialectInterface != nullptr); 126 127 // Try adding the same dialect interface again and check that we don't crash 128 // on repeated interface registration. 129 DialectRegistry secondRegistry; 130 secondRegistry.insert<TestDialect>(); 131 secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 132 dialect->addInterfaces<TestDialectInterface>(); 133 }); 134 context.appendDialectRegistry(secondRegistry); 135 testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 136 EXPECT_TRUE(testDialectInterface != nullptr); 137 } 138 139 } // namespace 140