xref: /llvm-project/mlir/unittests/IR/InterfaceAttachmentTest.cpp (revision db79f4a2e9c97039de8fd86012c06c2681cf9ad9)
1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the tests for attaching interfaces to attributes and types
10 // without having to specify them on the attribute or type class directly.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "gtest/gtest.h"
19 
20 #include "../../test/lib/Dialect/Test/TestAttributes.h"
21 #include "../../test/lib/Dialect/Test/TestDialect.h"
22 #include "../../test/lib/Dialect/Test/TestTypes.h"
23 #include "mlir/IR/OwningOpRef.h"
24 
25 using namespace mlir;
26 using namespace test;
27 
28 namespace {
29 
30 /// External interface model for the integer type. Only provides non-default
31 /// methods.
32 struct Model
33     : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
34   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
35     return type.getIntOrFloatBitWidth() + arg;
36   }
37 
38   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
39 };
40 
41 /// External interface model for the float type. Provides non-deafult and
42 /// overrides default methods.
43 struct OverridingModel
44     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
45                                                       FloatType> {
46   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
47     return type.getIntOrFloatBitWidth() + arg;
48   }
49 
50   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
51 
52   unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
53     return 128;
54   }
55 
56   static unsigned staticGetArgument(unsigned arg) { return 420; }
57 };
58 
59 TEST(InterfaceAttachment, Type) {
60   MLIRContext context;
61 
62   // Check that the type has no interface.
63   IntegerType i8 = IntegerType::get(&context, 8);
64   ASSERT_FALSE(i8.isa<TestExternalTypeInterface>());
65 
66   // Attach an interface and check that the type now has the interface.
67   IntegerType::attachInterface<Model>(context);
68   TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>();
69   ASSERT_TRUE(iface != nullptr);
70   EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
71   EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
72   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
73   EXPECT_EQ(iface.staticGetArgument(17), 17u);
74 
75   // Same, but with the default implementation overridden.
76   FloatType flt = Float32Type::get(&context);
77   ASSERT_FALSE(flt.isa<TestExternalTypeInterface>());
78   Float32Type::attachInterface<OverridingModel>(context);
79   iface = flt.dyn_cast<TestExternalTypeInterface>();
80   ASSERT_TRUE(iface != nullptr);
81   EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
82   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
83   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
84   EXPECT_EQ(iface.staticGetArgument(17), 420u);
85 
86   // Other contexts shouldn't have the attribute attached.
87   MLIRContext other;
88   IntegerType i8other = IntegerType::get(&other, 8);
89   EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
90 }
91 
92 /// External interface model for the test type from the test dialect.
93 struct TestTypeModel
94     : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
95                                                       test::TestType> {
96   unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
97 
98   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
99 };
100 
101 TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
102   // Put the interface in the registry.
103   DialectRegistry registry;
104   registry.insert<test::TestDialect>();
105   registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
106 
107   // Check that when a context is constructed with the given registry, the type
108   // interface gets registered.
109   MLIRContext context(registry);
110   context.loadDialect<test::TestDialect>();
111   test::TestType testType = test::TestType::get(&context);
112   auto iface = testType.dyn_cast<TestExternalTypeInterface>();
113   ASSERT_TRUE(iface != nullptr);
114   EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
115   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
116 }
117 
118 TEST(InterfaceAttachment, TypeDelayedContextAppend) {
119   // Put the interface in the registry.
120   DialectRegistry registry;
121   registry.insert<test::TestDialect>();
122   registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
123 
124   // Check that when the registry gets appended to the context, the interface
125   // becomes available for objects in loaded dialects.
126   MLIRContext context;
127   context.loadDialect<test::TestDialect>();
128   test::TestType testType = test::TestType::get(&context);
129   EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
130   context.appendDialectRegistry(registry);
131   EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
132 }
133 
134 TEST(InterfaceAttachment, RepeatedRegistration) {
135   DialectRegistry registry;
136   registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
137   MLIRContext context(registry);
138 
139   // Should't fail on repeated registration through the dialect registry.
140   context.appendDialectRegistry(registry);
141 }
142 
143 TEST(InterfaceAttachment, TypeBuiltinDelayed) {
144   // Builtin dialect needs to registration or loading, but delayed interface
145   // registration must still work.
146   DialectRegistry registry;
147   registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
148 
149   MLIRContext context(registry);
150   IntegerType i16 = IntegerType::get(&context, 16);
151   EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
152 
153   MLIRContext initiallyEmpty;
154   IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
155   EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
156   initiallyEmpty.appendDialectRegistry(registry);
157   EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
158 }
159 
160 /// The interface provides a default implementation that expects
161 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
162 /// just derives from the ExternalModel.
163 struct TestExternalFallbackTypeIntegerModel
164     : public TestExternalFallbackTypeInterface::ExternalModel<
165           TestExternalFallbackTypeIntegerModel, IntegerType> {};
166 
167 /// The interface provides a default implementation that expects
168 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
169 /// FallbackModel instead to override this and make sure the code still compiles
170 /// because we never instantiate the ExternalModel class template with a
171 /// template argument that would have led to compilation failures.
172 struct TestExternalFallbackTypeVectorModel
173     : public TestExternalFallbackTypeInterface::FallbackModel<
174           TestExternalFallbackTypeVectorModel> {
175   unsigned getBitwidth(Type type) const {
176     IntegerType elementType = type.cast<VectorType>()
177                                   .getElementType()
178                                   .dyn_cast_or_null<IntegerType>();
179     return elementType ? elementType.getWidth() : 0;
180   }
181 };
182 
183 TEST(InterfaceAttachment, Fallback) {
184   MLIRContext context;
185 
186   // Just check that we can attach the interface.
187   IntegerType i8 = IntegerType::get(&context, 8);
188   ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>());
189   IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
190   ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>());
191 
192   // Call the method so it is guaranteed not to be instantiated.
193   VectorType vec = VectorType::get({42}, i8);
194   ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>());
195   VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
196   ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>());
197   EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u);
198 }
199 
200 /// External model for attribute interfaces.
201 struct TestExternalIntegerAttrModel
202     : public TestExternalAttrInterface::ExternalModel<
203           TestExternalIntegerAttrModel, IntegerAttr> {
204   const Dialect *getDialectPtr(Attribute attr) const {
205     return &attr.cast<IntegerAttr>().getDialect();
206   }
207 
208   static int getSomeNumber() { return 42; }
209 };
210 
211 TEST(InterfaceAttachment, Attribute) {
212   MLIRContext context;
213 
214   // Attribute interfaces use the exact same mechanism as types, so just check
215   // that the basics work for attributes.
216   IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
217   ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
218   IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
219   auto iface = attr.dyn_cast<TestExternalAttrInterface>();
220   ASSERT_TRUE(iface != nullptr);
221   EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
222   EXPECT_EQ(iface.getSomeNumber(), 42);
223 }
224 
225 /// External model for an interface attachable to a non-builtin attribute.
226 struct TestExternalSimpleAAttrModel
227     : public TestExternalAttrInterface::ExternalModel<
228           TestExternalSimpleAAttrModel, test::SimpleAAttr> {
229   const Dialect *getDialectPtr(Attribute attr) const {
230     return &attr.getDialect();
231   }
232 
233   static int getSomeNumber() { return 21; }
234 };
235 
236 TEST(InterfaceAttachmentTest, AttributeDelayed) {
237   // Attribute interfaces use the exact same mechanism as types, so just check
238   // that the delayed registration work for attributes.
239   DialectRegistry registry;
240   registry.insert<test::TestDialect>();
241   registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
242                             TestExternalSimpleAAttrModel>();
243 
244   MLIRContext context(registry);
245   context.loadDialect<test::TestDialect>();
246   auto attr = test::SimpleAAttr::get(&context);
247   EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
248 
249   MLIRContext initiallyEmpty;
250   initiallyEmpty.loadDialect<test::TestDialect>();
251   attr = test::SimpleAAttr::get(&initiallyEmpty);
252   EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
253   initiallyEmpty.appendDialectRegistry(registry);
254   EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
255 }
256 
257 /// External interface model for the module operation. Only provides non-default
258 /// methods.
259 struct TestExternalOpModel
260     : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
261                                                     ModuleOp> {
262   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
263     return op->getName().getStringRef().size() + arg;
264   }
265 
266   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
267     return ModuleOp::getOperationName().size() + 2 * arg;
268   }
269 };
270 
271 /// External interface model for the func operation. Provides non-deafult and
272 /// overrides default methods.
273 struct TestExternalOpOverridingModel
274     : public TestExternalOpInterface::FallbackModel<
275           TestExternalOpOverridingModel> {
276   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
277     return op->getName().getStringRef().size() + arg;
278   }
279 
280   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
281     return FuncOp::getOperationName().size() + 2 * arg;
282   }
283 
284   unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
285     return 42;
286   }
287 
288   static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
289 };
290 
291 TEST(InterfaceAttachment, Operation) {
292   MLIRContext context;
293 
294   // Initially, the operation doesn't have the interface.
295   OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(UnknownLoc::get(&context));
296   ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
297 
298   // We can attach an external interface and now the operaiton has it.
299   ModuleOp::attachInterface<TestExternalOpModel>(context);
300   auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
301   ASSERT_TRUE(iface != nullptr);
302   EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
303   EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
304   EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
305   EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
306 
307   // Default implementation can be overridden.
308   OwningOpRef<FuncOp> funcOp =
309       FuncOp::create(UnknownLoc::get(&context), "function",
310                      FunctionType::get(&context, {}, {}));
311   ASSERT_FALSE(isa<TestExternalOpInterface>(funcOp->getOperation()));
312   FuncOp::attachInterface<TestExternalOpOverridingModel>(context);
313   iface = dyn_cast<TestExternalOpInterface>(funcOp->getOperation());
314   ASSERT_TRUE(iface != nullptr);
315   EXPECT_EQ(iface.getNameLengthPlusArg(10), 22u);
316   EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
317   EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 28u);
318   EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
319 
320   // Another context doesn't have the interfaces registered.
321   MLIRContext other;
322   OwningOpRef<ModuleOp> otherModuleOp =
323       ModuleOp::create(UnknownLoc::get(&other));
324   ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
325 }
326 
327 template <class ConcreteOp>
328 struct TestExternalTestOpModel
329     : public TestExternalOpInterface::ExternalModel<
330           TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
331   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
332     return op->getName().getStringRef().size() + arg;
333   }
334 
335   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
336     return ConcreteOp::getOperationName().size() + 2 * arg;
337   }
338 };
339 
340 TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
341   DialectRegistry registry;
342   registry.insert<test::TestDialect>();
343   registry.addOpInterface<ModuleOp, TestExternalOpModel>();
344   registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
345   registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
346 
347   // Construct the context directly from a registry. The interfaces are expected
348   // to be readily available on operations.
349   MLIRContext context(registry);
350   context.loadDialect<test::TestDialect>();
351 
352   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
353   OpBuilder builder(module->getBody(), module->getBody()->begin());
354   auto opJ =
355       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
356   auto opH =
357       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
358   auto opI =
359       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
360 
361   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
362   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
363   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
364   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
365 }
366 
367 TEST(InterfaceAttachment, OperationDelayedContextAppend) {
368   DialectRegistry registry;
369   registry.insert<test::TestDialect>();
370   registry.addOpInterface<ModuleOp, TestExternalOpModel>();
371   registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
372   registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
373 
374   // Construct the context, create ops, and only then append the registry. The
375   // interfaces are expected to be available after appending the registry.
376   MLIRContext context;
377   context.loadDialect<test::TestDialect>();
378 
379   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
380   OpBuilder builder(module->getBody(), module->getBody()->begin());
381   auto opJ =
382       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
383   auto opH =
384       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
385   auto opI =
386       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
387 
388   EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
389   EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
390   EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
391   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
392 
393   context.appendDialectRegistry(registry);
394 
395   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
396   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
397   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
398   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
399 }
400 
401 } // end namespace
402