xref: /llvm-project/mlir/unittests/IR/DialectTest.cpp (revision 77eee5795e2cf753e4400fb089d01018417c4ee0)
1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 namespace {
17 struct TestDialect : public Dialect {
18   static StringRef getDialectNamespace() { return "test"; };
19   TestDialect(MLIRContext *context)
20       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
21 };
22 struct AnotherTestDialect : public Dialect {
23   static StringRef getDialectNamespace() { return "test"; };
24   AnotherTestDialect(MLIRContext *context)
25       : Dialect(getDialectNamespace(), context,
26                 TypeID::get<AnotherTestDialect>()) {}
27 };
28 
29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
30   MLIRContext context;
31 
32   // Registering a dialect with the same namespace twice should result in a
33   // failure.
34   context.loadDialect<TestDialect>();
35   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
36 }
37 
38 struct SecondTestDialect : public Dialect {
39   static StringRef getDialectNamespace() { return "test2"; }
40   SecondTestDialect(MLIRContext *context)
41       : Dialect(getDialectNamespace(), context,
42                 TypeID::get<SecondTestDialect>()) {}
43 };
44 
45 struct TestDialectInterfaceBase
46     : public DialectInterface::Base<TestDialectInterfaceBase> {
47   TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
48   virtual int function() const { return 42; }
49 };
50 
51 struct TestDialectInterface : public TestDialectInterfaceBase {
52   using TestDialectInterfaceBase::TestDialectInterfaceBase;
53   int function() const final { return 56; }
54 };
55 
56 struct SecondTestDialectInterface : public TestDialectInterfaceBase {
57   using TestDialectInterfaceBase::TestDialectInterfaceBase;
58   int function() const final { return 78; }
59 };
60 
61 TEST(Dialect, DelayedInterfaceRegistration) {
62   DialectRegistry registry;
63   registry.insert<TestDialect, SecondTestDialect>();
64 
65   // Delayed registration of an interface for TestDialect.
66   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
67     dialect->addInterfaces<TestDialectInterface>();
68   });
69 
70   MLIRContext context(registry);
71 
72   // Load the TestDialect and check that the interface got registered for it.
73   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
74   ASSERT_TRUE(testDialect != nullptr);
75   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
76   EXPECT_TRUE(testDialectInterface != nullptr);
77 
78   // Load the SecondTestDialect and check that the interface is not registered
79   // for it.
80   Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
81   ASSERT_TRUE(secondTestDialect != nullptr);
82   auto *secondTestDialectInterface =
83       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
84   EXPECT_TRUE(secondTestDialectInterface == nullptr);
85 
86   // Use the same mechanism as for delayed registration but for an already
87   // loaded dialect and check that the interface is now registered.
88   DialectRegistry secondRegistry;
89   secondRegistry.insert<SecondTestDialect>();
90   secondRegistry.addExtension(
91       +[](MLIRContext *ctx, SecondTestDialect *dialect) {
92         dialect->addInterfaces<SecondTestDialectInterface>();
93       });
94   context.appendDialectRegistry(secondRegistry);
95   secondTestDialectInterface =
96       dyn_cast<SecondTestDialectInterface>(secondTestDialect);
97   EXPECT_TRUE(secondTestDialectInterface != nullptr);
98 }
99 
100 TEST(Dialect, RepeatedDelayedRegistration) {
101   // Set up the delayed registration.
102   DialectRegistry registry;
103   registry.insert<TestDialect>();
104   registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
105     dialect->addInterfaces<TestDialectInterface>();
106   });
107   MLIRContext context(registry);
108 
109   // Load the TestDialect and check that the interface got registered for it.
110   Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
111   ASSERT_TRUE(testDialect != nullptr);
112   auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
113   EXPECT_TRUE(testDialectInterface != nullptr);
114 
115   // Try adding the same dialect interface again and check that we don't crash
116   // on repeated interface registration.
117   DialectRegistry secondRegistry;
118   secondRegistry.insert<TestDialect>();
119   secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
120     dialect->addInterfaces<TestDialectInterface>();
121   });
122   context.appendDialectRegistry(secondRegistry);
123   testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
124   EXPECT_TRUE(testDialectInterface != nullptr);
125 }
126 
127 } // namespace
128