//===- DialectTest.cpp - Dialect unit tests -------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" #include "gtest/gtest.h" using namespace mlir; using namespace mlir::detail; namespace { struct TestDialect : public Dialect { static StringRef getDialectNamespace() { return "test"; }; TestDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) {} }; struct AnotherTestDialect : public Dialect { static StringRef getDialectNamespace() { return "test"; }; AnotherTestDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) {} }; TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { MLIRContext context; // Registering a dialect with the same namespace twice should result in a // failure. context.loadDialect(); ASSERT_DEATH(context.loadDialect(), ""); } struct SecondTestDialect : public Dialect { static StringRef getDialectNamespace() { return "test2"; } SecondTestDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) {} }; struct TestDialectInterfaceBase : public DialectInterface::Base { TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} virtual int function() const { return 42; } }; struct TestDialectInterface : public TestDialectInterfaceBase { using TestDialectInterfaceBase::TestDialectInterfaceBase; int function() const final { return 56; } }; struct SecondTestDialectInterface : public TestDialectInterfaceBase { using TestDialectInterfaceBase::TestDialectInterfaceBase; int function() const final { return 78; } }; TEST(Dialect, DelayedInterfaceRegistration) { DialectRegistry registry; registry.insert(); // Delayed registration of an interface for TestDialect. registry.addDialectInterface(); MLIRContext context(registry); // Load the TestDialect and check that the interface got registered for it. auto *testDialect = context.getOrLoadDialect(); ASSERT_TRUE(testDialect != nullptr); auto *testDialectInterface = testDialect->getRegisteredInterface(); EXPECT_TRUE(testDialectInterface != nullptr); // Load the SecondTestDialect and check that the interface is not registered // for it. auto *secondTestDialect = context.getOrLoadDialect(); ASSERT_TRUE(secondTestDialect != nullptr); auto *secondTestDialectInterface = secondTestDialect->getRegisteredInterface(); EXPECT_TRUE(secondTestDialectInterface == nullptr); // Use the same mechanism as for delayed registration but for an already // loaded dialect and check that the interface is now registered. DialectRegistry secondRegistry; secondRegistry.insert(); secondRegistry .addDialectInterface(); context.appendDialectRegistry(secondRegistry); secondTestDialectInterface = secondTestDialect->getRegisteredInterface(); EXPECT_TRUE(secondTestDialectInterface != nullptr); } } // end namespace