xref: /llvm-project/mlir/unittests/IR/InterfaceAttachmentTest.cpp (revision c24ce324d56328e4b91c8797ea4935545084303e)
1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the tests for attaching interfaces to attributes and types
10 // without having to specify them on the attribute or type class directly.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "gtest/gtest.h"
19 
20 #include "../../test/lib/Dialect/Test/TestAttributes.h"
21 #include "../../test/lib/Dialect/Test/TestDialect.h"
22 #include "../../test/lib/Dialect/Test/TestOps.h"
23 #include "../../test/lib/Dialect/Test/TestTypes.h"
24 #include "mlir/IR/OwningOpRef.h"
25 
26 using namespace mlir;
27 using namespace test;
28 
29 namespace {
30 
31 /// External interface model for the integer type. Only provides non-default
32 /// methods.
33 struct Model
34     : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
35   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
36     return type.getIntOrFloatBitWidth() + arg;
37   }
38 
39   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
40 };
41 
42 /// External interface model for the float type. Provides non-deafult and
43 /// overrides default methods.
44 struct OverridingModel
45     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
46                                                       Float32Type> {
47   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
48     return type.getIntOrFloatBitWidth() + arg;
49   }
50 
51   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
52 
53   unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
54     return 128;
55   }
56 
57   static unsigned staticGetArgument(unsigned arg) { return 420; }
58 };
59 
60 TEST(InterfaceAttachment, Type) {
61   MLIRContext context;
62 
63   // Check that the type has no interface.
64   IntegerType i8 = IntegerType::get(&context, 8);
65   ASSERT_FALSE(isa<TestExternalTypeInterface>(i8));
66 
67   // Attach an interface and check that the type now has the interface.
68   IntegerType::attachInterface<Model>(context);
69   TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8);
70   ASSERT_TRUE(iface != nullptr);
71   EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
72   EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
73   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
74   EXPECT_EQ(iface.staticGetArgument(17), 17u);
75 
76   // Same, but with the default implementation overridden.
77   FloatType flt = Float32Type::get(&context);
78   ASSERT_FALSE(isa<TestExternalTypeInterface>(flt));
79   Float32Type::attachInterface<OverridingModel>(context);
80   iface = dyn_cast<TestExternalTypeInterface>(flt);
81   ASSERT_TRUE(iface != nullptr);
82   EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
83   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
84   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
85   EXPECT_EQ(iface.staticGetArgument(17), 420u);
86 
87   // Other contexts shouldn't have the attribute attached.
88   MLIRContext other;
89   IntegerType i8other = IntegerType::get(&other, 8);
90   EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other));
91 }
92 
93 /// External interface model for the test type from the test dialect.
94 struct TestTypeModel
95     : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
96                                                       test::TestType> {
97   unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
98 
99   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
100 };
101 
102 TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
103   // Put the interface in the registry.
104   DialectRegistry registry;
105   registry.insert<test::TestDialect>();
106   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
107     test::TestType::attachInterface<TestTypeModel>(*ctx);
108   });
109 
110   // Check that when a context is constructed with the given registry, the type
111   // interface gets registered.
112   MLIRContext context(registry);
113   context.loadDialect<test::TestDialect>();
114   test::TestType testType = test::TestType::get(&context);
115   auto iface = dyn_cast<TestExternalTypeInterface>(testType);
116   ASSERT_TRUE(iface != nullptr);
117   EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
118   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
119 }
120 
121 TEST(InterfaceAttachment, TypeDelayedContextAppend) {
122   // Put the interface in the registry.
123   DialectRegistry registry;
124   registry.insert<test::TestDialect>();
125   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
126     test::TestType::attachInterface<TestTypeModel>(*ctx);
127   });
128 
129   // Check that when the registry gets appended to the context, the interface
130   // becomes available for objects in loaded dialects.
131   MLIRContext context;
132   context.loadDialect<test::TestDialect>();
133   test::TestType testType = test::TestType::get(&context);
134   EXPECT_FALSE(isa<TestExternalTypeInterface>(testType));
135   context.appendDialectRegistry(registry);
136   EXPECT_TRUE(isa<TestExternalTypeInterface>(testType));
137 }
138 
139 TEST(InterfaceAttachment, RepeatedRegistration) {
140   DialectRegistry registry;
141   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
142     IntegerType::attachInterface<Model>(*ctx);
143   });
144   MLIRContext context(registry);
145 
146   // Should't fail on repeated registration through the dialect registry.
147   context.appendDialectRegistry(registry);
148 }
149 
150 TEST(InterfaceAttachment, TypeBuiltinDelayed) {
151   // Builtin dialect needs to registration or loading, but delayed interface
152   // registration must still work.
153   DialectRegistry registry;
154   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
155     IntegerType::attachInterface<Model>(*ctx);
156   });
157 
158   MLIRContext context(registry);
159   IntegerType i16 = IntegerType::get(&context, 16);
160   EXPECT_TRUE(isa<TestExternalTypeInterface>(i16));
161 
162   MLIRContext initiallyEmpty;
163   IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
164   EXPECT_FALSE(isa<TestExternalTypeInterface>(i32));
165   initiallyEmpty.appendDialectRegistry(registry);
166   EXPECT_TRUE(isa<TestExternalTypeInterface>(i32));
167 }
168 
169 /// The interface provides a default implementation that expects
170 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
171 /// just derives from the ExternalModel.
172 struct TestExternalFallbackTypeIntegerModel
173     : public TestExternalFallbackTypeInterface::ExternalModel<
174           TestExternalFallbackTypeIntegerModel, IntegerType> {};
175 
176 /// The interface provides a default implementation that expects
177 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
178 /// FallbackModel instead to override this and make sure the code still compiles
179 /// because we never instantiate the ExternalModel class template with a
180 /// template argument that would have led to compilation failures.
181 struct TestExternalFallbackTypeVectorModel
182     : public TestExternalFallbackTypeInterface::FallbackModel<
183           TestExternalFallbackTypeVectorModel> {
184   unsigned getBitwidth(Type type) const {
185     IntegerType elementType =
186         dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType());
187     return elementType ? elementType.getWidth() : 0;
188   }
189 };
190 
191 TEST(InterfaceAttachment, Fallback) {
192   MLIRContext context;
193 
194   // Just check that we can attach the interface.
195   IntegerType i8 = IntegerType::get(&context, 8);
196   ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8));
197   IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
198   ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8));
199 
200   // Call the method so it is guaranteed not to be instantiated.
201   VectorType vec = VectorType::get({42}, i8);
202   ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec));
203   VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
204   ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec));
205   EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u);
206 }
207 
208 /// External model for attribute interfaces.
209 struct TestExternalIntegerAttrModel
210     : public TestExternalAttrInterface::ExternalModel<
211           TestExternalIntegerAttrModel, IntegerAttr> {
212   const Dialect *getDialectPtr(Attribute attr) const {
213     return &cast<IntegerAttr>(attr).getDialect();
214   }
215 
216   static int getSomeNumber() { return 42; }
217 };
218 
219 TEST(InterfaceAttachment, Attribute) {
220   MLIRContext context;
221 
222   // Attribute interfaces use the exact same mechanism as types, so just check
223   // that the basics work for attributes.
224   IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
225   ASSERT_FALSE(isa<TestExternalAttrInterface>(attr));
226   IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
227   auto iface = dyn_cast<TestExternalAttrInterface>(attr);
228   ASSERT_TRUE(iface != nullptr);
229   EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
230   EXPECT_EQ(iface.getSomeNumber(), 42);
231 }
232 
233 /// External model for an interface attachable to a non-builtin attribute.
234 struct TestExternalSimpleAAttrModel
235     : public TestExternalAttrInterface::ExternalModel<
236           TestExternalSimpleAAttrModel, test::SimpleAAttr> {
237   const Dialect *getDialectPtr(Attribute attr) const {
238     return &attr.getDialect();
239   }
240 
241   static int getSomeNumber() { return 21; }
242 };
243 
244 TEST(InterfaceAttachmentTest, AttributeDelayed) {
245   // Attribute interfaces use the exact same mechanism as types, so just check
246   // that the delayed registration work for attributes.
247   DialectRegistry registry;
248   registry.insert<test::TestDialect>();
249   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
250     test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
251   });
252 
253   MLIRContext context(registry);
254   context.loadDialect<test::TestDialect>();
255   auto attr = test::SimpleAAttr::get(&context);
256   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
257 
258   MLIRContext initiallyEmpty;
259   initiallyEmpty.loadDialect<test::TestDialect>();
260   attr = test::SimpleAAttr::get(&initiallyEmpty);
261   EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
262   initiallyEmpty.appendDialectRegistry(registry);
263   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
264 }
265 
266 /// External interface model for the module operation. Only provides non-default
267 /// methods.
268 struct TestExternalOpModel
269     : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
270                                                     ModuleOp> {
271   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
272     return op->getName().getStringRef().size() + arg;
273   }
274 
275   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
276     return ModuleOp::getOperationName().size() + 2 * arg;
277   }
278 };
279 
280 /// External interface model for the func operation. Provides non-deafult and
281 /// overrides default methods.
282 struct TestExternalOpOverridingModel
283     : public TestExternalOpInterface::FallbackModel<
284           TestExternalOpOverridingModel> {
285   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
286     return op->getName().getStringRef().size() + arg;
287   }
288 
289   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
290     return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg;
291   }
292 
293   unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
294     return 42;
295   }
296 
297   static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
298 };
299 
300 TEST(InterfaceAttachment, Operation) {
301   MLIRContext context;
302   OpBuilder builder(&context);
303 
304   // Initially, the operation doesn't have the interface.
305   OwningOpRef<ModuleOp> moduleOp =
306       builder.create<ModuleOp>(UnknownLoc::get(&context));
307   ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
308 
309   // We can attach an external interface and now the operaiton has it.
310   ModuleOp::attachInterface<TestExternalOpModel>(context);
311   auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
312   ASSERT_TRUE(iface != nullptr);
313   EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
314   EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
315   EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
316   EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
317 
318   // Default implementation can be overridden.
319   OwningOpRef<UnrealizedConversionCastOp> castOp =
320       builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context),
321                                                  TypeRange(), ValueRange());
322   ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation()));
323   UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>(
324       context);
325   iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation());
326   ASSERT_TRUE(iface != nullptr);
327   EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u);
328   EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
329   EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u);
330   EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
331 
332   // Another context doesn't have the interfaces registered.
333   MLIRContext other;
334   OwningOpRef<ModuleOp> otherModuleOp =
335       ModuleOp::create(UnknownLoc::get(&other));
336   ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
337 }
338 
339 template <class ConcreteOp>
340 struct TestExternalTestOpModel
341     : public TestExternalOpInterface::ExternalModel<
342           TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
343   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
344     return op->getName().getStringRef().size() + arg;
345   }
346 
347   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
348     return ConcreteOp::getOperationName().size() + 2 * arg;
349   }
350 };
351 
352 TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
353   DialectRegistry registry;
354   registry.insert<test::TestDialect>();
355   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
356     ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
357   });
358   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
359     test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
360     test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
361   });
362 
363   // Construct the context directly from a registry. The interfaces are
364   // expected to be readily available on operations.
365   MLIRContext context(registry);
366   context.loadDialect<test::TestDialect>();
367 
368   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
369   OpBuilder builder(module->getBody(), module->getBody()->begin());
370   auto opJ =
371       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
372   auto opH =
373       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
374   auto opI =
375       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
376 
377   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
378   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
379   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
380   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
381 }
382 
383 TEST(InterfaceAttachment, OperationDelayedContextAppend) {
384   DialectRegistry registry;
385   registry.insert<test::TestDialect>();
386   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
387     ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
388   });
389   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
390     test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
391     test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
392   });
393 
394   // Construct the context, create ops, and only then append the registry. The
395   // interfaces are expected to be available after appending the registry.
396   MLIRContext context;
397   context.loadDialect<test::TestDialect>();
398 
399   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
400   OpBuilder builder(module->getBody(), module->getBody()->begin());
401   auto opJ =
402       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
403   auto opH =
404       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
405   auto opI =
406       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
407 
408   EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
409   EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
410   EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
411   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
412 
413   context.appendDialectRegistry(registry);
414 
415   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
416   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
417   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
418   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
419 }
420 
421 TEST(InterfaceAttachmentTest, PromisedInterfaces) {
422   // Attribute interfaces use the exact same mechanism as types, so just check
423   // that the promise mechanism works for attributes.
424   MLIRContext context;
425   auto *testDialect = context.getOrLoadDialect<test::TestDialect>();
426   auto attr = test::SimpleAAttr::get(&context);
427 
428   // `SimpleAAttr` doesn't implement nor promises the
429   // `TestExternalAttrInterface` interface.
430   EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
431   EXPECT_FALSE(
432       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
433 
434   // Add a promise `TestExternalAttrInterface`.
435   testDialect->declarePromisedInterface<TestExternalAttrInterface,
436                                         test::SimpleAAttr>();
437   EXPECT_TRUE(
438       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
439 
440   // Attach the interface.
441   test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
442   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
443   EXPECT_TRUE(
444       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
445 }
446 
447 } // namespace
448