xref: /llvm-project/mlir/unittests/IR/DialectTest.cpp (revision 2996a8d67553b9d469e01215b49bb1af17ad6d1e)
1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 namespace {
17 struct TestDialect : public Dialect {
18   static StringRef getDialectNamespace() { return "test"; };
19   TestDialect(MLIRContext *context)
20       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
21 };
22 struct AnotherTestDialect : public Dialect {
23   static StringRef getDialectNamespace() { return "test"; };
24   AnotherTestDialect(MLIRContext *context)
25       : Dialect(getDialectNamespace(), context,
26                 TypeID::get<AnotherTestDialect>()) {}
27 };
28 
29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
30   MLIRContext context;
31 
32   // Registering a dialect with the same namespace twice should result in a
33   // failure.
34   context.loadDialect<TestDialect>();
35   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
36 }
37 
38 struct SecondTestDialect : public Dialect {
39   static StringRef getDialectNamespace() { return "test2"; }
40   SecondTestDialect(MLIRContext *context)
41       : Dialect(getDialectNamespace(), context,
42                 TypeID::get<SecondTestDialect>()) {}
43 };
44 
45 struct TestDialectInterfaceBase
46     : public DialectInterface::Base<TestDialectInterfaceBase> {
47   TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
48   virtual int function() const { return 42; }
49 };
50 
51 struct TestDialectInterface : public TestDialectInterfaceBase {
52   using TestDialectInterfaceBase::TestDialectInterfaceBase;
53   int function() const final { return 56; }
54 };
55 
56 struct SecondTestDialectInterface : public TestDialectInterfaceBase {
57   using TestDialectInterfaceBase::TestDialectInterfaceBase;
58   int function() const final { return 78; }
59 };
60 
61 TEST(Dialect, DelayedInterfaceRegistration) {
62   DialectRegistry registry;
63   registry.insert<TestDialect, SecondTestDialect>();
64 
65   // Delayed registration of an interface for TestDialect.
66   registry.addDialectInterface<TestDialect, TestDialectInterface>();
67 
68   MLIRContext context(registry);
69 
70   // Load the TestDialect and check that the interface got registered for it.
71   auto *testDialect = context.getOrLoadDialect<TestDialect>();
72   ASSERT_TRUE(testDialect != nullptr);
73   auto *testDialectInterface =
74       testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
75   EXPECT_TRUE(testDialectInterface != nullptr);
76 
77   // Load the SecondTestDialect and check that the interface is not registered
78   // for it.
79   auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
80   ASSERT_TRUE(secondTestDialect != nullptr);
81   auto *secondTestDialectInterface =
82       secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
83   EXPECT_TRUE(secondTestDialectInterface == nullptr);
84 
85   // Use the same mechanism as for delayed registration but for an already
86   // loaded dialect and check that the interface is now registered.
87   DialectRegistry secondRegistry;
88   secondRegistry.insert<SecondTestDialect>();
89   secondRegistry
90       .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
91   context.appendDialectRegistry(secondRegistry);
92   secondTestDialectInterface =
93       secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
94   EXPECT_TRUE(secondTestDialectInterface != nullptr);
95 }
96 
97 } // end namespace
98