1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This implements the tests for attaching interfaces to attributes and types 10 // without having to specify them on the attribute or type class directly. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/BuiltinDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "gtest/gtest.h" 19 20 #include "../../test/lib/Dialect/Test/TestAttributes.h" 21 #include "../../test/lib/Dialect/Test/TestDialect.h" 22 #include "../../test/lib/Dialect/Test/TestOps.h" 23 #include "../../test/lib/Dialect/Test/TestTypes.h" 24 #include "mlir/IR/OwningOpRef.h" 25 26 using namespace mlir; 27 using namespace test; 28 29 namespace { 30 31 /// External interface model for the integer type. Only provides non-default 32 /// methods. 33 struct Model 34 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { 35 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 36 return type.getIntOrFloatBitWidth() + arg; 37 } 38 39 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 40 }; 41 42 /// External interface model for the float type. Provides non-deafult and 43 /// overrides default methods. 44 struct OverridingModel 45 : public TestExternalTypeInterface::ExternalModel<OverridingModel, 46 Float32Type> { 47 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 48 return type.getIntOrFloatBitWidth() + arg; 49 } 50 51 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 52 53 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { 54 return 128; 55 } 56 57 static unsigned staticGetArgument(unsigned arg) { return 420; } 58 }; 59 60 TEST(InterfaceAttachment, Type) { 61 MLIRContext context; 62 63 // Check that the type has no interface. 64 IntegerType i8 = IntegerType::get(&context, 8); 65 ASSERT_FALSE(isa<TestExternalTypeInterface>(i8)); 66 67 // Attach an interface and check that the type now has the interface. 68 IntegerType::attachInterface<Model>(context); 69 TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8); 70 ASSERT_TRUE(iface != nullptr); 71 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); 72 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); 73 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); 74 EXPECT_EQ(iface.staticGetArgument(17), 17u); 75 76 // Same, but with the default implementation overridden. 77 FloatType flt = Float32Type::get(&context); 78 ASSERT_FALSE(isa<TestExternalTypeInterface>(flt)); 79 Float32Type::attachInterface<OverridingModel>(context); 80 iface = dyn_cast<TestExternalTypeInterface>(flt); 81 ASSERT_TRUE(iface != nullptr); 82 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); 83 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); 84 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); 85 EXPECT_EQ(iface.staticGetArgument(17), 420u); 86 87 // Other contexts shouldn't have the attribute attached. 88 MLIRContext other; 89 IntegerType i8other = IntegerType::get(&other, 8); 90 EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other)); 91 } 92 93 /// External interface model for the test type from the test dialect. 94 struct TestTypeModel 95 : public TestExternalTypeInterface::ExternalModel<TestTypeModel, 96 test::TestType> { 97 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } 98 99 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } 100 }; 101 102 TEST(InterfaceAttachment, TypeDelayedContextConstruct) { 103 // Put the interface in the registry. 104 DialectRegistry registry; 105 registry.insert<test::TestDialect>(); 106 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 107 test::TestType::attachInterface<TestTypeModel>(*ctx); 108 }); 109 110 // Check that when a context is constructed with the given registry, the type 111 // interface gets registered. 112 MLIRContext context(registry); 113 context.loadDialect<test::TestDialect>(); 114 test::TestType testType = test::TestType::get(&context); 115 auto iface = dyn_cast<TestExternalTypeInterface>(testType); 116 ASSERT_TRUE(iface != nullptr); 117 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); 118 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); 119 } 120 121 TEST(InterfaceAttachment, TypeDelayedContextAppend) { 122 // Put the interface in the registry. 123 DialectRegistry registry; 124 registry.insert<test::TestDialect>(); 125 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 126 test::TestType::attachInterface<TestTypeModel>(*ctx); 127 }); 128 129 // Check that when the registry gets appended to the context, the interface 130 // becomes available for objects in loaded dialects. 131 MLIRContext context; 132 context.loadDialect<test::TestDialect>(); 133 test::TestType testType = test::TestType::get(&context); 134 EXPECT_FALSE(isa<TestExternalTypeInterface>(testType)); 135 context.appendDialectRegistry(registry); 136 EXPECT_TRUE(isa<TestExternalTypeInterface>(testType)); 137 } 138 139 TEST(InterfaceAttachment, RepeatedRegistration) { 140 DialectRegistry registry; 141 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 142 IntegerType::attachInterface<Model>(*ctx); 143 }); 144 MLIRContext context(registry); 145 146 // Should't fail on repeated registration through the dialect registry. 147 context.appendDialectRegistry(registry); 148 } 149 150 TEST(InterfaceAttachment, TypeBuiltinDelayed) { 151 // Builtin dialect needs to registration or loading, but delayed interface 152 // registration must still work. 153 DialectRegistry registry; 154 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 155 IntegerType::attachInterface<Model>(*ctx); 156 }); 157 158 MLIRContext context(registry); 159 IntegerType i16 = IntegerType::get(&context, 16); 160 EXPECT_TRUE(isa<TestExternalTypeInterface>(i16)); 161 162 MLIRContext initiallyEmpty; 163 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); 164 EXPECT_FALSE(isa<TestExternalTypeInterface>(i32)); 165 initiallyEmpty.appendDialectRegistry(registry); 166 EXPECT_TRUE(isa<TestExternalTypeInterface>(i32)); 167 } 168 169 /// The interface provides a default implementation that expects 170 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this 171 /// just derives from the ExternalModel. 172 struct TestExternalFallbackTypeIntegerModel 173 : public TestExternalFallbackTypeInterface::ExternalModel< 174 TestExternalFallbackTypeIntegerModel, IntegerType> {}; 175 176 /// The interface provides a default implementation that expects 177 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use 178 /// FallbackModel instead to override this and make sure the code still compiles 179 /// because we never instantiate the ExternalModel class template with a 180 /// template argument that would have led to compilation failures. 181 struct TestExternalFallbackTypeVectorModel 182 : public TestExternalFallbackTypeInterface::FallbackModel< 183 TestExternalFallbackTypeVectorModel> { 184 unsigned getBitwidth(Type type) const { 185 IntegerType elementType = 186 dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType()); 187 return elementType ? elementType.getWidth() : 0; 188 } 189 }; 190 191 TEST(InterfaceAttachment, Fallback) { 192 MLIRContext context; 193 194 // Just check that we can attach the interface. 195 IntegerType i8 = IntegerType::get(&context, 8); 196 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8)); 197 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); 198 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8)); 199 200 // Call the method so it is guaranteed not to be instantiated. 201 VectorType vec = VectorType::get({42}, i8); 202 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec)); 203 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); 204 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec)); 205 EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u); 206 } 207 208 /// External model for attribute interfaces. 209 struct TestExternalIntegerAttrModel 210 : public TestExternalAttrInterface::ExternalModel< 211 TestExternalIntegerAttrModel, IntegerAttr> { 212 const Dialect *getDialectPtr(Attribute attr) const { 213 return &cast<IntegerAttr>(attr).getDialect(); 214 } 215 216 static int getSomeNumber() { return 42; } 217 }; 218 219 TEST(InterfaceAttachment, Attribute) { 220 MLIRContext context; 221 222 // Attribute interfaces use the exact same mechanism as types, so just check 223 // that the basics work for attributes. 224 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); 225 ASSERT_FALSE(isa<TestExternalAttrInterface>(attr)); 226 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context); 227 auto iface = dyn_cast<TestExternalAttrInterface>(attr); 228 ASSERT_TRUE(iface != nullptr); 229 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); 230 EXPECT_EQ(iface.getSomeNumber(), 42); 231 } 232 233 /// External model for an interface attachable to a non-builtin attribute. 234 struct TestExternalSimpleAAttrModel 235 : public TestExternalAttrInterface::ExternalModel< 236 TestExternalSimpleAAttrModel, test::SimpleAAttr> { 237 const Dialect *getDialectPtr(Attribute attr) const { 238 return &attr.getDialect(); 239 } 240 241 static int getSomeNumber() { return 21; } 242 }; 243 244 TEST(InterfaceAttachmentTest, AttributeDelayed) { 245 // Attribute interfaces use the exact same mechanism as types, so just check 246 // that the delayed registration work for attributes. 247 DialectRegistry registry; 248 registry.insert<test::TestDialect>(); 249 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 250 test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx); 251 }); 252 253 MLIRContext context(registry); 254 context.loadDialect<test::TestDialect>(); 255 auto attr = test::SimpleAAttr::get(&context); 256 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); 257 258 MLIRContext initiallyEmpty; 259 initiallyEmpty.loadDialect<test::TestDialect>(); 260 attr = test::SimpleAAttr::get(&initiallyEmpty); 261 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); 262 initiallyEmpty.appendDialectRegistry(registry); 263 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); 264 } 265 266 /// External interface model for the module operation. Only provides non-default 267 /// methods. 268 struct TestExternalOpModel 269 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel, 270 ModuleOp> { 271 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 272 return op->getName().getStringRef().size() + arg; 273 } 274 275 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 276 return ModuleOp::getOperationName().size() + 2 * arg; 277 } 278 }; 279 280 /// External interface model for the func operation. Provides non-deafult and 281 /// overrides default methods. 282 struct TestExternalOpOverridingModel 283 : public TestExternalOpInterface::FallbackModel< 284 TestExternalOpOverridingModel> { 285 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 286 return op->getName().getStringRef().size() + arg; 287 } 288 289 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 290 return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg; 291 } 292 293 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { 294 return 42; 295 } 296 297 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } 298 }; 299 300 TEST(InterfaceAttachment, Operation) { 301 MLIRContext context; 302 OpBuilder builder(&context); 303 304 // Initially, the operation doesn't have the interface. 305 OwningOpRef<ModuleOp> moduleOp = 306 builder.create<ModuleOp>(UnknownLoc::get(&context)); 307 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation())); 308 309 // We can attach an external interface and now the operaiton has it. 310 ModuleOp::attachInterface<TestExternalOpModel>(context); 311 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation()); 312 ASSERT_TRUE(iface != nullptr); 313 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u); 314 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u); 315 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u); 316 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); 317 318 // Default implementation can be overridden. 319 OwningOpRef<UnrealizedConversionCastOp> castOp = 320 builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context), 321 TypeRange(), ValueRange()); 322 ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation())); 323 UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>( 324 context); 325 iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation()); 326 ASSERT_TRUE(iface != nullptr); 327 EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u); 328 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); 329 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u); 330 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); 331 332 // Another context doesn't have the interfaces registered. 333 MLIRContext other; 334 OwningOpRef<ModuleOp> otherModuleOp = 335 ModuleOp::create(UnknownLoc::get(&other)); 336 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation())); 337 } 338 339 template <class ConcreteOp> 340 struct TestExternalTestOpModel 341 : public TestExternalOpInterface::ExternalModel< 342 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> { 343 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 344 return op->getName().getStringRef().size() + arg; 345 } 346 347 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 348 return ConcreteOp::getOperationName().size() + 2 * arg; 349 } 350 }; 351 352 TEST(InterfaceAttachment, OperationDelayedContextConstruct) { 353 DialectRegistry registry; 354 registry.insert<test::TestDialect>(); 355 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 356 ModuleOp::attachInterface<TestExternalOpModel>(*ctx); 357 }); 358 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 359 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); 360 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); 361 }); 362 363 // Construct the context directly from a registry. The interfaces are 364 // expected to be readily available on operations. 365 MLIRContext context(registry); 366 context.loadDialect<test::TestDialect>(); 367 368 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 369 OpBuilder builder(module->getBody(), module->getBody()->begin()); 370 auto opJ = 371 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 372 auto opH = 373 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 374 auto opI = 375 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 376 377 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 378 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 379 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 380 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 381 } 382 383 TEST(InterfaceAttachment, OperationDelayedContextAppend) { 384 DialectRegistry registry; 385 registry.insert<test::TestDialect>(); 386 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 387 ModuleOp::attachInterface<TestExternalOpModel>(*ctx); 388 }); 389 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 390 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); 391 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); 392 }); 393 394 // Construct the context, create ops, and only then append the registry. The 395 // interfaces are expected to be available after appending the registry. 396 MLIRContext context; 397 context.loadDialect<test::TestDialect>(); 398 399 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 400 OpBuilder builder(module->getBody(), module->getBody()->begin()); 401 auto opJ = 402 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 403 auto opH = 404 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 405 auto opI = 406 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 407 408 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation())); 409 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation())); 410 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation())); 411 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 412 413 context.appendDialectRegistry(registry); 414 415 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 416 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 417 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 418 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 419 } 420 421 TEST(InterfaceAttachmentTest, PromisedInterfaces) { 422 // Attribute interfaces use the exact same mechanism as types, so just check 423 // that the promise mechanism works for attributes. 424 MLIRContext context; 425 auto *testDialect = context.getOrLoadDialect<test::TestDialect>(); 426 auto attr = test::SimpleAAttr::get(&context); 427 428 // `SimpleAAttr` doesn't implement nor promises the 429 // `TestExternalAttrInterface` interface. 430 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); 431 EXPECT_FALSE( 432 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); 433 434 // Add a promise `TestExternalAttrInterface`. 435 testDialect->declarePromisedInterface<TestExternalAttrInterface, 436 test::SimpleAAttr>(); 437 EXPECT_TRUE( 438 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); 439 440 // Attach the interface. 441 test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context); 442 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); 443 EXPECT_TRUE( 444 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); 445 } 446 447 } // namespace 448