1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/DialectInterface.h" 11 #include "mlir/Support/TypeID.h" 12 #include "gtest/gtest.h" 13 14 using namespace mlir; 15 using namespace mlir::detail; 16 17 namespace { 18 struct TestDialect : public Dialect { 19 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) 20 21 static StringRef getDialectNamespace() { return "test"; }; 22 TestDialect(MLIRContext *context) 23 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} 24 }; 25 struct AnotherTestDialect : public Dialect { 26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect) 27 28 static StringRef getDialectNamespace() { return "test"; }; 29 AnotherTestDialect(MLIRContext *context) 30 : Dialect(getDialectNamespace(), context, 31 TypeID::get<AnotherTestDialect>()) {} 32 }; 33 34 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { 35 MLIRContext context; 36 37 // Registering a dialect with the same namespace twice should result in a 38 // failure. 39 context.loadDialect<TestDialect>(); 40 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), ""); 41 } 42 43 struct SecondTestDialect : public Dialect { 44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect) 45 46 static StringRef getDialectNamespace() { return "test2"; } 47 SecondTestDialect(MLIRContext *context) 48 : Dialect(getDialectNamespace(), context, 49 TypeID::get<SecondTestDialect>()) {} 50 }; 51 52 struct TestDialectInterfaceBase 53 : public DialectInterface::Base<TestDialectInterfaceBase> { 54 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase) 55 56 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} 57 virtual int function() const { return 42; } 58 }; 59 60 struct TestDialectInterface : public TestDialectInterfaceBase { 61 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface) 62 63 using TestDialectInterfaceBase::TestDialectInterfaceBase; 64 int function() const final { return 56; } 65 }; 66 67 struct SecondTestDialectInterface : public TestDialectInterfaceBase { 68 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface) 69 70 using TestDialectInterfaceBase::TestDialectInterfaceBase; 71 int function() const final { return 78; } 72 }; 73 74 TEST(Dialect, DelayedInterfaceRegistration) { 75 DialectRegistry registry; 76 registry.insert<TestDialect, SecondTestDialect>(); 77 78 // Delayed registration of an interface for TestDialect. 79 registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 80 dialect->addInterfaces<TestDialectInterface>(); 81 }); 82 83 MLIRContext context(registry); 84 85 // Load the TestDialect and check that the interface got registered for it. 86 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 87 ASSERT_TRUE(testDialect != nullptr); 88 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 89 EXPECT_TRUE(testDialectInterface != nullptr); 90 91 // Load the SecondTestDialect and check that the interface is not registered 92 // for it. 93 Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); 94 ASSERT_TRUE(secondTestDialect != nullptr); 95 auto *secondTestDialectInterface = 96 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 97 EXPECT_TRUE(secondTestDialectInterface == nullptr); 98 99 // Use the same mechanism as for delayed registration but for an already 100 // loaded dialect and check that the interface is now registered. 101 DialectRegistry secondRegistry; 102 secondRegistry.insert<SecondTestDialect>(); 103 secondRegistry.addExtension( 104 +[](MLIRContext *ctx, SecondTestDialect *dialect) { 105 dialect->addInterfaces<SecondTestDialectInterface>(); 106 }); 107 context.appendDialectRegistry(secondRegistry); 108 secondTestDialectInterface = 109 dyn_cast<SecondTestDialectInterface>(secondTestDialect); 110 EXPECT_TRUE(secondTestDialectInterface != nullptr); 111 } 112 113 TEST(Dialect, RepeatedDelayedRegistration) { 114 // Set up the delayed registration. 115 DialectRegistry registry; 116 registry.insert<TestDialect>(); 117 registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 118 dialect->addInterfaces<TestDialectInterface>(); 119 }); 120 MLIRContext context(registry); 121 122 // Load the TestDialect and check that the interface got registered for it. 123 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 124 ASSERT_TRUE(testDialect != nullptr); 125 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 126 EXPECT_TRUE(testDialectInterface != nullptr); 127 128 // Try adding the same dialect interface again and check that we don't crash 129 // on repeated interface registration. 130 DialectRegistry secondRegistry; 131 secondRegistry.insert<TestDialect>(); 132 secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { 133 dialect->addInterfaces<TestDialectInterface>(); 134 }); 135 context.appendDialectRegistry(secondRegistry); 136 testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect); 137 EXPECT_TRUE(testDialectInterface != nullptr); 138 } 139 140 namespace { 141 /// A dummy extension that increases a counter when being applied and 142 /// recursively adds additional extensions. 143 struct DummyExtension : DialectExtension<DummyExtension, TestDialect> { 144 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension) 145 146 DummyExtension(int *counter, int numRecursive) 147 : DialectExtension(), counter(counter), numRecursive(numRecursive) {} 148 149 void apply(MLIRContext *ctx, TestDialect *dialect) const final { 150 ++(*counter); 151 DialectRegistry nestedRegistry; 152 for (int i = 0; i < numRecursive; ++i) { 153 // Create unique TypeIDs for these recursive extensions so they don't get 154 // de-duplicated. 155 auto extension = 156 std::make_unique<DummyExtension>(counter, /*numRecursive=*/0); 157 auto typeID = TypeID::getFromOpaquePointer(extension.get()); 158 nestedRegistry.addExtension(typeID, std::move(extension)); 159 } 160 // Adding additional extensions may trigger a reallocation of the 161 // `extensions` vector in the dialect registry. 162 ctx->appendDialectRegistry(nestedRegistry); 163 } 164 165 private: 166 int *counter; 167 int numRecursive; 168 }; 169 } // namespace 170 171 TEST(Dialect, NestedDialectExtension) { 172 DialectRegistry registry; 173 registry.insert<TestDialect>(); 174 175 // Add an extension that adds 100 more extensions. 176 int counter1 = 0; 177 registry.addExtension(TypeID::get<DummyExtension>(), 178 std::make_unique<DummyExtension>(&counter1, 100)); 179 // Add one more extension. This should not crash. 180 int counter2 = 0; 181 registry.addExtension(TypeID::getFromOpaquePointer(&counter2), 182 std::make_unique<DummyExtension>(&counter2, 0)); 183 184 // Load dialect and apply extensions. 185 MLIRContext context(registry); 186 Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); 187 ASSERT_TRUE(testDialect != nullptr); 188 189 // Extensions are de-duplicated by typeID. Make sure that each expected 190 // extension was applied at least once. 191 EXPECT_GE(counter1, 101); 192 EXPECT_GE(counter2, 1); 193 } 194 195 TEST(Dialect, SubsetWithExtensions) { 196 DialectRegistry registry1, registry2; 197 registry1.insert<TestDialect>(); 198 registry2.insert<TestDialect>(); 199 200 // Validate that the registries are equivalent. 201 ASSERT_TRUE(registry1.isSubsetOf(registry2)); 202 ASSERT_TRUE(registry2.isSubsetOf(registry1)); 203 204 // Add extensions to registry2. 205 int counter = 0; 206 registry2.addExtension(TypeID::get<DummyExtension>(), 207 std::make_unique<DummyExtension>(&counter, 0)); 208 209 // Expect that (1) is a subset of (2) but not the other way around. 210 ASSERT_TRUE(registry1.isSubsetOf(registry2)); 211 ASSERT_FALSE(registry2.isSubsetOf(registry1)); 212 213 // Add extensions to registry1. 214 registry1.addExtension(TypeID::get<DummyExtension>(), 215 std::make_unique<DummyExtension>(&counter, 0)); 216 217 // Expect that (1) and (2) are equivalent. 218 ASSERT_TRUE(registry1.isSubsetOf(registry2)); 219 ASSERT_TRUE(registry2.isSubsetOf(registry1)); 220 221 // Load dialect and apply extensions. 222 MLIRContext context(registry1); 223 context.getOrLoadDialect<TestDialect>(); 224 context.appendDialectRegistry(registry2); 225 // Expect that the extension as only invoked once. 226 ASSERT_EQ(counter, 1); 227 } 228 229 } // namespace 230