xref: /llvm-project/mlir/unittests/IR/InterfaceAttachmentTest.cpp (revision 9b2a1bcf6fbe79cfa3725cfd654f45cce0375d3b)
1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the tests for attaching interfaces to attributes and types
10 // without having to specify them on the attribute or type class directly.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "gtest/gtest.h"
17 
18 #include "../../test/lib/Dialect/Test/TestAttributes.h"
19 #include "../../test/lib/Dialect/Test/TestTypes.h"
20 
21 using namespace mlir;
22 using namespace mlir::test;
23 
24 namespace {
25 
26 /// External interface model for the integer type. Only provides non-default
27 /// methods.
28 struct Model
29     : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
30   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
31     return type.getIntOrFloatBitWidth() + arg;
32   }
33 
34   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
35 };
36 
37 /// External interface model for the float type. Provides non-deafult and
38 /// overrides default methods.
39 struct OverridingModel
40     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
41                                                       FloatType> {
42   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
43     return type.getIntOrFloatBitWidth() + arg;
44   }
45 
46   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
47 
48   unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
49     return 128;
50   }
51 
52   static unsigned staticGetArgument(unsigned arg) { return 420; }
53 };
54 
55 TEST(InterfaceAttachment, Type) {
56   MLIRContext context;
57 
58   // Check that the type has no interface.
59   IntegerType i8 = IntegerType::get(&context, 8);
60   ASSERT_FALSE(i8.isa<TestExternalTypeInterface>());
61 
62   // Attach an interface and check that the type now has the interface.
63   IntegerType::attachInterface<Model>(context);
64   TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>();
65   ASSERT_TRUE(iface != nullptr);
66   EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
67   EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
68   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
69   EXPECT_EQ(iface.staticGetArgument(17), 17u);
70 
71   // Same, but with the default implementation overridden.
72   FloatType flt = Float32Type::get(&context);
73   ASSERT_FALSE(flt.isa<TestExternalTypeInterface>());
74   Float32Type::attachInterface<OverridingModel>(context);
75   iface = flt.dyn_cast<TestExternalTypeInterface>();
76   ASSERT_TRUE(iface != nullptr);
77   EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
78   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
79   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
80   EXPECT_EQ(iface.staticGetArgument(17), 420u);
81 
82   // Other contexts shouldn't have the attribute attached.
83   MLIRContext other;
84   IntegerType i8other = IntegerType::get(&other, 8);
85   EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
86 }
87 
88 /// The interface provides a default implementation that expects
89 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
90 /// just derives from the ExternalModel.
91 struct TestExternalFallbackTypeIntegerModel
92     : public TestExternalFallbackTypeInterface::ExternalModel<
93           TestExternalFallbackTypeIntegerModel, IntegerType> {};
94 
95 /// The interface provides a default implementation that expects
96 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
97 /// FallbackModel instead to override this and make sure the code still compiles
98 /// because we never instantiate the ExternalModel class template with a
99 /// template argument that would have led to compilation failures.
100 struct TestExternalFallbackTypeVectorModel
101     : public TestExternalFallbackTypeInterface::FallbackModel<
102           TestExternalFallbackTypeVectorModel> {
103   unsigned getBitwidth(Type type) const {
104     IntegerType elementType = type.cast<VectorType>()
105                                   .getElementType()
106                                   .dyn_cast_or_null<IntegerType>();
107     return elementType ? elementType.getWidth() : 0;
108   }
109 };
110 
111 TEST(InterfaceAttachment, Fallback) {
112   MLIRContext context;
113 
114   // Just check that we can attach the interface.
115   IntegerType i8 = IntegerType::get(&context, 8);
116   ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>());
117   IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
118   ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>());
119 
120   // Call the method so it is guaranteed not to be instantiated.
121   VectorType vec = VectorType::get({42}, i8);
122   ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>());
123   VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
124   ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>());
125   EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u);
126 }
127 
128 /// External model for attribute interfaces.
129 struct TextExternalIntegerAttrModel
130     : public TestExternalAttrInterface::ExternalModel<
131           TextExternalIntegerAttrModel, IntegerAttr> {
132   const Dialect *getDialectPtr(Attribute attr) const {
133     return &attr.cast<IntegerAttr>().getDialect();
134   }
135 
136   static int getSomeNumber() { return 42; }
137 };
138 
139 TEST(InterfaceAttachment, Attribute) {
140   MLIRContext context;
141 
142   // Attribute interfaces use the exact same mechanism as types, so just check
143   // that the basics work for attributes.
144   IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
145   ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
146   IntegerAttr::attachInterface<TextExternalIntegerAttrModel>(context);
147   auto iface = attr.dyn_cast<TestExternalAttrInterface>();
148   ASSERT_TRUE(iface != nullptr);
149   EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
150   EXPECT_EQ(iface.getSomeNumber(), 42);
151 }
152 
153 } // end namespace
154