1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This implements the tests for attaching interfaces to attributes and types 10 // without having to specify them on the attribute or type class directly. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "gtest/gtest.h" 17 18 #include "../../test/lib/Dialect/Test/TestAttributes.h" 19 #include "../../test/lib/Dialect/Test/TestTypes.h" 20 21 using namespace mlir; 22 using namespace mlir::test; 23 24 namespace { 25 26 /// External interface model for the integer type. Only provides non-default 27 /// methods. 28 struct Model 29 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { 30 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 31 return type.getIntOrFloatBitWidth() + arg; 32 } 33 34 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 35 }; 36 37 /// External interface model for the float type. Provides non-deafult and 38 /// overrides default methods. 39 struct OverridingModel 40 : public TestExternalTypeInterface::ExternalModel<OverridingModel, 41 FloatType> { 42 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 43 return type.getIntOrFloatBitWidth() + arg; 44 } 45 46 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 47 48 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { 49 return 128; 50 } 51 52 static unsigned staticGetArgument(unsigned arg) { return 420; } 53 }; 54 55 TEST(InterfaceAttachment, Type) { 56 MLIRContext context; 57 58 // Check that the type has no interface. 59 IntegerType i8 = IntegerType::get(&context, 8); 60 ASSERT_FALSE(i8.isa<TestExternalTypeInterface>()); 61 62 // Attach an interface and check that the type now has the interface. 63 IntegerType::attachInterface<Model>(context); 64 TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>(); 65 ASSERT_TRUE(iface != nullptr); 66 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); 67 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); 68 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); 69 EXPECT_EQ(iface.staticGetArgument(17), 17u); 70 71 // Same, but with the default implementation overridden. 72 FloatType flt = Float32Type::get(&context); 73 ASSERT_FALSE(flt.isa<TestExternalTypeInterface>()); 74 Float32Type::attachInterface<OverridingModel>(context); 75 iface = flt.dyn_cast<TestExternalTypeInterface>(); 76 ASSERT_TRUE(iface != nullptr); 77 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); 78 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); 79 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); 80 EXPECT_EQ(iface.staticGetArgument(17), 420u); 81 82 // Other contexts shouldn't have the attribute attached. 83 MLIRContext other; 84 IntegerType i8other = IntegerType::get(&other, 8); 85 EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>()); 86 } 87 88 /// The interface provides a default implementation that expects 89 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this 90 /// just derives from the ExternalModel. 91 struct TestExternalFallbackTypeIntegerModel 92 : public TestExternalFallbackTypeInterface::ExternalModel< 93 TestExternalFallbackTypeIntegerModel, IntegerType> {}; 94 95 /// The interface provides a default implementation that expects 96 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use 97 /// FallbackModel instead to override this and make sure the code still compiles 98 /// because we never instantiate the ExternalModel class template with a 99 /// template argument that would have led to compilation failures. 100 struct TestExternalFallbackTypeVectorModel 101 : public TestExternalFallbackTypeInterface::FallbackModel< 102 TestExternalFallbackTypeVectorModel> { 103 unsigned getBitwidth(Type type) const { 104 IntegerType elementType = type.cast<VectorType>() 105 .getElementType() 106 .dyn_cast_or_null<IntegerType>(); 107 return elementType ? elementType.getWidth() : 0; 108 } 109 }; 110 111 TEST(InterfaceAttachment, Fallback) { 112 MLIRContext context; 113 114 // Just check that we can attach the interface. 115 IntegerType i8 = IntegerType::get(&context, 8); 116 ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>()); 117 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); 118 ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>()); 119 120 // Call the method so it is guaranteed not to be instantiated. 121 VectorType vec = VectorType::get({42}, i8); 122 ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>()); 123 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); 124 ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>()); 125 EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u); 126 } 127 128 /// External model for attribute interfaces. 129 struct TextExternalIntegerAttrModel 130 : public TestExternalAttrInterface::ExternalModel< 131 TextExternalIntegerAttrModel, IntegerAttr> { 132 const Dialect *getDialectPtr(Attribute attr) const { 133 return &attr.cast<IntegerAttr>().getDialect(); 134 } 135 136 static int getSomeNumber() { return 42; } 137 }; 138 139 TEST(InterfaceAttachment, Attribute) { 140 MLIRContext context; 141 142 // Attribute interfaces use the exact same mechanism as types, so just check 143 // that the basics work for attributes. 144 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); 145 ASSERT_FALSE(attr.isa<TestExternalAttrInterface>()); 146 IntegerAttr::attachInterface<TextExternalIntegerAttrModel>(context); 147 auto iface = attr.dyn_cast<TestExternalAttrInterface>(); 148 ASSERT_TRUE(iface != nullptr); 149 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); 150 EXPECT_EQ(iface.getSomeNumber(), 42); 151 } 152 153 } // end namespace 154