1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This implements the tests for attaching interfaces to attributes and types 10 // without having to specify them on the attribute or type class directly. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/BuiltinDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "gtest/gtest.h" 19 20 #include "../../test/lib/Dialect/Test/TestAttributes.h" 21 #include "../../test/lib/Dialect/Test/TestDialect.h" 22 #include "../../test/lib/Dialect/Test/TestTypes.h" 23 #include "mlir/IR/OwningOpRef.h" 24 25 using namespace mlir; 26 using namespace test; 27 28 namespace { 29 30 /// External interface model for the integer type. Only provides non-default 31 /// methods. 32 struct Model 33 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { 34 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 35 return type.getIntOrFloatBitWidth() + arg; 36 } 37 38 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 39 }; 40 41 /// External interface model for the float type. Provides non-deafult and 42 /// overrides default methods. 43 struct OverridingModel 44 : public TestExternalTypeInterface::ExternalModel<OverridingModel, 45 FloatType> { 46 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 47 return type.getIntOrFloatBitWidth() + arg; 48 } 49 50 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 51 52 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { 53 return 128; 54 } 55 56 static unsigned staticGetArgument(unsigned arg) { return 420; } 57 }; 58 59 TEST(InterfaceAttachment, Type) { 60 MLIRContext context; 61 62 // Check that the type has no interface. 63 IntegerType i8 = IntegerType::get(&context, 8); 64 ASSERT_FALSE(isa<TestExternalTypeInterface>(i8)); 65 66 // Attach an interface and check that the type now has the interface. 67 IntegerType::attachInterface<Model>(context); 68 TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8); 69 ASSERT_TRUE(iface != nullptr); 70 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); 71 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); 72 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); 73 EXPECT_EQ(iface.staticGetArgument(17), 17u); 74 75 // Same, but with the default implementation overridden. 76 FloatType flt = Float32Type::get(&context); 77 ASSERT_FALSE(isa<TestExternalTypeInterface>(flt)); 78 Float32Type::attachInterface<OverridingModel>(context); 79 iface = dyn_cast<TestExternalTypeInterface>(flt); 80 ASSERT_TRUE(iface != nullptr); 81 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); 82 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); 83 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); 84 EXPECT_EQ(iface.staticGetArgument(17), 420u); 85 86 // Other contexts shouldn't have the attribute attached. 87 MLIRContext other; 88 IntegerType i8other = IntegerType::get(&other, 8); 89 EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other)); 90 } 91 92 /// External interface model for the test type from the test dialect. 93 struct TestTypeModel 94 : public TestExternalTypeInterface::ExternalModel<TestTypeModel, 95 test::TestType> { 96 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } 97 98 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } 99 }; 100 101 TEST(InterfaceAttachment, TypeDelayedContextConstruct) { 102 // Put the interface in the registry. 103 DialectRegistry registry; 104 registry.insert<test::TestDialect>(); 105 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 106 test::TestType::attachInterface<TestTypeModel>(*ctx); 107 }); 108 109 // Check that when a context is constructed with the given registry, the type 110 // interface gets registered. 111 MLIRContext context(registry); 112 context.loadDialect<test::TestDialect>(); 113 test::TestType testType = test::TestType::get(&context); 114 auto iface = dyn_cast<TestExternalTypeInterface>(testType); 115 ASSERT_TRUE(iface != nullptr); 116 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); 117 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); 118 } 119 120 TEST(InterfaceAttachment, TypeDelayedContextAppend) { 121 // Put the interface in the registry. 122 DialectRegistry registry; 123 registry.insert<test::TestDialect>(); 124 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 125 test::TestType::attachInterface<TestTypeModel>(*ctx); 126 }); 127 128 // Check that when the registry gets appended to the context, the interface 129 // becomes available for objects in loaded dialects. 130 MLIRContext context; 131 context.loadDialect<test::TestDialect>(); 132 test::TestType testType = test::TestType::get(&context); 133 EXPECT_FALSE(isa<TestExternalTypeInterface>(testType)); 134 context.appendDialectRegistry(registry); 135 EXPECT_TRUE(isa<TestExternalTypeInterface>(testType)); 136 } 137 138 TEST(InterfaceAttachment, RepeatedRegistration) { 139 DialectRegistry registry; 140 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 141 IntegerType::attachInterface<Model>(*ctx); 142 }); 143 MLIRContext context(registry); 144 145 // Should't fail on repeated registration through the dialect registry. 146 context.appendDialectRegistry(registry); 147 } 148 149 TEST(InterfaceAttachment, TypeBuiltinDelayed) { 150 // Builtin dialect needs to registration or loading, but delayed interface 151 // registration must still work. 152 DialectRegistry registry; 153 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 154 IntegerType::attachInterface<Model>(*ctx); 155 }); 156 157 MLIRContext context(registry); 158 IntegerType i16 = IntegerType::get(&context, 16); 159 EXPECT_TRUE(isa<TestExternalTypeInterface>(i16)); 160 161 MLIRContext initiallyEmpty; 162 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); 163 EXPECT_FALSE(isa<TestExternalTypeInterface>(i32)); 164 initiallyEmpty.appendDialectRegistry(registry); 165 EXPECT_TRUE(isa<TestExternalTypeInterface>(i32)); 166 } 167 168 /// The interface provides a default implementation that expects 169 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this 170 /// just derives from the ExternalModel. 171 struct TestExternalFallbackTypeIntegerModel 172 : public TestExternalFallbackTypeInterface::ExternalModel< 173 TestExternalFallbackTypeIntegerModel, IntegerType> {}; 174 175 /// The interface provides a default implementation that expects 176 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use 177 /// FallbackModel instead to override this and make sure the code still compiles 178 /// because we never instantiate the ExternalModel class template with a 179 /// template argument that would have led to compilation failures. 180 struct TestExternalFallbackTypeVectorModel 181 : public TestExternalFallbackTypeInterface::FallbackModel< 182 TestExternalFallbackTypeVectorModel> { 183 unsigned getBitwidth(Type type) const { 184 IntegerType elementType = 185 dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType()); 186 return elementType ? elementType.getWidth() : 0; 187 } 188 }; 189 190 TEST(InterfaceAttachment, Fallback) { 191 MLIRContext context; 192 193 // Just check that we can attach the interface. 194 IntegerType i8 = IntegerType::get(&context, 8); 195 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8)); 196 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); 197 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8)); 198 199 // Call the method so it is guaranteed not to be instantiated. 200 VectorType vec = VectorType::get({42}, i8); 201 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec)); 202 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); 203 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec)); 204 EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u); 205 } 206 207 /// External model for attribute interfaces. 208 struct TestExternalIntegerAttrModel 209 : public TestExternalAttrInterface::ExternalModel< 210 TestExternalIntegerAttrModel, IntegerAttr> { 211 const Dialect *getDialectPtr(Attribute attr) const { 212 return &cast<IntegerAttr>(attr).getDialect(); 213 } 214 215 static int getSomeNumber() { return 42; } 216 }; 217 218 TEST(InterfaceAttachment, Attribute) { 219 MLIRContext context; 220 221 // Attribute interfaces use the exact same mechanism as types, so just check 222 // that the basics work for attributes. 223 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); 224 ASSERT_FALSE(isa<TestExternalAttrInterface>(attr)); 225 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context); 226 auto iface = dyn_cast<TestExternalAttrInterface>(attr); 227 ASSERT_TRUE(iface != nullptr); 228 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); 229 EXPECT_EQ(iface.getSomeNumber(), 42); 230 } 231 232 /// External model for an interface attachable to a non-builtin attribute. 233 struct TestExternalSimpleAAttrModel 234 : public TestExternalAttrInterface::ExternalModel< 235 TestExternalSimpleAAttrModel, test::SimpleAAttr> { 236 const Dialect *getDialectPtr(Attribute attr) const { 237 return &attr.getDialect(); 238 } 239 240 static int getSomeNumber() { return 21; } 241 }; 242 243 TEST(InterfaceAttachmentTest, AttributeDelayed) { 244 // Attribute interfaces use the exact same mechanism as types, so just check 245 // that the delayed registration work for attributes. 246 DialectRegistry registry; 247 registry.insert<test::TestDialect>(); 248 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 249 test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx); 250 }); 251 252 MLIRContext context(registry); 253 context.loadDialect<test::TestDialect>(); 254 auto attr = test::SimpleAAttr::get(&context); 255 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); 256 257 MLIRContext initiallyEmpty; 258 initiallyEmpty.loadDialect<test::TestDialect>(); 259 attr = test::SimpleAAttr::get(&initiallyEmpty); 260 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); 261 initiallyEmpty.appendDialectRegistry(registry); 262 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); 263 } 264 265 /// External interface model for the module operation. Only provides non-default 266 /// methods. 267 struct TestExternalOpModel 268 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel, 269 ModuleOp> { 270 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 271 return op->getName().getStringRef().size() + arg; 272 } 273 274 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 275 return ModuleOp::getOperationName().size() + 2 * arg; 276 } 277 }; 278 279 /// External interface model for the func operation. Provides non-deafult and 280 /// overrides default methods. 281 struct TestExternalOpOverridingModel 282 : public TestExternalOpInterface::FallbackModel< 283 TestExternalOpOverridingModel> { 284 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 285 return op->getName().getStringRef().size() + arg; 286 } 287 288 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 289 return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg; 290 } 291 292 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { 293 return 42; 294 } 295 296 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } 297 }; 298 299 TEST(InterfaceAttachment, Operation) { 300 MLIRContext context; 301 OpBuilder builder(&context); 302 303 // Initially, the operation doesn't have the interface. 304 OwningOpRef<ModuleOp> moduleOp = 305 builder.create<ModuleOp>(UnknownLoc::get(&context)); 306 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation())); 307 308 // We can attach an external interface and now the operaiton has it. 309 ModuleOp::attachInterface<TestExternalOpModel>(context); 310 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation()); 311 ASSERT_TRUE(iface != nullptr); 312 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u); 313 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u); 314 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u); 315 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); 316 317 // Default implementation can be overridden. 318 OwningOpRef<UnrealizedConversionCastOp> castOp = 319 builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context), 320 TypeRange(), ValueRange()); 321 ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation())); 322 UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>( 323 context); 324 iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation()); 325 ASSERT_TRUE(iface != nullptr); 326 EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u); 327 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); 328 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u); 329 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); 330 331 // Another context doesn't have the interfaces registered. 332 MLIRContext other; 333 OwningOpRef<ModuleOp> otherModuleOp = 334 ModuleOp::create(UnknownLoc::get(&other)); 335 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation())); 336 } 337 338 template <class ConcreteOp> 339 struct TestExternalTestOpModel 340 : public TestExternalOpInterface::ExternalModel< 341 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> { 342 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 343 return op->getName().getStringRef().size() + arg; 344 } 345 346 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 347 return ConcreteOp::getOperationName().size() + 2 * arg; 348 } 349 }; 350 351 TEST(InterfaceAttachment, OperationDelayedContextConstruct) { 352 DialectRegistry registry; 353 registry.insert<test::TestDialect>(); 354 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 355 ModuleOp::attachInterface<TestExternalOpModel>(*ctx); 356 }); 357 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 358 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); 359 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); 360 }); 361 362 // Construct the context directly from a registry. The interfaces are 363 // expected to be readily available on operations. 364 MLIRContext context(registry); 365 context.loadDialect<test::TestDialect>(); 366 367 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 368 OpBuilder builder(module->getBody(), module->getBody()->begin()); 369 auto opJ = 370 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 371 auto opH = 372 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 373 auto opI = 374 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 375 376 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 377 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 378 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 379 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 380 } 381 382 TEST(InterfaceAttachment, OperationDelayedContextAppend) { 383 DialectRegistry registry; 384 registry.insert<test::TestDialect>(); 385 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 386 ModuleOp::attachInterface<TestExternalOpModel>(*ctx); 387 }); 388 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 389 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); 390 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); 391 }); 392 393 // Construct the context, create ops, and only then append the registry. The 394 // interfaces are expected to be available after appending the registry. 395 MLIRContext context; 396 context.loadDialect<test::TestDialect>(); 397 398 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 399 OpBuilder builder(module->getBody(), module->getBody()->begin()); 400 auto opJ = 401 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 402 auto opH = 403 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 404 auto opI = 405 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 406 407 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation())); 408 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation())); 409 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation())); 410 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 411 412 context.appendDialectRegistry(registry); 413 414 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 415 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 416 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 417 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 418 } 419 420 TEST(InterfaceAttachmentTest, PromisedInterfaces) { 421 // Attribute interfaces use the exact same mechanism as types, so just check 422 // that the promise mechanism works for attributes. 423 MLIRContext context; 424 auto testDialect = context.getOrLoadDialect<test::TestDialect>(); 425 auto attr = test::SimpleAAttr::get(&context); 426 427 // `SimpleAAttr` doesn't implement nor promises the 428 // `TestExternalAttrInterface` interface. 429 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); 430 EXPECT_FALSE( 431 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); 432 433 // Add a promise `TestExternalAttrInterface`. 434 testDialect->declarePromisedInterface<test::SimpleAAttr, 435 TestExternalAttrInterface>(); 436 EXPECT_TRUE( 437 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); 438 439 // Attach the interface. 440 test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context); 441 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); 442 EXPECT_TRUE( 443 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); 444 } 445 446 } // namespace 447