1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/DialectInterface.h" 11 #include "gtest/gtest.h" 12 13 using namespace mlir; 14 using namespace mlir::detail; 15 16 namespace { 17 struct TestDialect : public Dialect { 18 static StringRef getDialectNamespace() { return "test"; }; 19 TestDialect(MLIRContext *context) 20 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} 21 }; 22 struct AnotherTestDialect : public Dialect { 23 static StringRef getDialectNamespace() { return "test"; }; 24 AnotherTestDialect(MLIRContext *context) 25 : Dialect(getDialectNamespace(), context, 26 TypeID::get<AnotherTestDialect>()) {} 27 }; 28 29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { 30 MLIRContext context; 31 32 // Registering a dialect with the same namespace twice should result in a 33 // failure. 34 context.loadDialect<TestDialect>(); 35 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), ""); 36 } 37 38 struct SecondTestDialect : public Dialect { 39 static StringRef getDialectNamespace() { return "test2"; } 40 SecondTestDialect(MLIRContext *context) 41 : Dialect(getDialectNamespace(), context, 42 TypeID::get<SecondTestDialect>()) {} 43 }; 44 45 struct TestDialectInterfaceBase 46 : public DialectInterface::Base<TestDialectInterfaceBase> { 47 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} 48 virtual int function() const { return 42; } 49 }; 50 51 struct TestDialectInterface : public TestDialectInterfaceBase { 52 using TestDialectInterfaceBase::TestDialectInterfaceBase; 53 int function() const final { return 56; } 54 }; 55 56 struct SecondTestDialectInterface : public TestDialectInterfaceBase { 57 using TestDialectInterfaceBase::TestDialectInterfaceBase; 58 int function() const final { return 78; } 59 }; 60 61 TEST(Dialect, DelayedInterfaceRegistration) { 62 DialectRegistry registry; 63 registry.insert<TestDialect, SecondTestDialect>(); 64 65 // Delayed registration of an interface for TestDialect. 66 registry.addDialectInterface<TestDialect, TestDialectInterface>(); 67 68 MLIRContext context(registry); 69 70 // Load the TestDialect and check that the interface got registered for it. 71 auto *testDialect = context.getOrLoadDialect<TestDialect>(); 72 ASSERT_TRUE(testDialect != nullptr); 73 auto *testDialectInterface = 74 testDialect->getRegisteredInterface<TestDialectInterfaceBase>(); 75 EXPECT_TRUE(testDialectInterface != nullptr); 76 77 // Load the SecondTestDialect and check that the interface is not registered 78 // for it. 79 auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); 80 ASSERT_TRUE(secondTestDialect != nullptr); 81 auto *secondTestDialectInterface = 82 secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>(); 83 EXPECT_TRUE(secondTestDialectInterface == nullptr); 84 85 // Use the same mechanism as for delayed registration but for an already 86 // loaded dialect and check that the interface is now registered. 87 DialectRegistry secondRegistry; 88 secondRegistry.insert<SecondTestDialect>(); 89 secondRegistry 90 .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>(); 91 context.appendDialectRegistry(secondRegistry); 92 secondTestDialectInterface = 93 secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>(); 94 EXPECT_TRUE(secondTestDialectInterface != nullptr); 95 } 96 97 } // end namespace 98