xref: /llvm-project/mlir/unittests/IR/InterfaceAttachmentTest.cpp (revision f8479d9de59d0f2f3997319b0ec189eb086aa85a)
1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the tests for attaching interfaces to attributes and types
10 // without having to specify them on the attribute or type class directly.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "gtest/gtest.h"
19 
20 #include "../../test/lib/Dialect/Test/TestAttributes.h"
21 #include "../../test/lib/Dialect/Test/TestDialect.h"
22 #include "../../test/lib/Dialect/Test/TestTypes.h"
23 
24 using namespace mlir;
25 using namespace mlir::test;
26 
27 namespace {
28 
29 /// External interface model for the integer type. Only provides non-default
30 /// methods.
31 struct Model
32     : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
33   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
34     return type.getIntOrFloatBitWidth() + arg;
35   }
36 
37   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
38 };
39 
40 /// External interface model for the float type. Provides non-deafult and
41 /// overrides default methods.
42 struct OverridingModel
43     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
44                                                       FloatType> {
45   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
46     return type.getIntOrFloatBitWidth() + arg;
47   }
48 
49   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
50 
51   unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
52     return 128;
53   }
54 
55   static unsigned staticGetArgument(unsigned arg) { return 420; }
56 };
57 
58 TEST(InterfaceAttachment, Type) {
59   MLIRContext context;
60 
61   // Check that the type has no interface.
62   IntegerType i8 = IntegerType::get(&context, 8);
63   ASSERT_FALSE(i8.isa<TestExternalTypeInterface>());
64 
65   // Attach an interface and check that the type now has the interface.
66   IntegerType::attachInterface<Model>(context);
67   TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>();
68   ASSERT_TRUE(iface != nullptr);
69   EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
70   EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
71   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
72   EXPECT_EQ(iface.staticGetArgument(17), 17u);
73 
74   // Same, but with the default implementation overridden.
75   FloatType flt = Float32Type::get(&context);
76   ASSERT_FALSE(flt.isa<TestExternalTypeInterface>());
77   Float32Type::attachInterface<OverridingModel>(context);
78   iface = flt.dyn_cast<TestExternalTypeInterface>();
79   ASSERT_TRUE(iface != nullptr);
80   EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
81   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
82   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
83   EXPECT_EQ(iface.staticGetArgument(17), 420u);
84 
85   // Other contexts shouldn't have the attribute attached.
86   MLIRContext other;
87   IntegerType i8other = IntegerType::get(&other, 8);
88   EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
89 }
90 
91 /// External interface model for the test type from the test dialect.
92 struct TestTypeModel
93     : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
94                                                       test::TestType> {
95   unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
96 
97   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
98 };
99 
100 TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
101   // Put the interface in the registry.
102   DialectRegistry registry;
103   registry.insert<test::TestDialect>();
104   registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
105 
106   // Check that when a context is constructed with the given registry, the type
107   // interface gets registered.
108   MLIRContext context(registry);
109   context.loadDialect<test::TestDialect>();
110   test::TestType testType = test::TestType::get(&context);
111   auto iface = testType.dyn_cast<TestExternalTypeInterface>();
112   ASSERT_TRUE(iface != nullptr);
113   EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
114   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
115 }
116 
117 TEST(InterfaceAttachment, TypeDelayedContextAppend) {
118   // Put the interface in the registry.
119   DialectRegistry registry;
120   registry.insert<test::TestDialect>();
121   registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
122 
123   // Check that when the registry gets appended to the context, the interface
124   // becomes available for objects in loaded dialects.
125   MLIRContext context;
126   context.loadDialect<test::TestDialect>();
127   test::TestType testType = test::TestType::get(&context);
128   EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
129   context.appendDialectRegistry(registry);
130   EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
131 }
132 
133 TEST(InterfaceAttachment, RepeatedRegistration) {
134   DialectRegistry registry;
135   registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
136   MLIRContext context(registry);
137 
138   // Should't fail on repeated registration through the dialect registry.
139   context.appendDialectRegistry(registry);
140 }
141 
142 TEST(InterfaceAttachment, TypeBuiltinDelayed) {
143   // Builtin dialect needs to registration or loading, but delayed interface
144   // registration must still work.
145   DialectRegistry registry;
146   registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
147 
148   MLIRContext context(registry);
149   IntegerType i16 = IntegerType::get(&context, 16);
150   EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
151 
152   MLIRContext initiallyEmpty;
153   IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
154   EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
155   initiallyEmpty.appendDialectRegistry(registry);
156   EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
157 }
158 
159 /// The interface provides a default implementation that expects
160 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
161 /// just derives from the ExternalModel.
162 struct TestExternalFallbackTypeIntegerModel
163     : public TestExternalFallbackTypeInterface::ExternalModel<
164           TestExternalFallbackTypeIntegerModel, IntegerType> {};
165 
166 /// The interface provides a default implementation that expects
167 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
168 /// FallbackModel instead to override this and make sure the code still compiles
169 /// because we never instantiate the ExternalModel class template with a
170 /// template argument that would have led to compilation failures.
171 struct TestExternalFallbackTypeVectorModel
172     : public TestExternalFallbackTypeInterface::FallbackModel<
173           TestExternalFallbackTypeVectorModel> {
174   unsigned getBitwidth(Type type) const {
175     IntegerType elementType = type.cast<VectorType>()
176                                   .getElementType()
177                                   .dyn_cast_or_null<IntegerType>();
178     return elementType ? elementType.getWidth() : 0;
179   }
180 };
181 
182 TEST(InterfaceAttachment, Fallback) {
183   MLIRContext context;
184 
185   // Just check that we can attach the interface.
186   IntegerType i8 = IntegerType::get(&context, 8);
187   ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>());
188   IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
189   ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>());
190 
191   // Call the method so it is guaranteed not to be instantiated.
192   VectorType vec = VectorType::get({42}, i8);
193   ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>());
194   VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
195   ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>());
196   EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u);
197 }
198 
199 /// External model for attribute interfaces.
200 struct TestExternalIntegerAttrModel
201     : public TestExternalAttrInterface::ExternalModel<
202           TestExternalIntegerAttrModel, IntegerAttr> {
203   const Dialect *getDialectPtr(Attribute attr) const {
204     return &attr.cast<IntegerAttr>().getDialect();
205   }
206 
207   static int getSomeNumber() { return 42; }
208 };
209 
210 TEST(InterfaceAttachment, Attribute) {
211   MLIRContext context;
212 
213   // Attribute interfaces use the exact same mechanism as types, so just check
214   // that the basics work for attributes.
215   IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
216   ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
217   IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
218   auto iface = attr.dyn_cast<TestExternalAttrInterface>();
219   ASSERT_TRUE(iface != nullptr);
220   EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
221   EXPECT_EQ(iface.getSomeNumber(), 42);
222 }
223 
224 /// External model for an interface attachable to a non-builtin attribute.
225 struct TestExternalSimpleAAttrModel
226     : public TestExternalAttrInterface::ExternalModel<
227           TestExternalSimpleAAttrModel, test::SimpleAAttr> {
228   const Dialect *getDialectPtr(Attribute attr) const {
229     return &attr.getDialect();
230   }
231 
232   static int getSomeNumber() { return 21; }
233 };
234 
235 TEST(InterfaceAttachmentTest, AttributeDelayed) {
236   // Attribute interfaces use the exact same mechanism as types, so just check
237   // that the delayed registration work for attributes.
238   DialectRegistry registry;
239   registry.insert<test::TestDialect>();
240   registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
241                             TestExternalSimpleAAttrModel>();
242 
243   MLIRContext context(registry);
244   context.loadDialect<test::TestDialect>();
245   auto attr = test::SimpleAAttr::get(&context);
246   EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
247 
248   MLIRContext initiallyEmpty;
249   initiallyEmpty.loadDialect<test::TestDialect>();
250   attr = test::SimpleAAttr::get(&initiallyEmpty);
251   EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
252   initiallyEmpty.appendDialectRegistry(registry);
253   EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
254 }
255 
256 /// External interface model for the module operation. Only provides non-default
257 /// methods.
258 struct TestExternalOpModel
259     : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
260                                                     ModuleOp> {
261   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
262     return op->getName().getStringRef().size() + arg;
263   }
264 
265   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
266     return ModuleOp::getOperationName().size() + 2 * arg;
267   }
268 };
269 
270 /// External interface model for the func operation. Provides non-deafult and
271 /// overrides default methods.
272 struct TestExternalOpOverridingModel
273     : public TestExternalOpInterface::FallbackModel<
274           TestExternalOpOverridingModel> {
275   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
276     return op->getName().getStringRef().size() + arg;
277   }
278 
279   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
280     return FuncOp::getOperationName().size() + 2 * arg;
281   }
282 
283   unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
284     return 42;
285   }
286 
287   static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
288 };
289 
290 TEST(InterfaceAttachment, Operation) {
291   MLIRContext context;
292 
293   // Initially, the operation doesn't have the interface.
294   auto moduleOp = ModuleOp::create(UnknownLoc::get(&context));
295   ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp.getOperation()));
296 
297   // We can attach an external interface and now the operaiton has it.
298   ModuleOp::attachInterface<TestExternalOpModel>(context);
299   auto iface = dyn_cast<TestExternalOpInterface>(moduleOp.getOperation());
300   ASSERT_TRUE(iface != nullptr);
301   EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
302   EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
303   EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
304   EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
305 
306   // Default implementation can be overridden.
307   auto funcOp = FuncOp::create(UnknownLoc::get(&context), "function",
308                                FunctionType::get(&context, {}, {}));
309   ASSERT_FALSE(isa<TestExternalOpInterface>(funcOp.getOperation()));
310   FuncOp::attachInterface<TestExternalOpOverridingModel>(context);
311   iface = dyn_cast<TestExternalOpInterface>(funcOp.getOperation());
312   ASSERT_TRUE(iface != nullptr);
313   EXPECT_EQ(iface.getNameLengthPlusArg(10), 22u);
314   EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
315   EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 28u);
316   EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
317 
318   // Another context doesn't have the interfaces registered.
319   MLIRContext other;
320   auto otherModuleOp = ModuleOp::create(UnknownLoc::get(&other));
321   ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
322 }
323 
324 struct TestExternalTestOpModel
325     : public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
326                                                     test::OpJ> {
327   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
328     return op->getName().getStringRef().size() + arg;
329   }
330 
331   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
332     return test::OpJ::getOperationName().size() + 2 * arg;
333   }
334 };
335 
336 TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
337   DialectRegistry registry;
338   registry.insert<test::TestDialect>();
339   registry.addOpInterface<ModuleOp, TestExternalOpModel>();
340   registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
341 
342   // Construct the context directly from a registry. The interfaces are expected
343   // to be readily available on operations.
344   MLIRContext context(registry);
345   context.loadDialect<test::TestDialect>();
346   ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
347   OpBuilder builder(module);
348   auto op =
349       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
350   EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
351   EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
352 }
353 
354 TEST(InterfaceAttachment, OperationDelayedContextAppend) {
355   DialectRegistry registry;
356   registry.insert<test::TestDialect>();
357   registry.addOpInterface<ModuleOp, TestExternalOpModel>();
358   registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
359 
360   // Construct the context, create ops, and only then append the registry. The
361   // interfaces are expected to be available after appending the registry.
362   MLIRContext context;
363   context.loadDialect<test::TestDialect>();
364   ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
365   OpBuilder builder(module);
366   auto op =
367       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
368   EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
369   EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
370   context.appendDialectRegistry(registry);
371   EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
372   EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
373 }
374 
375 } // end namespace
376