xref: /llvm-project/mlir/unittests/IR/DialectTest.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
16f24bf82SRiver Riddle //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
26f24bf82SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66f24bf82SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
86f24bf82SRiver Riddle 
96f24bf82SRiver Riddle #include "mlir/IR/Dialect.h"
103da51522SAlex Zinenko #include "mlir/IR/DialectInterface.h"
11*84cc1865SNikhil Kalra #include "mlir/Support/TypeID.h"
126f24bf82SRiver Riddle #include "gtest/gtest.h"
136f24bf82SRiver Riddle 
146f24bf82SRiver Riddle using namespace mlir;
156f24bf82SRiver Riddle using namespace mlir::detail;
166f24bf82SRiver Riddle 
176f24bf82SRiver Riddle namespace {
186f24bf82SRiver Riddle struct TestDialect : public Dialect {
195e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
205e50dd04SRiver Riddle 
21575b22b5SMehdi Amini   static StringRef getDialectNamespace() { return "test"; };
22575b22b5SMehdi Amini   TestDialect(MLIRContext *context)
23575b22b5SMehdi Amini       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
24575b22b5SMehdi Amini };
25575b22b5SMehdi Amini struct AnotherTestDialect : public Dialect {
265e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect)
275e50dd04SRiver Riddle 
28575b22b5SMehdi Amini   static StringRef getDialectNamespace() { return "test"; };
29575b22b5SMehdi Amini   AnotherTestDialect(MLIRContext *context)
30575b22b5SMehdi Amini       : Dialect(getDialectNamespace(), context,
31575b22b5SMehdi Amini                 TypeID::get<AnotherTestDialect>()) {}
326f24bf82SRiver Riddle };
336f24bf82SRiver Riddle 
346f24bf82SRiver Riddle TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
35e7021232SMehdi Amini   MLIRContext context;
366f24bf82SRiver Riddle 
376f24bf82SRiver Riddle   // Registering a dialect with the same namespace twice should result in a
386f24bf82SRiver Riddle   // failure.
39f9dc2b70SMehdi Amini   context.loadDialect<TestDialect>();
40f9dc2b70SMehdi Amini   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
416f24bf82SRiver Riddle }
426f24bf82SRiver Riddle 
433da51522SAlex Zinenko struct SecondTestDialect : public Dialect {
445e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect)
455e50dd04SRiver Riddle 
463da51522SAlex Zinenko   static StringRef getDialectNamespace() { return "test2"; }
473da51522SAlex Zinenko   SecondTestDialect(MLIRContext *context)
483da51522SAlex Zinenko       : Dialect(getDialectNamespace(), context,
493da51522SAlex Zinenko                 TypeID::get<SecondTestDialect>()) {}
503da51522SAlex Zinenko };
513da51522SAlex Zinenko 
523da51522SAlex Zinenko struct TestDialectInterfaceBase
533da51522SAlex Zinenko     : public DialectInterface::Base<TestDialectInterfaceBase> {
545e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase)
555e50dd04SRiver Riddle 
563da51522SAlex Zinenko   TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
573da51522SAlex Zinenko   virtual int function() const { return 42; }
583da51522SAlex Zinenko };
593da51522SAlex Zinenko 
603da51522SAlex Zinenko struct TestDialectInterface : public TestDialectInterfaceBase {
615e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface)
625e50dd04SRiver Riddle 
633da51522SAlex Zinenko   using TestDialectInterfaceBase::TestDialectInterfaceBase;
643da51522SAlex Zinenko   int function() const final { return 56; }
653da51522SAlex Zinenko };
663da51522SAlex Zinenko 
673da51522SAlex Zinenko struct SecondTestDialectInterface : public TestDialectInterfaceBase {
685e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface)
695e50dd04SRiver Riddle 
703da51522SAlex Zinenko   using TestDialectInterfaceBase::TestDialectInterfaceBase;
713da51522SAlex Zinenko   int function() const final { return 78; }
723da51522SAlex Zinenko };
733da51522SAlex Zinenko 
743da51522SAlex Zinenko TEST(Dialect, DelayedInterfaceRegistration) {
753da51522SAlex Zinenko   DialectRegistry registry;
763da51522SAlex Zinenko   registry.insert<TestDialect, SecondTestDialect>();
773da51522SAlex Zinenko 
783da51522SAlex Zinenko   // Delayed registration of an interface for TestDialect.
7977eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
8077eee579SRiver Riddle     dialect->addInterfaces<TestDialectInterface>();
8177eee579SRiver Riddle   });
823da51522SAlex Zinenko 
832996a8d6SAlex Zinenko   MLIRContext context(registry);
843da51522SAlex Zinenko 
853da51522SAlex Zinenko   // Load the TestDialect and check that the interface got registered for it.
8658e7bf78SRiver Riddle   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
873da51522SAlex Zinenko   ASSERT_TRUE(testDialect != nullptr);
8858e7bf78SRiver Riddle   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
893da51522SAlex Zinenko   EXPECT_TRUE(testDialectInterface != nullptr);
903da51522SAlex Zinenko 
913da51522SAlex Zinenko   // Load the SecondTestDialect and check that the interface is not registered
923da51522SAlex Zinenko   // for it.
9358e7bf78SRiver Riddle   Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
943da51522SAlex Zinenko   ASSERT_TRUE(secondTestDialect != nullptr);
953da51522SAlex Zinenko   auto *secondTestDialectInterface =
9658e7bf78SRiver Riddle       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
973da51522SAlex Zinenko   EXPECT_TRUE(secondTestDialectInterface == nullptr);
983da51522SAlex Zinenko 
993da51522SAlex Zinenko   // Use the same mechanism as for delayed registration but for an already
1003da51522SAlex Zinenko   // loaded dialect and check that the interface is now registered.
1012996a8d6SAlex Zinenko   DialectRegistry secondRegistry;
1022996a8d6SAlex Zinenko   secondRegistry.insert<SecondTestDialect>();
10377eee579SRiver Riddle   secondRegistry.addExtension(
10477eee579SRiver Riddle       +[](MLIRContext *ctx, SecondTestDialect *dialect) {
10577eee579SRiver Riddle         dialect->addInterfaces<SecondTestDialectInterface>();
10677eee579SRiver Riddle       });
1072996a8d6SAlex Zinenko   context.appendDialectRegistry(secondRegistry);
1083da51522SAlex Zinenko   secondTestDialectInterface =
10958e7bf78SRiver Riddle       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
1103da51522SAlex Zinenko   EXPECT_TRUE(secondTestDialectInterface != nullptr);
1113da51522SAlex Zinenko }
1123da51522SAlex Zinenko 
11334ea608aSAlex Zinenko TEST(Dialect, RepeatedDelayedRegistration) {
11434ea608aSAlex Zinenko   // Set up the delayed registration.
11534ea608aSAlex Zinenko   DialectRegistry registry;
11634ea608aSAlex Zinenko   registry.insert<TestDialect>();
11777eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
11877eee579SRiver Riddle     dialect->addInterfaces<TestDialectInterface>();
11977eee579SRiver Riddle   });
12034ea608aSAlex Zinenko   MLIRContext context(registry);
12134ea608aSAlex Zinenko 
12234ea608aSAlex Zinenko   // Load the TestDialect and check that the interface got registered for it.
12358e7bf78SRiver Riddle   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
12434ea608aSAlex Zinenko   ASSERT_TRUE(testDialect != nullptr);
12558e7bf78SRiver Riddle   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
12634ea608aSAlex Zinenko   EXPECT_TRUE(testDialectInterface != nullptr);
12734ea608aSAlex Zinenko 
12834ea608aSAlex Zinenko   // Try adding the same dialect interface again and check that we don't crash
12934ea608aSAlex Zinenko   // on repeated interface registration.
13034ea608aSAlex Zinenko   DialectRegistry secondRegistry;
13134ea608aSAlex Zinenko   secondRegistry.insert<TestDialect>();
13277eee579SRiver Riddle   secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
13377eee579SRiver Riddle     dialect->addInterfaces<TestDialectInterface>();
13477eee579SRiver Riddle   });
13534ea608aSAlex Zinenko   context.appendDialectRegistry(secondRegistry);
13658e7bf78SRiver Riddle   testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
13734ea608aSAlex Zinenko   EXPECT_TRUE(testDialectInterface != nullptr);
13834ea608aSAlex Zinenko }
13934ea608aSAlex Zinenko 
140012b148dSMatthias Springer namespace {
141012b148dSMatthias Springer /// A dummy extension that increases a counter when being applied and
142012b148dSMatthias Springer /// recursively adds additional extensions.
143012b148dSMatthias Springer struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
144*84cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension)
145*84cc1865SNikhil Kalra 
146012b148dSMatthias Springer   DummyExtension(int *counter, int numRecursive)
147012b148dSMatthias Springer       : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
148012b148dSMatthias Springer 
149012b148dSMatthias Springer   void apply(MLIRContext *ctx, TestDialect *dialect) const final {
150012b148dSMatthias Springer     ++(*counter);
151012b148dSMatthias Springer     DialectRegistry nestedRegistry;
152*84cc1865SNikhil Kalra     for (int i = 0; i < numRecursive; ++i) {
153*84cc1865SNikhil Kalra       // Create unique TypeIDs for these recursive extensions so they don't get
154*84cc1865SNikhil Kalra       // de-duplicated.
155*84cc1865SNikhil Kalra       auto extension =
156*84cc1865SNikhil Kalra           std::make_unique<DummyExtension>(counter, /*numRecursive=*/0);
157*84cc1865SNikhil Kalra       auto typeID = TypeID::getFromOpaquePointer(extension.get());
158*84cc1865SNikhil Kalra       nestedRegistry.addExtension(typeID, std::move(extension));
159*84cc1865SNikhil Kalra     }
160012b148dSMatthias Springer     // Adding additional extensions may trigger a reallocation of the
161012b148dSMatthias Springer     // `extensions` vector in the dialect registry.
162012b148dSMatthias Springer     ctx->appendDialectRegistry(nestedRegistry);
163012b148dSMatthias Springer   }
164012b148dSMatthias Springer 
165012b148dSMatthias Springer private:
166012b148dSMatthias Springer   int *counter;
167012b148dSMatthias Springer   int numRecursive;
168012b148dSMatthias Springer };
169012b148dSMatthias Springer } // namespace
170012b148dSMatthias Springer 
171012b148dSMatthias Springer TEST(Dialect, NestedDialectExtension) {
172012b148dSMatthias Springer   DialectRegistry registry;
173012b148dSMatthias Springer   registry.insert<TestDialect>();
174012b148dSMatthias Springer 
175012b148dSMatthias Springer   // Add an extension that adds 100 more extensions.
176012b148dSMatthias Springer   int counter1 = 0;
177*84cc1865SNikhil Kalra   registry.addExtension(TypeID::get<DummyExtension>(),
178*84cc1865SNikhil Kalra                         std::make_unique<DummyExtension>(&counter1, 100));
179012b148dSMatthias Springer   // Add one more extension. This should not crash.
180012b148dSMatthias Springer   int counter2 = 0;
181*84cc1865SNikhil Kalra   registry.addExtension(TypeID::getFromOpaquePointer(&counter2),
182*84cc1865SNikhil Kalra                         std::make_unique<DummyExtension>(&counter2, 0));
183012b148dSMatthias Springer 
184012b148dSMatthias Springer   // Load dialect and apply extensions.
185012b148dSMatthias Springer   MLIRContext context(registry);
186012b148dSMatthias Springer   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
187012b148dSMatthias Springer   ASSERT_TRUE(testDialect != nullptr);
188012b148dSMatthias Springer 
189*84cc1865SNikhil Kalra   // Extensions are de-duplicated by typeID. Make sure that each expected
190012b148dSMatthias Springer   // extension was applied at least once.
191012b148dSMatthias Springer   EXPECT_GE(counter1, 101);
192012b148dSMatthias Springer   EXPECT_GE(counter2, 1);
193012b148dSMatthias Springer }
194012b148dSMatthias Springer 
195*84cc1865SNikhil Kalra TEST(Dialect, SubsetWithExtensions) {
196*84cc1865SNikhil Kalra   DialectRegistry registry1, registry2;
197*84cc1865SNikhil Kalra   registry1.insert<TestDialect>();
198*84cc1865SNikhil Kalra   registry2.insert<TestDialect>();
199*84cc1865SNikhil Kalra 
200*84cc1865SNikhil Kalra   // Validate that the registries are equivalent.
201*84cc1865SNikhil Kalra   ASSERT_TRUE(registry1.isSubsetOf(registry2));
202*84cc1865SNikhil Kalra   ASSERT_TRUE(registry2.isSubsetOf(registry1));
203*84cc1865SNikhil Kalra 
204*84cc1865SNikhil Kalra   // Add extensions to registry2.
205*84cc1865SNikhil Kalra   int counter = 0;
206*84cc1865SNikhil Kalra   registry2.addExtension(TypeID::get<DummyExtension>(),
207*84cc1865SNikhil Kalra                          std::make_unique<DummyExtension>(&counter, 0));
208*84cc1865SNikhil Kalra 
209*84cc1865SNikhil Kalra   // Expect that (1) is a subset of (2) but not the other way around.
210*84cc1865SNikhil Kalra   ASSERT_TRUE(registry1.isSubsetOf(registry2));
211*84cc1865SNikhil Kalra   ASSERT_FALSE(registry2.isSubsetOf(registry1));
212*84cc1865SNikhil Kalra 
213*84cc1865SNikhil Kalra   // Add extensions to registry1.
214*84cc1865SNikhil Kalra   registry1.addExtension(TypeID::get<DummyExtension>(),
215*84cc1865SNikhil Kalra                          std::make_unique<DummyExtension>(&counter, 0));
216*84cc1865SNikhil Kalra 
217*84cc1865SNikhil Kalra   // Expect that (1) and (2) are equivalent.
218*84cc1865SNikhil Kalra   ASSERT_TRUE(registry1.isSubsetOf(registry2));
219*84cc1865SNikhil Kalra   ASSERT_TRUE(registry2.isSubsetOf(registry1));
220*84cc1865SNikhil Kalra 
221*84cc1865SNikhil Kalra   // Load dialect and apply extensions.
222*84cc1865SNikhil Kalra   MLIRContext context(registry1);
223*84cc1865SNikhil Kalra   context.getOrLoadDialect<TestDialect>();
224*84cc1865SNikhil Kalra   context.appendDialectRegistry(registry2);
225*84cc1865SNikhil Kalra   // Expect that the extension as only invoked once.
226*84cc1865SNikhil Kalra   ASSERT_EQ(counter, 1);
227*84cc1865SNikhil Kalra }
228*84cc1865SNikhil Kalra 
229be0a7e9fSMehdi Amini } // namespace
230