1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/DialectInterface.h" 11 #include "gtest/gtest.h" 12 13 using namespace mlir; 14 using namespace mlir::detail; 15 16 namespace { 17 struct TestDialect : public Dialect { 18 static StringRef getDialectNamespace() { return "test"; }; 19 TestDialect(MLIRContext *context) 20 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} 21 }; 22 struct AnotherTestDialect : public Dialect { 23 static StringRef getDialectNamespace() { return "test"; }; 24 AnotherTestDialect(MLIRContext *context) 25 : Dialect(getDialectNamespace(), context, 26 TypeID::get<AnotherTestDialect>()) {} 27 }; 28 29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { 30 MLIRContext context; 31 32 // Registering a dialect with the same namespace twice should result in a 33 // failure. 34 context.loadDialect<TestDialect>(); 35 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), ""); 36 } 37 38 struct SecondTestDialect : public Dialect { 39 static StringRef getDialectNamespace() { return "test2"; } 40 SecondTestDialect(MLIRContext *context) 41 : Dialect(getDialectNamespace(), context, 42 TypeID::get<SecondTestDialect>()) {} 43 }; 44 45 struct TestDialectInterfaceBase 46 : public DialectInterface::Base<TestDialectInterfaceBase> { 47 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} 48 virtual int function() const { return 42; } 49 }; 50 51 struct TestDialectInterface : public TestDialectInterfaceBase { 52 using TestDialectInterfaceBase::TestDialectInterfaceBase; 53 int function() const final { return 56; } 54 }; 55 56 struct SecondTestDialectInterface : public TestDialectInterfaceBase { 57 using TestDialectInterfaceBase::TestDialectInterfaceBase; 58 int function() const final { return 78; } 59 }; 60 61 TEST(Dialect, DelayedInterfaceRegistration) { 62 DialectRegistry registry; 63 registry.insert<TestDialect, SecondTestDialect>(); 64 65 // Delayed registration of an interface for TestDialect. 66 registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 67 dialect->addInterfaces<TestDialectInterface>(); 68 }); 69 70 MLIRContext context(registry); 71 72 // Load the TestDialect and check that the interface got registered for it. 73 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 74 ASSERT_TRUE(testDialect != nullptr); 75 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 76 EXPECT_TRUE(testDialectInterface != nullptr); 77 78 // Load the SecondTestDialect and check that the interface is not registered 79 // for it. 80 Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); 81 ASSERT_TRUE(secondTestDialect != nullptr); 82 auto *secondTestDialectInterface = 83 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 84 EXPECT_TRUE(secondTestDialectInterface == nullptr); 85 86 // Use the same mechanism as for delayed registration but for an already 87 // loaded dialect and check that the interface is now registered. 88 DialectRegistry secondRegistry; 89 secondRegistry.insert<SecondTestDialect>(); 90 secondRegistry.addExtension( 91 +[](MLIRContext *ctx, SecondTestDialect *dialect) { 92 dialect->addInterfaces<SecondTestDialectInterface>(); 93 }); 94 context.appendDialectRegistry(secondRegistry); 95 secondTestDialectInterface = 96 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 97 EXPECT_TRUE(secondTestDialectInterface != nullptr); 98 } 99 100 TEST(Dialect, RepeatedDelayedRegistration) { 101 // Set up the delayed registration. 102 DialectRegistry registry; 103 registry.insert<TestDialect>(); 104 registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 105 dialect->addInterfaces<TestDialectInterface>(); 106 }); 107 MLIRContext context(registry); 108 109 // Load the TestDialect and check that the interface got registered for it. 110 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 111 ASSERT_TRUE(testDialect != nullptr); 112 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 113 EXPECT_TRUE(testDialectInterface != nullptr); 114 115 // Try adding the same dialect interface again and check that we don't crash 116 // on repeated interface registration. 117 DialectRegistry secondRegistry; 118 secondRegistry.insert<TestDialect>(); 119 secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 120 dialect->addInterfaces<TestDialectInterface>(); 121 }); 122 context.appendDialectRegistry(secondRegistry); 123 testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 124 EXPECT_TRUE(testDialectInterface != nullptr); 125 } 126 127 } // namespace 128