1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This implements the tests for attaching interfaces to attributes and types 10 // without having to specify them on the attribute or type class directly. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/BuiltinDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "gtest/gtest.h" 19 20 #include "../../test/lib/Dialect/Test/TestAttributes.h" 21 #include "../../test/lib/Dialect/Test/TestDialect.h" 22 #include "../../test/lib/Dialect/Test/TestTypes.h" 23 24 using namespace mlir; 25 using namespace mlir::test; 26 27 namespace { 28 29 /// External interface model for the integer type. Only provides non-default 30 /// methods. 31 struct Model 32 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { 33 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 34 return type.getIntOrFloatBitWidth() + arg; 35 } 36 37 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 38 }; 39 40 /// External interface model for the float type. Provides non-deafult and 41 /// overrides default methods. 42 struct OverridingModel 43 : public TestExternalTypeInterface::ExternalModel<OverridingModel, 44 FloatType> { 45 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 46 return type.getIntOrFloatBitWidth() + arg; 47 } 48 49 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 50 51 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { 52 return 128; 53 } 54 55 static unsigned staticGetArgument(unsigned arg) { return 420; } 56 }; 57 58 TEST(InterfaceAttachment, Type) { 59 MLIRContext context; 60 61 // Check that the type has no interface. 62 IntegerType i8 = IntegerType::get(&context, 8); 63 ASSERT_FALSE(i8.isa<TestExternalTypeInterface>()); 64 65 // Attach an interface and check that the type now has the interface. 66 IntegerType::attachInterface<Model>(context); 67 TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>(); 68 ASSERT_TRUE(iface != nullptr); 69 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); 70 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); 71 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); 72 EXPECT_EQ(iface.staticGetArgument(17), 17u); 73 74 // Same, but with the default implementation overridden. 75 FloatType flt = Float32Type::get(&context); 76 ASSERT_FALSE(flt.isa<TestExternalTypeInterface>()); 77 Float32Type::attachInterface<OverridingModel>(context); 78 iface = flt.dyn_cast<TestExternalTypeInterface>(); 79 ASSERT_TRUE(iface != nullptr); 80 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); 81 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); 82 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); 83 EXPECT_EQ(iface.staticGetArgument(17), 420u); 84 85 // Other contexts shouldn't have the attribute attached. 86 MLIRContext other; 87 IntegerType i8other = IntegerType::get(&other, 8); 88 EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>()); 89 } 90 91 /// External interface model for the test type from the test dialect. 92 struct TestTypeModel 93 : public TestExternalTypeInterface::ExternalModel<TestTypeModel, 94 test::TestType> { 95 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } 96 97 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } 98 }; 99 100 TEST(InterfaceAttachment, TypeDelayedContextConstruct) { 101 // Put the interface in the registry. 102 DialectRegistry registry; 103 registry.insert<test::TestDialect>(); 104 registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>(); 105 106 // Check that when a context is constructed with the given registry, the type 107 // interface gets registered. 108 MLIRContext context(registry); 109 context.loadDialect<test::TestDialect>(); 110 test::TestType testType = test::TestType::get(&context); 111 auto iface = testType.dyn_cast<TestExternalTypeInterface>(); 112 ASSERT_TRUE(iface != nullptr); 113 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); 114 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); 115 } 116 117 TEST(InterfaceAttachment, TypeDelayedContextAppend) { 118 // Put the interface in the registry. 119 DialectRegistry registry; 120 registry.insert<test::TestDialect>(); 121 registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>(); 122 123 // Check that when the registry gets appended to the context, the interface 124 // becomes available for objects in loaded dialects. 125 MLIRContext context; 126 context.loadDialect<test::TestDialect>(); 127 test::TestType testType = test::TestType::get(&context); 128 EXPECT_FALSE(testType.isa<TestExternalTypeInterface>()); 129 context.appendDialectRegistry(registry); 130 EXPECT_TRUE(testType.isa<TestExternalTypeInterface>()); 131 } 132 133 TEST(InterfaceAttachment, RepeatedRegistration) { 134 DialectRegistry registry; 135 registry.addTypeInterface<BuiltinDialect, IntegerType, Model>(); 136 MLIRContext context(registry); 137 138 // Should't fail on repeated registration through the dialect registry. 139 context.appendDialectRegistry(registry); 140 } 141 142 TEST(InterfaceAttachment, TypeBuiltinDelayed) { 143 // Builtin dialect needs to registration or loading, but delayed interface 144 // registration must still work. 145 DialectRegistry registry; 146 registry.addTypeInterface<BuiltinDialect, IntegerType, Model>(); 147 148 MLIRContext context(registry); 149 IntegerType i16 = IntegerType::get(&context, 16); 150 EXPECT_TRUE(i16.isa<TestExternalTypeInterface>()); 151 152 MLIRContext initiallyEmpty; 153 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); 154 EXPECT_FALSE(i32.isa<TestExternalTypeInterface>()); 155 initiallyEmpty.appendDialectRegistry(registry); 156 EXPECT_TRUE(i32.isa<TestExternalTypeInterface>()); 157 } 158 159 /// The interface provides a default implementation that expects 160 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this 161 /// just derives from the ExternalModel. 162 struct TestExternalFallbackTypeIntegerModel 163 : public TestExternalFallbackTypeInterface::ExternalModel< 164 TestExternalFallbackTypeIntegerModel, IntegerType> {}; 165 166 /// The interface provides a default implementation that expects 167 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use 168 /// FallbackModel instead to override this and make sure the code still compiles 169 /// because we never instantiate the ExternalModel class template with a 170 /// template argument that would have led to compilation failures. 171 struct TestExternalFallbackTypeVectorModel 172 : public TestExternalFallbackTypeInterface::FallbackModel< 173 TestExternalFallbackTypeVectorModel> { 174 unsigned getBitwidth(Type type) const { 175 IntegerType elementType = type.cast<VectorType>() 176 .getElementType() 177 .dyn_cast_or_null<IntegerType>(); 178 return elementType ? elementType.getWidth() : 0; 179 } 180 }; 181 182 TEST(InterfaceAttachment, Fallback) { 183 MLIRContext context; 184 185 // Just check that we can attach the interface. 186 IntegerType i8 = IntegerType::get(&context, 8); 187 ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>()); 188 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); 189 ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>()); 190 191 // Call the method so it is guaranteed not to be instantiated. 192 VectorType vec = VectorType::get({42}, i8); 193 ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>()); 194 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); 195 ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>()); 196 EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u); 197 } 198 199 /// External model for attribute interfaces. 200 struct TestExternalIntegerAttrModel 201 : public TestExternalAttrInterface::ExternalModel< 202 TestExternalIntegerAttrModel, IntegerAttr> { 203 const Dialect *getDialectPtr(Attribute attr) const { 204 return &attr.cast<IntegerAttr>().getDialect(); 205 } 206 207 static int getSomeNumber() { return 42; } 208 }; 209 210 TEST(InterfaceAttachment, Attribute) { 211 MLIRContext context; 212 213 // Attribute interfaces use the exact same mechanism as types, so just check 214 // that the basics work for attributes. 215 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); 216 ASSERT_FALSE(attr.isa<TestExternalAttrInterface>()); 217 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context); 218 auto iface = attr.dyn_cast<TestExternalAttrInterface>(); 219 ASSERT_TRUE(iface != nullptr); 220 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); 221 EXPECT_EQ(iface.getSomeNumber(), 42); 222 } 223 224 /// External model for an interface attachable to a non-builtin attribute. 225 struct TestExternalSimpleAAttrModel 226 : public TestExternalAttrInterface::ExternalModel< 227 TestExternalSimpleAAttrModel, test::SimpleAAttr> { 228 const Dialect *getDialectPtr(Attribute attr) const { 229 return &attr.getDialect(); 230 } 231 232 static int getSomeNumber() { return 21; } 233 }; 234 235 TEST(InterfaceAttachmentTest, AttributeDelayed) { 236 // Attribute interfaces use the exact same mechanism as types, so just check 237 // that the delayed registration work for attributes. 238 DialectRegistry registry; 239 registry.insert<test::TestDialect>(); 240 registry.addAttrInterface<test::TestDialect, test::SimpleAAttr, 241 TestExternalSimpleAAttrModel>(); 242 243 MLIRContext context(registry); 244 context.loadDialect<test::TestDialect>(); 245 auto attr = test::SimpleAAttr::get(&context); 246 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>()); 247 248 MLIRContext initiallyEmpty; 249 initiallyEmpty.loadDialect<test::TestDialect>(); 250 attr = test::SimpleAAttr::get(&initiallyEmpty); 251 EXPECT_FALSE(attr.isa<TestExternalAttrInterface>()); 252 initiallyEmpty.appendDialectRegistry(registry); 253 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>()); 254 } 255 256 /// External interface model for the module operation. Only provides non-default 257 /// methods. 258 struct TestExternalOpModel 259 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel, 260 ModuleOp> { 261 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 262 return op->getName().getStringRef().size() + arg; 263 } 264 265 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 266 return ModuleOp::getOperationName().size() + 2 * arg; 267 } 268 }; 269 270 /// External interface model for the func operation. Provides non-deafult and 271 /// overrides default methods. 272 struct TestExternalOpOverridingModel 273 : public TestExternalOpInterface::FallbackModel< 274 TestExternalOpOverridingModel> { 275 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 276 return op->getName().getStringRef().size() + arg; 277 } 278 279 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 280 return FuncOp::getOperationName().size() + 2 * arg; 281 } 282 283 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { 284 return 42; 285 } 286 287 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } 288 }; 289 290 TEST(InterfaceAttachment, Operation) { 291 MLIRContext context; 292 293 // Initially, the operation doesn't have the interface. 294 auto moduleOp = ModuleOp::create(UnknownLoc::get(&context)); 295 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp.getOperation())); 296 297 // We can attach an external interface and now the operaiton has it. 298 ModuleOp::attachInterface<TestExternalOpModel>(context); 299 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp.getOperation()); 300 ASSERT_TRUE(iface != nullptr); 301 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u); 302 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u); 303 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u); 304 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); 305 306 // Default implementation can be overridden. 307 auto funcOp = FuncOp::create(UnknownLoc::get(&context), "function", 308 FunctionType::get(&context, {}, {})); 309 ASSERT_FALSE(isa<TestExternalOpInterface>(funcOp.getOperation())); 310 FuncOp::attachInterface<TestExternalOpOverridingModel>(context); 311 iface = dyn_cast<TestExternalOpInterface>(funcOp.getOperation()); 312 ASSERT_TRUE(iface != nullptr); 313 EXPECT_EQ(iface.getNameLengthPlusArg(10), 22u); 314 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); 315 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 28u); 316 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); 317 318 // Another context doesn't have the interfaces registered. 319 MLIRContext other; 320 auto otherModuleOp = ModuleOp::create(UnknownLoc::get(&other)); 321 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation())); 322 } 323 324 struct TestExternalTestOpModel 325 : public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel, 326 test::OpJ> { 327 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 328 return op->getName().getStringRef().size() + arg; 329 } 330 331 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 332 return test::OpJ::getOperationName().size() + 2 * arg; 333 } 334 }; 335 336 TEST(InterfaceAttachment, OperationDelayedContextConstruct) { 337 DialectRegistry registry; 338 registry.insert<test::TestDialect>(); 339 registry.addOpInterface<ModuleOp, TestExternalOpModel>(); 340 registry.addOpInterface<test::OpJ, TestExternalTestOpModel>(); 341 342 // Construct the context directly from a registry. The interfaces are expected 343 // to be readily available on operations. 344 MLIRContext context(registry); 345 context.loadDialect<test::TestDialect>(); 346 ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); 347 OpBuilder builder(module); 348 auto op = 349 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 350 EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation())); 351 EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation())); 352 } 353 354 TEST(InterfaceAttachment, OperationDelayedContextAppend) { 355 DialectRegistry registry; 356 registry.insert<test::TestDialect>(); 357 registry.addOpInterface<ModuleOp, TestExternalOpModel>(); 358 registry.addOpInterface<test::OpJ, TestExternalTestOpModel>(); 359 360 // Construct the context, create ops, and only then append the registry. The 361 // interfaces are expected to be available after appending the registry. 362 MLIRContext context; 363 context.loadDialect<test::TestDialect>(); 364 ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); 365 OpBuilder builder(module); 366 auto op = 367 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 368 EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation())); 369 EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation())); 370 context.appendDialectRegistry(registry); 371 EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation())); 372 EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation())); 373 } 374 375 } // end namespace 376