xref: /llvm-project/mlir/unittests/IR/DialectTest.cpp (revision 5e50dd048e3a20cde5da5d7a754dfee775ef35d6)
1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 namespace {
17 struct TestDialect : public Dialect {
18   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
19 
20   static StringRef getDialectNamespace() { return "test"; };
21   TestDialect(MLIRContext *context)
22       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
23 };
24 struct AnotherTestDialect : public Dialect {
25   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect)
26 
27   static StringRef getDialectNamespace() { return "test"; };
28   AnotherTestDialect(MLIRContext *context)
29       : Dialect(getDialectNamespace(), context,
30                 TypeID::get<AnotherTestDialect>()) {}
31 };
32 
33 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
34   MLIRContext context;
35 
36   // Registering a dialect with the same namespace twice should result in a
37   // failure.
38   context.loadDialect<TestDialect>();
39   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
40 }
41 
42 struct SecondTestDialect : public Dialect {
43   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect)
44 
45   static StringRef getDialectNamespace() { return "test2"; }
46   SecondTestDialect(MLIRContext *context)
47       : Dialect(getDialectNamespace(), context,
48                 TypeID::get<SecondTestDialect>()) {}
49 };
50 
51 struct TestDialectInterfaceBase
52     : public DialectInterface::Base<TestDialectInterfaceBase> {
53   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase)
54 
55   TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
56   virtual int function() const { return 42; }
57 };
58 
59 struct TestDialectInterface : public TestDialectInterfaceBase {
60   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface)
61 
62   using TestDialectInterfaceBase::TestDialectInterfaceBase;
63   int function() const final { return 56; }
64 };
65 
66 struct SecondTestDialectInterface : public TestDialectInterfaceBase {
67   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface)
68 
69   using TestDialectInterfaceBase::TestDialectInterfaceBase;
70   int function() const final { return 78; }
71 };
72 
73 TEST(Dialect, DelayedInterfaceRegistration) {
74   DialectRegistry registry;
75   registry.insert<TestDialect, SecondTestDialect>();
76 
77   // Delayed registration of an interface for TestDialect.
78   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
79     dialect->addInterfaces<TestDialectInterface>();
80   });
81 
82   MLIRContext context(registry);
83 
84   // Load the TestDialect and check that the interface got registered for it.
85   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
86   ASSERT_TRUE(testDialect != nullptr);
87   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
88   EXPECT_TRUE(testDialectInterface != nullptr);
89 
90   // Load the SecondTestDialect and check that the interface is not registered
91   // for it.
92   Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
93   ASSERT_TRUE(secondTestDialect != nullptr);
94   auto *secondTestDialectInterface =
95       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
96   EXPECT_TRUE(secondTestDialectInterface == nullptr);
97 
98   // Use the same mechanism as for delayed registration but for an already
99   // loaded dialect and check that the interface is now registered.
100   DialectRegistry secondRegistry;
101   secondRegistry.insert<SecondTestDialect>();
102   secondRegistry.addExtension(
103       +[](MLIRContext *ctx, SecondTestDialect *dialect) {
104         dialect->addInterfaces<SecondTestDialectInterface>();
105       });
106   context.appendDialectRegistry(secondRegistry);
107   secondTestDialectInterface =
108       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
109   EXPECT_TRUE(secondTestDialectInterface != nullptr);
110 }
111 
112 TEST(Dialect, RepeatedDelayedRegistration) {
113   // Set up the delayed registration.
114   DialectRegistry registry;
115   registry.insert<TestDialect>();
116   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
117     dialect->addInterfaces<TestDialectInterface>();
118   });
119   MLIRContext context(registry);
120 
121   // Load the TestDialect and check that the interface got registered for it.
122   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
123   ASSERT_TRUE(testDialect != nullptr);
124   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
125   EXPECT_TRUE(testDialectInterface != nullptr);
126 
127   // Try adding the same dialect interface again and check that we don't crash
128   // on repeated interface registration.
129   DialectRegistry secondRegistry;
130   secondRegistry.insert<TestDialect>();
131   secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
132     dialect->addInterfaces<TestDialectInterface>();
133   });
134   context.appendDialectRegistry(secondRegistry);
135   testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
136   EXPECT_TRUE(testDialectInterface != nullptr);
137 }
138 
139 } // namespace
140