16f24bf82SRiver Riddle //===- DialectTest.cpp - Dialect unit tests -------------------------------===// 26f24bf82SRiver Riddle // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 66f24bf82SRiver Riddle // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 86f24bf82SRiver Riddle 96f24bf82SRiver Riddle #include "mlir/IR/Dialect.h" 103da51522SAlex Zinenko #include "mlir/IR/DialectInterface.h" 11*84cc1865SNikhil Kalra #include "mlir/Support/TypeID.h" 126f24bf82SRiver Riddle #include "gtest/gtest.h" 136f24bf82SRiver Riddle 146f24bf82SRiver Riddle using namespace mlir; 156f24bf82SRiver Riddle using namespace mlir::detail; 166f24bf82SRiver Riddle 176f24bf82SRiver Riddle namespace { 186f24bf82SRiver Riddle struct TestDialect : public Dialect { 195e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) 205e50dd04SRiver Riddle 21575b22b5SMehdi Amini static StringRef getDialectNamespace() { return "test"; }; 22575b22b5SMehdi Amini TestDialect(MLIRContext *context) 23575b22b5SMehdi Amini : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} 24575b22b5SMehdi Amini }; 25575b22b5SMehdi Amini struct AnotherTestDialect : public Dialect { 265e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect) 275e50dd04SRiver Riddle 28575b22b5SMehdi Amini static StringRef getDialectNamespace() { return "test"; }; 29575b22b5SMehdi Amini AnotherTestDialect(MLIRContext *context) 30575b22b5SMehdi Amini : Dialect(getDialectNamespace(), context, 31575b22b5SMehdi Amini TypeID::get<AnotherTestDialect>()) {} 326f24bf82SRiver Riddle }; 336f24bf82SRiver Riddle 346f24bf82SRiver Riddle TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { 35e7021232SMehdi Amini MLIRContext context; 366f24bf82SRiver Riddle 376f24bf82SRiver Riddle // Registering a dialect with the same namespace twice should result in a 386f24bf82SRiver Riddle // failure. 39f9dc2b70SMehdi Amini context.loadDialect<TestDialect>(); 40f9dc2b70SMehdi Amini ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), ""); 416f24bf82SRiver Riddle } 426f24bf82SRiver Riddle 433da51522SAlex Zinenko struct SecondTestDialect : public Dialect { 445e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect) 455e50dd04SRiver Riddle 463da51522SAlex Zinenko static StringRef getDialectNamespace() { return "test2"; } 473da51522SAlex Zinenko SecondTestDialect(MLIRContext *context) 483da51522SAlex Zinenko : Dialect(getDialectNamespace(), context, 493da51522SAlex Zinenko TypeID::get<SecondTestDialect>()) {} 503da51522SAlex Zinenko }; 513da51522SAlex Zinenko 523da51522SAlex Zinenko struct TestDialectInterfaceBase 533da51522SAlex Zinenko : public DialectInterface::Base<TestDialectInterfaceBase> { 545e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase) 555e50dd04SRiver Riddle 563da51522SAlex Zinenko TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} 573da51522SAlex Zinenko virtual int function() const { return 42; } 583da51522SAlex Zinenko }; 593da51522SAlex Zinenko 603da51522SAlex Zinenko struct TestDialectInterface : public TestDialectInterfaceBase { 615e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface) 625e50dd04SRiver Riddle 633da51522SAlex Zinenko using TestDialectInterfaceBase::TestDialectInterfaceBase; 643da51522SAlex Zinenko int function() const final { return 56; } 653da51522SAlex Zinenko }; 663da51522SAlex Zinenko 673da51522SAlex Zinenko struct SecondTestDialectInterface : public TestDialectInterfaceBase { 685e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface) 695e50dd04SRiver Riddle 703da51522SAlex Zinenko using TestDialectInterfaceBase::TestDialectInterfaceBase; 713da51522SAlex Zinenko int function() const final { return 78; } 723da51522SAlex Zinenko }; 733da51522SAlex Zinenko 743da51522SAlex Zinenko TEST(Dialect, DelayedInterfaceRegistration) { 753da51522SAlex Zinenko DialectRegistry registry; 763da51522SAlex Zinenko registry.insert<TestDialect, SecondTestDialect>(); 773da51522SAlex Zinenko 783da51522SAlex Zinenko // Delayed registration of an interface for TestDialect. 7977eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 8077eee579SRiver Riddle dialect->addInterfaces<TestDialectInterface>(); 8177eee579SRiver Riddle }); 823da51522SAlex Zinenko 832996a8d6SAlex Zinenko MLIRContext context(registry); 843da51522SAlex Zinenko 853da51522SAlex Zinenko // Load the TestDialect and check that the interface got registered for it. 8658e7bf78SRiver Riddle Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 873da51522SAlex Zinenko ASSERT_TRUE(testDialect != nullptr); 8858e7bf78SRiver Riddle auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 893da51522SAlex Zinenko EXPECT_TRUE(testDialectInterface != nullptr); 903da51522SAlex Zinenko 913da51522SAlex Zinenko // Load the SecondTestDialect and check that the interface is not registered 923da51522SAlex Zinenko // for it. 9358e7bf78SRiver Riddle Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); 943da51522SAlex Zinenko ASSERT_TRUE(secondTestDialect != nullptr); 953da51522SAlex Zinenko auto *secondTestDialectInterface = 9658e7bf78SRiver Riddle dyn_cast<SecondTestDialectInterface>(secondTestDialect); 973da51522SAlex Zinenko EXPECT_TRUE(secondTestDialectInterface == nullptr); 983da51522SAlex Zinenko 993da51522SAlex Zinenko // Use the same mechanism as for delayed registration but for an already 1003da51522SAlex Zinenko // loaded dialect and check that the interface is now registered. 1012996a8d6SAlex Zinenko DialectRegistry secondRegistry; 1022996a8d6SAlex Zinenko secondRegistry.insert<SecondTestDialect>(); 10377eee579SRiver Riddle secondRegistry.addExtension( 10477eee579SRiver Riddle +[](MLIRContext *ctx, SecondTestDialect *dialect) { 10577eee579SRiver Riddle dialect->addInterfaces<SecondTestDialectInterface>(); 10677eee579SRiver Riddle }); 1072996a8d6SAlex Zinenko context.appendDialectRegistry(secondRegistry); 1083da51522SAlex Zinenko secondTestDialectInterface = 10958e7bf78SRiver Riddle dyn_cast<SecondTestDialectInterface>(secondTestDialect); 1103da51522SAlex Zinenko EXPECT_TRUE(secondTestDialectInterface != nullptr); 1113da51522SAlex Zinenko } 1123da51522SAlex Zinenko 11334ea608aSAlex Zinenko TEST(Dialect, RepeatedDelayedRegistration) { 11434ea608aSAlex Zinenko // Set up the delayed registration. 11534ea608aSAlex Zinenko DialectRegistry registry; 11634ea608aSAlex Zinenko registry.insert<TestDialect>(); 11777eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 11877eee579SRiver Riddle dialect->addInterfaces<TestDialectInterface>(); 11977eee579SRiver Riddle }); 12034ea608aSAlex Zinenko MLIRContext context(registry); 12134ea608aSAlex Zinenko 12234ea608aSAlex Zinenko // Load the TestDialect and check that the interface got registered for it. 12358e7bf78SRiver Riddle Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 12434ea608aSAlex Zinenko ASSERT_TRUE(testDialect != nullptr); 12558e7bf78SRiver Riddle auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 12634ea608aSAlex Zinenko EXPECT_TRUE(testDialectInterface != nullptr); 12734ea608aSAlex Zinenko 12834ea608aSAlex Zinenko // Try adding the same dialect interface again and check that we don't crash 12934ea608aSAlex Zinenko // on repeated interface registration. 13034ea608aSAlex Zinenko DialectRegistry secondRegistry; 13134ea608aSAlex Zinenko secondRegistry.insert<TestDialect>(); 13277eee579SRiver Riddle secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 13377eee579SRiver Riddle dialect->addInterfaces<TestDialectInterface>(); 13477eee579SRiver Riddle }); 13534ea608aSAlex Zinenko context.appendDialectRegistry(secondRegistry); 13658e7bf78SRiver Riddle testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 13734ea608aSAlex Zinenko EXPECT_TRUE(testDialectInterface != nullptr); 13834ea608aSAlex Zinenko } 13934ea608aSAlex Zinenko 140012b148dSMatthias Springer namespace { 141012b148dSMatthias Springer /// A dummy extension that increases a counter when being applied and 142012b148dSMatthias Springer /// recursively adds additional extensions. 143012b148dSMatthias Springer struct DummyExtension : DialectExtension<DummyExtension, TestDialect> { 144*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension) 145*84cc1865SNikhil Kalra 146012b148dSMatthias Springer DummyExtension(int *counter, int numRecursive) 147012b148dSMatthias Springer : DialectExtension(), counter(counter), numRecursive(numRecursive) {} 148012b148dSMatthias Springer 149012b148dSMatthias Springer void apply(MLIRContext *ctx, TestDialect *dialect) const final { 150012b148dSMatthias Springer ++(*counter); 151012b148dSMatthias Springer DialectRegistry nestedRegistry; 152*84cc1865SNikhil Kalra for (int i = 0; i < numRecursive; ++i) { 153*84cc1865SNikhil Kalra // Create unique TypeIDs for these recursive extensions so they don't get 154*84cc1865SNikhil Kalra // de-duplicated. 155*84cc1865SNikhil Kalra auto extension = 156*84cc1865SNikhil Kalra std::make_unique<DummyExtension>(counter, /*numRecursive=*/0); 157*84cc1865SNikhil Kalra auto typeID = TypeID::getFromOpaquePointer(extension.get()); 158*84cc1865SNikhil Kalra nestedRegistry.addExtension(typeID, std::move(extension)); 159*84cc1865SNikhil Kalra } 160012b148dSMatthias Springer // Adding additional extensions may trigger a reallocation of the 161012b148dSMatthias Springer // `extensions` vector in the dialect registry. 162012b148dSMatthias Springer ctx->appendDialectRegistry(nestedRegistry); 163012b148dSMatthias Springer } 164012b148dSMatthias Springer 165012b148dSMatthias Springer private: 166012b148dSMatthias Springer int *counter; 167012b148dSMatthias Springer int numRecursive; 168012b148dSMatthias Springer }; 169012b148dSMatthias Springer } // namespace 170012b148dSMatthias Springer 171012b148dSMatthias Springer TEST(Dialect, NestedDialectExtension) { 172012b148dSMatthias Springer DialectRegistry registry; 173012b148dSMatthias Springer registry.insert<TestDialect>(); 174012b148dSMatthias Springer 175012b148dSMatthias Springer // Add an extension that adds 100 more extensions. 176012b148dSMatthias Springer int counter1 = 0; 177*84cc1865SNikhil Kalra registry.addExtension(TypeID::get<DummyExtension>(), 178*84cc1865SNikhil Kalra std::make_unique<DummyExtension>(&counter1, 100)); 179012b148dSMatthias Springer // Add one more extension. This should not crash. 180012b148dSMatthias Springer int counter2 = 0; 181*84cc1865SNikhil Kalra registry.addExtension(TypeID::getFromOpaquePointer(&counter2), 182*84cc1865SNikhil Kalra std::make_unique<DummyExtension>(&counter2, 0)); 183012b148dSMatthias Springer 184012b148dSMatthias Springer // Load dialect and apply extensions. 185012b148dSMatthias Springer MLIRContext context(registry); 186012b148dSMatthias Springer Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 187012b148dSMatthias Springer ASSERT_TRUE(testDialect != nullptr); 188012b148dSMatthias Springer 189*84cc1865SNikhil Kalra // Extensions are de-duplicated by typeID. Make sure that each expected 190012b148dSMatthias Springer // extension was applied at least once. 191012b148dSMatthias Springer EXPECT_GE(counter1, 101); 192012b148dSMatthias Springer EXPECT_GE(counter2, 1); 193012b148dSMatthias Springer } 194012b148dSMatthias Springer 195*84cc1865SNikhil Kalra TEST(Dialect, SubsetWithExtensions) { 196*84cc1865SNikhil Kalra DialectRegistry registry1, registry2; 197*84cc1865SNikhil Kalra registry1.insert<TestDialect>(); 198*84cc1865SNikhil Kalra registry2.insert<TestDialect>(); 199*84cc1865SNikhil Kalra 200*84cc1865SNikhil Kalra // Validate that the registries are equivalent. 201*84cc1865SNikhil Kalra ASSERT_TRUE(registry1.isSubsetOf(registry2)); 202*84cc1865SNikhil Kalra ASSERT_TRUE(registry2.isSubsetOf(registry1)); 203*84cc1865SNikhil Kalra 204*84cc1865SNikhil Kalra // Add extensions to registry2. 205*84cc1865SNikhil Kalra int counter = 0; 206*84cc1865SNikhil Kalra registry2.addExtension(TypeID::get<DummyExtension>(), 207*84cc1865SNikhil Kalra std::make_unique<DummyExtension>(&counter, 0)); 208*84cc1865SNikhil Kalra 209*84cc1865SNikhil Kalra // Expect that (1) is a subset of (2) but not the other way around. 210*84cc1865SNikhil Kalra ASSERT_TRUE(registry1.isSubsetOf(registry2)); 211*84cc1865SNikhil Kalra ASSERT_FALSE(registry2.isSubsetOf(registry1)); 212*84cc1865SNikhil Kalra 213*84cc1865SNikhil Kalra // Add extensions to registry1. 214*84cc1865SNikhil Kalra registry1.addExtension(TypeID::get<DummyExtension>(), 215*84cc1865SNikhil Kalra std::make_unique<DummyExtension>(&counter, 0)); 216*84cc1865SNikhil Kalra 217*84cc1865SNikhil Kalra // Expect that (1) and (2) are equivalent. 218*84cc1865SNikhil Kalra ASSERT_TRUE(registry1.isSubsetOf(registry2)); 219*84cc1865SNikhil Kalra ASSERT_TRUE(registry2.isSubsetOf(registry1)); 220*84cc1865SNikhil Kalra 221*84cc1865SNikhil Kalra // Load dialect and apply extensions. 222*84cc1865SNikhil Kalra MLIRContext context(registry1); 223*84cc1865SNikhil Kalra context.getOrLoadDialect<TestDialect>(); 224*84cc1865SNikhil Kalra context.appendDialectRegistry(registry2); 225*84cc1865SNikhil Kalra // Expect that the extension as only invoked once. 226*84cc1865SNikhil Kalra ASSERT_EQ(counter, 1); 227*84cc1865SNikhil Kalra } 228*84cc1865SNikhil Kalra 229be0a7e9fSMehdi Amini } // namespace 230