xref: /llvm-project/mlir/unittests/IR/DialectTest.cpp (revision 58e7bf78a3ef724b70304912fb3bb66af8c4a10c)
1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 namespace {
17 struct TestDialect : public Dialect {
18   static StringRef getDialectNamespace() { return "test"; };
19   TestDialect(MLIRContext *context)
20       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
21 };
22 struct AnotherTestDialect : public Dialect {
23   static StringRef getDialectNamespace() { return "test"; };
24   AnotherTestDialect(MLIRContext *context)
25       : Dialect(getDialectNamespace(), context,
26                 TypeID::get<AnotherTestDialect>()) {}
27 };
28 
29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
30   MLIRContext context;
31 
32   // Registering a dialect with the same namespace twice should result in a
33   // failure.
34   context.loadDialect<TestDialect>();
35   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
36 }
37 
38 struct SecondTestDialect : public Dialect {
39   static StringRef getDialectNamespace() { return "test2"; }
40   SecondTestDialect(MLIRContext *context)
41       : Dialect(getDialectNamespace(), context,
42                 TypeID::get<SecondTestDialect>()) {}
43 };
44 
45 struct TestDialectInterfaceBase
46     : public DialectInterface::Base<TestDialectInterfaceBase> {
47   TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
48   virtual int function() const { return 42; }
49 };
50 
51 struct TestDialectInterface : public TestDialectInterfaceBase {
52   using TestDialectInterfaceBase::TestDialectInterfaceBase;
53   int function() const final { return 56; }
54 };
55 
56 struct SecondTestDialectInterface : public TestDialectInterfaceBase {
57   using TestDialectInterfaceBase::TestDialectInterfaceBase;
58   int function() const final { return 78; }
59 };
60 
61 TEST(Dialect, DelayedInterfaceRegistration) {
62   DialectRegistry registry;
63   registry.insert<TestDialect, SecondTestDialect>();
64 
65   // Delayed registration of an interface for TestDialect.
66   registry.addDialectInterface<TestDialect, TestDialectInterface>();
67 
68   MLIRContext context(registry);
69 
70   // Load the TestDialect and check that the interface got registered for it.
71   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
72   ASSERT_TRUE(testDialect != nullptr);
73   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
74   EXPECT_TRUE(testDialectInterface != nullptr);
75 
76   // Load the SecondTestDialect and check that the interface is not registered
77   // for it.
78   Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
79   ASSERT_TRUE(secondTestDialect != nullptr);
80   auto *secondTestDialectInterface =
81       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
82   EXPECT_TRUE(secondTestDialectInterface == nullptr);
83 
84   // Use the same mechanism as for delayed registration but for an already
85   // loaded dialect and check that the interface is now registered.
86   DialectRegistry secondRegistry;
87   secondRegistry.insert<SecondTestDialect>();
88   secondRegistry
89       .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
90   context.appendDialectRegistry(secondRegistry);
91   secondTestDialectInterface =
92       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
93   EXPECT_TRUE(secondTestDialectInterface != nullptr);
94 }
95 
96 TEST(Dialect, RepeatedDelayedRegistration) {
97   // Set up the delayed registration.
98   DialectRegistry registry;
99   registry.insert<TestDialect>();
100   registry.addDialectInterface<TestDialect, TestDialectInterface>();
101   MLIRContext context(registry);
102 
103   // Load the TestDialect and check that the interface got registered for it.
104   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
105   ASSERT_TRUE(testDialect != nullptr);
106   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
107   EXPECT_TRUE(testDialectInterface != nullptr);
108 
109   // Try adding the same dialect interface again and check that we don't crash
110   // on repeated interface registration.
111   DialectRegistry secondRegistry;
112   secondRegistry.insert<TestDialect>();
113   secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
114   context.appendDialectRegistry(secondRegistry);
115   testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
116   EXPECT_TRUE(testDialectInterface != nullptr);
117 }
118 
119 // A dialect that registers two interfaces with the same InterfaceID, triggering
120 // an assertion failure.
121 struct RepeatedRegistrationDialect : public Dialect {
122   static StringRef getDialectNamespace() { return "repeatedreg"; }
123   RepeatedRegistrationDialect(MLIRContext *context)
124       : Dialect(getDialectNamespace(), context,
125                 TypeID::get<RepeatedRegistrationDialect>()) {
126     addInterfaces<TestDialectInterface>();
127     addInterfaces<SecondTestDialectInterface>();
128   }
129 };
130 
131 TEST(Dialect, RepeatedInterfaceRegistrationDeath) {
132   MLIRContext context;
133   (void)context;
134 
135   // This triggers an assertion in debug mode.
136 #ifndef NDEBUG
137   ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(),
138                "interface kind has already been registered");
139 #endif
140 }
141 
142 } // namespace
143