xref: /llvm-project/mlir/unittests/IR/DialectTest.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "mlir/Support/TypeID.h"
12 #include "gtest/gtest.h"
13 
14 using namespace mlir;
15 using namespace mlir::detail;
16 
17 namespace {
18 struct TestDialect : public Dialect {
19   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
20 
21   static StringRef getDialectNamespace() { return "test"; };
22   TestDialect(MLIRContext *context)
23       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
24 };
25 struct AnotherTestDialect : public Dialect {
26   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect)
27 
28   static StringRef getDialectNamespace() { return "test"; };
29   AnotherTestDialect(MLIRContext *context)
30       : Dialect(getDialectNamespace(), context,
31                 TypeID::get<AnotherTestDialect>()) {}
32 };
33 
34 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
35   MLIRContext context;
36 
37   // Registering a dialect with the same namespace twice should result in a
38   // failure.
39   context.loadDialect<TestDialect>();
40   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
41 }
42 
43 struct SecondTestDialect : public Dialect {
44   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect)
45 
46   static StringRef getDialectNamespace() { return "test2"; }
47   SecondTestDialect(MLIRContext *context)
48       : Dialect(getDialectNamespace(), context,
49                 TypeID::get<SecondTestDialect>()) {}
50 };
51 
52 struct TestDialectInterfaceBase
53     : public DialectInterface::Base<TestDialectInterfaceBase> {
54   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase)
55 
56   TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
57   virtual int function() const { return 42; }
58 };
59 
60 struct TestDialectInterface : public TestDialectInterfaceBase {
61   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface)
62 
63   using TestDialectInterfaceBase::TestDialectInterfaceBase;
64   int function() const final { return 56; }
65 };
66 
67 struct SecondTestDialectInterface : public TestDialectInterfaceBase {
68   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface)
69 
70   using TestDialectInterfaceBase::TestDialectInterfaceBase;
71   int function() const final { return 78; }
72 };
73 
74 TEST(Dialect, DelayedInterfaceRegistration) {
75   DialectRegistry registry;
76   registry.insert<TestDialect, SecondTestDialect>();
77 
78   // Delayed registration of an interface for TestDialect.
79   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
80     dialect->addInterfaces<TestDialectInterface>();
81   });
82 
83   MLIRContext context(registry);
84 
85   // Load the TestDialect and check that the interface got registered for it.
86   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
87   ASSERT_TRUE(testDialect != nullptr);
88   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
89   EXPECT_TRUE(testDialectInterface != nullptr);
90 
91   // Load the SecondTestDialect and check that the interface is not registered
92   // for it.
93   Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
94   ASSERT_TRUE(secondTestDialect != nullptr);
95   auto *secondTestDialectInterface =
96       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
97   EXPECT_TRUE(secondTestDialectInterface == nullptr);
98 
99   // Use the same mechanism as for delayed registration but for an already
100   // loaded dialect and check that the interface is now registered.
101   DialectRegistry secondRegistry;
102   secondRegistry.insert<SecondTestDialect>();
103   secondRegistry.addExtension(
104       +[](MLIRContext *ctx, SecondTestDialect *dialect) {
105         dialect->addInterfaces<SecondTestDialectInterface>();
106       });
107   context.appendDialectRegistry(secondRegistry);
108   secondTestDialectInterface =
109       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
110   EXPECT_TRUE(secondTestDialectInterface != nullptr);
111 }
112 
113 TEST(Dialect, RepeatedDelayedRegistration) {
114   // Set up the delayed registration.
115   DialectRegistry registry;
116   registry.insert<TestDialect>();
117   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
118     dialect->addInterfaces<TestDialectInterface>();
119   });
120   MLIRContext context(registry);
121 
122   // Load the TestDialect and check that the interface got registered for it.
123   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
124   ASSERT_TRUE(testDialect != nullptr);
125   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
126   EXPECT_TRUE(testDialectInterface != nullptr);
127 
128   // Try adding the same dialect interface again and check that we don't crash
129   // on repeated interface registration.
130   DialectRegistry secondRegistry;
131   secondRegistry.insert<TestDialect>();
132   secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
133     dialect->addInterfaces<TestDialectInterface>();
134   });
135   context.appendDialectRegistry(secondRegistry);
136   testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
137   EXPECT_TRUE(testDialectInterface != nullptr);
138 }
139 
140 namespace {
141 /// A dummy extension that increases a counter when being applied and
142 /// recursively adds additional extensions.
143 struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
144   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension)
145 
146   DummyExtension(int *counter, int numRecursive)
147       : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
148 
149   void apply(MLIRContext *ctx, TestDialect *dialect) const final {
150     ++(*counter);
151     DialectRegistry nestedRegistry;
152     for (int i = 0; i < numRecursive; ++i) {
153       // Create unique TypeIDs for these recursive extensions so they don't get
154       // de-duplicated.
155       auto extension =
156           std::make_unique<DummyExtension>(counter, /*numRecursive=*/0);
157       auto typeID = TypeID::getFromOpaquePointer(extension.get());
158       nestedRegistry.addExtension(typeID, std::move(extension));
159     }
160     // Adding additional extensions may trigger a reallocation of the
161     // `extensions` vector in the dialect registry.
162     ctx->appendDialectRegistry(nestedRegistry);
163   }
164 
165 private:
166   int *counter;
167   int numRecursive;
168 };
169 } // namespace
170 
171 TEST(Dialect, NestedDialectExtension) {
172   DialectRegistry registry;
173   registry.insert<TestDialect>();
174 
175   // Add an extension that adds 100 more extensions.
176   int counter1 = 0;
177   registry.addExtension(TypeID::get<DummyExtension>(),
178                         std::make_unique<DummyExtension>(&counter1, 100));
179   // Add one more extension. This should not crash.
180   int counter2 = 0;
181   registry.addExtension(TypeID::getFromOpaquePointer(&counter2),
182                         std::make_unique<DummyExtension>(&counter2, 0));
183 
184   // Load dialect and apply extensions.
185   MLIRContext context(registry);
186   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
187   ASSERT_TRUE(testDialect != nullptr);
188 
189   // Extensions are de-duplicated by typeID. Make sure that each expected
190   // extension was applied at least once.
191   EXPECT_GE(counter1, 101);
192   EXPECT_GE(counter2, 1);
193 }
194 
195 TEST(Dialect, SubsetWithExtensions) {
196   DialectRegistry registry1, registry2;
197   registry1.insert<TestDialect>();
198   registry2.insert<TestDialect>();
199 
200   // Validate that the registries are equivalent.
201   ASSERT_TRUE(registry1.isSubsetOf(registry2));
202   ASSERT_TRUE(registry2.isSubsetOf(registry1));
203 
204   // Add extensions to registry2.
205   int counter = 0;
206   registry2.addExtension(TypeID::get<DummyExtension>(),
207                          std::make_unique<DummyExtension>(&counter, 0));
208 
209   // Expect that (1) is a subset of (2) but not the other way around.
210   ASSERT_TRUE(registry1.isSubsetOf(registry2));
211   ASSERT_FALSE(registry2.isSubsetOf(registry1));
212 
213   // Add extensions to registry1.
214   registry1.addExtension(TypeID::get<DummyExtension>(),
215                          std::make_unique<DummyExtension>(&counter, 0));
216 
217   // Expect that (1) and (2) are equivalent.
218   ASSERT_TRUE(registry1.isSubsetOf(registry2));
219   ASSERT_TRUE(registry2.isSubsetOf(registry1));
220 
221   // Load dialect and apply extensions.
222   MLIRContext context(registry1);
223   context.getOrLoadDialect<TestDialect>();
224   context.appendDialectRegistry(registry2);
225   // Expect that the extension as only invoked once.
226   ASSERT_EQ(counter, 1);
227 }
228 
229 } // namespace
230