xref: /llvm-project/mlir/unittests/IR/InterfaceAttachmentTest.cpp (revision d0e6fd99aa95ff61372ea328e9f89da2ee39c49c)
1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the tests for attaching interfaces to attributes and types
10 // without having to specify them on the attribute or type class directly.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "gtest/gtest.h"
19 
20 #include "../../test/lib/Dialect/Test/TestAttributes.h"
21 #include "../../test/lib/Dialect/Test/TestDialect.h"
22 #include "../../test/lib/Dialect/Test/TestTypes.h"
23 #include "mlir/IR/OwningOpRef.h"
24 
25 using namespace mlir;
26 using namespace test;
27 
28 namespace {
29 
30 /// External interface model for the integer type. Only provides non-default
31 /// methods.
32 struct Model
33     : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
34   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
35     return type.getIntOrFloatBitWidth() + arg;
36   }
37 
38   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
39 };
40 
41 /// External interface model for the float type. Provides non-deafult and
42 /// overrides default methods.
43 struct OverridingModel
44     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
45                                                       FloatType> {
46   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
47     return type.getIntOrFloatBitWidth() + arg;
48   }
49 
50   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
51 
52   unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
53     return 128;
54   }
55 
56   static unsigned staticGetArgument(unsigned arg) { return 420; }
57 };
58 
59 TEST(InterfaceAttachment, Type) {
60   MLIRContext context;
61 
62   // Check that the type has no interface.
63   IntegerType i8 = IntegerType::get(&context, 8);
64   ASSERT_FALSE(isa<TestExternalTypeInterface>(i8));
65 
66   // Attach an interface and check that the type now has the interface.
67   IntegerType::attachInterface<Model>(context);
68   TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8);
69   ASSERT_TRUE(iface != nullptr);
70   EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
71   EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
72   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
73   EXPECT_EQ(iface.staticGetArgument(17), 17u);
74 
75   // Same, but with the default implementation overridden.
76   FloatType flt = Float32Type::get(&context);
77   ASSERT_FALSE(isa<TestExternalTypeInterface>(flt));
78   Float32Type::attachInterface<OverridingModel>(context);
79   iface = dyn_cast<TestExternalTypeInterface>(flt);
80   ASSERT_TRUE(iface != nullptr);
81   EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
82   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
83   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
84   EXPECT_EQ(iface.staticGetArgument(17), 420u);
85 
86   // Other contexts shouldn't have the attribute attached.
87   MLIRContext other;
88   IntegerType i8other = IntegerType::get(&other, 8);
89   EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other));
90 }
91 
92 /// External interface model for the test type from the test dialect.
93 struct TestTypeModel
94     : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
95                                                       test::TestType> {
96   unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
97 
98   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
99 };
100 
101 TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
102   // Put the interface in the registry.
103   DialectRegistry registry;
104   registry.insert<test::TestDialect>();
105   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
106     test::TestType::attachInterface<TestTypeModel>(*ctx);
107   });
108 
109   // Check that when a context is constructed with the given registry, the type
110   // interface gets registered.
111   MLIRContext context(registry);
112   context.loadDialect<test::TestDialect>();
113   test::TestType testType = test::TestType::get(&context);
114   auto iface = dyn_cast<TestExternalTypeInterface>(testType);
115   ASSERT_TRUE(iface != nullptr);
116   EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
117   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
118 }
119 
120 TEST(InterfaceAttachment, TypeDelayedContextAppend) {
121   // Put the interface in the registry.
122   DialectRegistry registry;
123   registry.insert<test::TestDialect>();
124   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
125     test::TestType::attachInterface<TestTypeModel>(*ctx);
126   });
127 
128   // Check that when the registry gets appended to the context, the interface
129   // becomes available for objects in loaded dialects.
130   MLIRContext context;
131   context.loadDialect<test::TestDialect>();
132   test::TestType testType = test::TestType::get(&context);
133   EXPECT_FALSE(isa<TestExternalTypeInterface>(testType));
134   context.appendDialectRegistry(registry);
135   EXPECT_TRUE(isa<TestExternalTypeInterface>(testType));
136 }
137 
138 TEST(InterfaceAttachment, RepeatedRegistration) {
139   DialectRegistry registry;
140   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
141     IntegerType::attachInterface<Model>(*ctx);
142   });
143   MLIRContext context(registry);
144 
145   // Should't fail on repeated registration through the dialect registry.
146   context.appendDialectRegistry(registry);
147 }
148 
149 TEST(InterfaceAttachment, TypeBuiltinDelayed) {
150   // Builtin dialect needs to registration or loading, but delayed interface
151   // registration must still work.
152   DialectRegistry registry;
153   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
154     IntegerType::attachInterface<Model>(*ctx);
155   });
156 
157   MLIRContext context(registry);
158   IntegerType i16 = IntegerType::get(&context, 16);
159   EXPECT_TRUE(isa<TestExternalTypeInterface>(i16));
160 
161   MLIRContext initiallyEmpty;
162   IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
163   EXPECT_FALSE(isa<TestExternalTypeInterface>(i32));
164   initiallyEmpty.appendDialectRegistry(registry);
165   EXPECT_TRUE(isa<TestExternalTypeInterface>(i32));
166 }
167 
168 /// The interface provides a default implementation that expects
169 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
170 /// just derives from the ExternalModel.
171 struct TestExternalFallbackTypeIntegerModel
172     : public TestExternalFallbackTypeInterface::ExternalModel<
173           TestExternalFallbackTypeIntegerModel, IntegerType> {};
174 
175 /// The interface provides a default implementation that expects
176 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
177 /// FallbackModel instead to override this and make sure the code still compiles
178 /// because we never instantiate the ExternalModel class template with a
179 /// template argument that would have led to compilation failures.
180 struct TestExternalFallbackTypeVectorModel
181     : public TestExternalFallbackTypeInterface::FallbackModel<
182           TestExternalFallbackTypeVectorModel> {
183   unsigned getBitwidth(Type type) const {
184     IntegerType elementType =
185         dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType());
186     return elementType ? elementType.getWidth() : 0;
187   }
188 };
189 
190 TEST(InterfaceAttachment, Fallback) {
191   MLIRContext context;
192 
193   // Just check that we can attach the interface.
194   IntegerType i8 = IntegerType::get(&context, 8);
195   ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8));
196   IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
197   ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8));
198 
199   // Call the method so it is guaranteed not to be instantiated.
200   VectorType vec = VectorType::get({42}, i8);
201   ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec));
202   VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
203   ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec));
204   EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u);
205 }
206 
207 /// External model for attribute interfaces.
208 struct TestExternalIntegerAttrModel
209     : public TestExternalAttrInterface::ExternalModel<
210           TestExternalIntegerAttrModel, IntegerAttr> {
211   const Dialect *getDialectPtr(Attribute attr) const {
212     return &cast<IntegerAttr>(attr).getDialect();
213   }
214 
215   static int getSomeNumber() { return 42; }
216 };
217 
218 TEST(InterfaceAttachment, Attribute) {
219   MLIRContext context;
220 
221   // Attribute interfaces use the exact same mechanism as types, so just check
222   // that the basics work for attributes.
223   IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
224   ASSERT_FALSE(isa<TestExternalAttrInterface>(attr));
225   IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
226   auto iface = dyn_cast<TestExternalAttrInterface>(attr);
227   ASSERT_TRUE(iface != nullptr);
228   EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
229   EXPECT_EQ(iface.getSomeNumber(), 42);
230 }
231 
232 /// External model for an interface attachable to a non-builtin attribute.
233 struct TestExternalSimpleAAttrModel
234     : public TestExternalAttrInterface::ExternalModel<
235           TestExternalSimpleAAttrModel, test::SimpleAAttr> {
236   const Dialect *getDialectPtr(Attribute attr) const {
237     return &attr.getDialect();
238   }
239 
240   static int getSomeNumber() { return 21; }
241 };
242 
243 TEST(InterfaceAttachmentTest, AttributeDelayed) {
244   // Attribute interfaces use the exact same mechanism as types, so just check
245   // that the delayed registration work for attributes.
246   DialectRegistry registry;
247   registry.insert<test::TestDialect>();
248   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
249     test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
250   });
251 
252   MLIRContext context(registry);
253   context.loadDialect<test::TestDialect>();
254   auto attr = test::SimpleAAttr::get(&context);
255   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
256 
257   MLIRContext initiallyEmpty;
258   initiallyEmpty.loadDialect<test::TestDialect>();
259   attr = test::SimpleAAttr::get(&initiallyEmpty);
260   EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
261   initiallyEmpty.appendDialectRegistry(registry);
262   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
263 }
264 
265 /// External interface model for the module operation. Only provides non-default
266 /// methods.
267 struct TestExternalOpModel
268     : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
269                                                     ModuleOp> {
270   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
271     return op->getName().getStringRef().size() + arg;
272   }
273 
274   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
275     return ModuleOp::getOperationName().size() + 2 * arg;
276   }
277 };
278 
279 /// External interface model for the func operation. Provides non-deafult and
280 /// overrides default methods.
281 struct TestExternalOpOverridingModel
282     : public TestExternalOpInterface::FallbackModel<
283           TestExternalOpOverridingModel> {
284   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
285     return op->getName().getStringRef().size() + arg;
286   }
287 
288   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
289     return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg;
290   }
291 
292   unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
293     return 42;
294   }
295 
296   static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
297 };
298 
299 TEST(InterfaceAttachment, Operation) {
300   MLIRContext context;
301   OpBuilder builder(&context);
302 
303   // Initially, the operation doesn't have the interface.
304   OwningOpRef<ModuleOp> moduleOp =
305       builder.create<ModuleOp>(UnknownLoc::get(&context));
306   ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
307 
308   // We can attach an external interface and now the operaiton has it.
309   ModuleOp::attachInterface<TestExternalOpModel>(context);
310   auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
311   ASSERT_TRUE(iface != nullptr);
312   EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
313   EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
314   EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
315   EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
316 
317   // Default implementation can be overridden.
318   OwningOpRef<UnrealizedConversionCastOp> castOp =
319       builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context),
320                                                  TypeRange(), ValueRange());
321   ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation()));
322   UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>(
323       context);
324   iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation());
325   ASSERT_TRUE(iface != nullptr);
326   EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u);
327   EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
328   EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u);
329   EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
330 
331   // Another context doesn't have the interfaces registered.
332   MLIRContext other;
333   OwningOpRef<ModuleOp> otherModuleOp =
334       ModuleOp::create(UnknownLoc::get(&other));
335   ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
336 }
337 
338 template <class ConcreteOp>
339 struct TestExternalTestOpModel
340     : public TestExternalOpInterface::ExternalModel<
341           TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
342   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
343     return op->getName().getStringRef().size() + arg;
344   }
345 
346   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
347     return ConcreteOp::getOperationName().size() + 2 * arg;
348   }
349 };
350 
351 TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
352   DialectRegistry registry;
353   registry.insert<test::TestDialect>();
354   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
355     ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
356   });
357   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
358     test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
359     test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
360   });
361 
362   // Construct the context directly from a registry. The interfaces are
363   // expected to be readily available on operations.
364   MLIRContext context(registry);
365   context.loadDialect<test::TestDialect>();
366 
367   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
368   OpBuilder builder(module->getBody(), module->getBody()->begin());
369   auto opJ =
370       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
371   auto opH =
372       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
373   auto opI =
374       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
375 
376   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
377   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
378   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
379   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
380 }
381 
382 TEST(InterfaceAttachment, OperationDelayedContextAppend) {
383   DialectRegistry registry;
384   registry.insert<test::TestDialect>();
385   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
386     ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
387   });
388   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
389     test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
390     test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
391   });
392 
393   // Construct the context, create ops, and only then append the registry. The
394   // interfaces are expected to be available after appending the registry.
395   MLIRContext context;
396   context.loadDialect<test::TestDialect>();
397 
398   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
399   OpBuilder builder(module->getBody(), module->getBody()->begin());
400   auto opJ =
401       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
402   auto opH =
403       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
404   auto opI =
405       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
406 
407   EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
408   EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
409   EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
410   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
411 
412   context.appendDialectRegistry(registry);
413 
414   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
415   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
416   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
417   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
418 }
419 
420 TEST(InterfaceAttachmentTest, PromisedInterfaces) {
421   // Attribute interfaces use the exact same mechanism as types, so just check
422   // that the promise mechanism works for attributes.
423   MLIRContext context;
424   auto testDialect = context.getOrLoadDialect<test::TestDialect>();
425   auto attr = test::SimpleAAttr::get(&context);
426 
427   // `SimpleAAttr` doesn't implement nor promises the
428   // `TestExternalAttrInterface` interface.
429   EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
430   EXPECT_FALSE(
431       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
432 
433   // Add a promise `TestExternalAttrInterface`.
434   testDialect->declarePromisedInterface<test::SimpleAAttr,
435                                         TestExternalAttrInterface>();
436   EXPECT_TRUE(
437       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
438 
439   // Attach the interface.
440   test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
441   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
442   EXPECT_TRUE(
443       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
444 }
445 
446 } // namespace
447