1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This implements the tests for attaching interfaces to attributes and types 10 // without having to specify them on the attribute or type class directly. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/BuiltinDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "gtest/gtest.h" 19 20 #include "../../test/lib/Dialect/Test/TestAttributes.h" 21 #include "../../test/lib/Dialect/Test/TestDialect.h" 22 #include "../../test/lib/Dialect/Test/TestTypes.h" 23 #include "mlir/IR/OwningOpRef.h" 24 25 using namespace mlir; 26 using namespace test; 27 28 namespace { 29 30 /// External interface model for the integer type. Only provides non-default 31 /// methods. 32 struct Model 33 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { 34 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 35 return type.getIntOrFloatBitWidth() + arg; 36 } 37 38 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 39 }; 40 41 /// External interface model for the float type. Provides non-deafult and 42 /// overrides default methods. 43 struct OverridingModel 44 : public TestExternalTypeInterface::ExternalModel<OverridingModel, 45 FloatType> { 46 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 47 return type.getIntOrFloatBitWidth() + arg; 48 } 49 50 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 51 52 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { 53 return 128; 54 } 55 56 static unsigned staticGetArgument(unsigned arg) { return 420; } 57 }; 58 59 TEST(InterfaceAttachment, Type) { 60 MLIRContext context; 61 62 // Check that the type has no interface. 63 IntegerType i8 = IntegerType::get(&context, 8); 64 ASSERT_FALSE(i8.isa<TestExternalTypeInterface>()); 65 66 // Attach an interface and check that the type now has the interface. 67 IntegerType::attachInterface<Model>(context); 68 TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>(); 69 ASSERT_TRUE(iface != nullptr); 70 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); 71 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); 72 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); 73 EXPECT_EQ(iface.staticGetArgument(17), 17u); 74 75 // Same, but with the default implementation overridden. 76 FloatType flt = Float32Type::get(&context); 77 ASSERT_FALSE(flt.isa<TestExternalTypeInterface>()); 78 Float32Type::attachInterface<OverridingModel>(context); 79 iface = flt.dyn_cast<TestExternalTypeInterface>(); 80 ASSERT_TRUE(iface != nullptr); 81 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); 82 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); 83 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); 84 EXPECT_EQ(iface.staticGetArgument(17), 420u); 85 86 // Other contexts shouldn't have the attribute attached. 87 MLIRContext other; 88 IntegerType i8other = IntegerType::get(&other, 8); 89 EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>()); 90 } 91 92 /// External interface model for the test type from the test dialect. 93 struct TestTypeModel 94 : public TestExternalTypeInterface::ExternalModel<TestTypeModel, 95 test::TestType> { 96 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } 97 98 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } 99 }; 100 101 TEST(InterfaceAttachment, TypeDelayedContextConstruct) { 102 // Put the interface in the registry. 103 DialectRegistry registry; 104 registry.insert<test::TestDialect>(); 105 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 106 test::TestType::attachInterface<TestTypeModel>(*ctx); 107 }); 108 109 // Check that when a context is constructed with the given registry, the type 110 // interface gets registered. 111 MLIRContext context(registry); 112 context.loadDialect<test::TestDialect>(); 113 test::TestType testType = test::TestType::get(&context); 114 auto iface = testType.dyn_cast<TestExternalTypeInterface>(); 115 ASSERT_TRUE(iface != nullptr); 116 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); 117 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); 118 } 119 120 TEST(InterfaceAttachment, TypeDelayedContextAppend) { 121 // Put the interface in the registry. 122 DialectRegistry registry; 123 registry.insert<test::TestDialect>(); 124 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 125 test::TestType::attachInterface<TestTypeModel>(*ctx); 126 }); 127 128 // Check that when the registry gets appended to the context, the interface 129 // becomes available for objects in loaded dialects. 130 MLIRContext context; 131 context.loadDialect<test::TestDialect>(); 132 test::TestType testType = test::TestType::get(&context); 133 EXPECT_FALSE(testType.isa<TestExternalTypeInterface>()); 134 context.appendDialectRegistry(registry); 135 EXPECT_TRUE(testType.isa<TestExternalTypeInterface>()); 136 } 137 138 TEST(InterfaceAttachment, RepeatedRegistration) { 139 DialectRegistry registry; 140 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 141 IntegerType::attachInterface<Model>(*ctx); 142 }); 143 MLIRContext context(registry); 144 145 // Should't fail on repeated registration through the dialect registry. 146 context.appendDialectRegistry(registry); 147 } 148 149 TEST(InterfaceAttachment, TypeBuiltinDelayed) { 150 // Builtin dialect needs to registration or loading, but delayed interface 151 // registration must still work. 152 DialectRegistry registry; 153 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 154 IntegerType::attachInterface<Model>(*ctx); 155 }); 156 157 MLIRContext context(registry); 158 IntegerType i16 = IntegerType::get(&context, 16); 159 EXPECT_TRUE(i16.isa<TestExternalTypeInterface>()); 160 161 MLIRContext initiallyEmpty; 162 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); 163 EXPECT_FALSE(i32.isa<TestExternalTypeInterface>()); 164 initiallyEmpty.appendDialectRegistry(registry); 165 EXPECT_TRUE(i32.isa<TestExternalTypeInterface>()); 166 } 167 168 /// The interface provides a default implementation that expects 169 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this 170 /// just derives from the ExternalModel. 171 struct TestExternalFallbackTypeIntegerModel 172 : public TestExternalFallbackTypeInterface::ExternalModel< 173 TestExternalFallbackTypeIntegerModel, IntegerType> {}; 174 175 /// The interface provides a default implementation that expects 176 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use 177 /// FallbackModel instead to override this and make sure the code still compiles 178 /// because we never instantiate the ExternalModel class template with a 179 /// template argument that would have led to compilation failures. 180 struct TestExternalFallbackTypeVectorModel 181 : public TestExternalFallbackTypeInterface::FallbackModel< 182 TestExternalFallbackTypeVectorModel> { 183 unsigned getBitwidth(Type type) const { 184 IntegerType elementType = type.cast<VectorType>() 185 .getElementType() 186 .dyn_cast_or_null<IntegerType>(); 187 return elementType ? elementType.getWidth() : 0; 188 } 189 }; 190 191 TEST(InterfaceAttachment, Fallback) { 192 MLIRContext context; 193 194 // Just check that we can attach the interface. 195 IntegerType i8 = IntegerType::get(&context, 8); 196 ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>()); 197 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); 198 ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>()); 199 200 // Call the method so it is guaranteed not to be instantiated. 201 VectorType vec = VectorType::get({42}, i8); 202 ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>()); 203 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); 204 ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>()); 205 EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u); 206 } 207 208 /// External model for attribute interfaces. 209 struct TestExternalIntegerAttrModel 210 : public TestExternalAttrInterface::ExternalModel< 211 TestExternalIntegerAttrModel, IntegerAttr> { 212 const Dialect *getDialectPtr(Attribute attr) const { 213 return &attr.cast<IntegerAttr>().getDialect(); 214 } 215 216 static int getSomeNumber() { return 42; } 217 }; 218 219 TEST(InterfaceAttachment, Attribute) { 220 MLIRContext context; 221 222 // Attribute interfaces use the exact same mechanism as types, so just check 223 // that the basics work for attributes. 224 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); 225 ASSERT_FALSE(attr.isa<TestExternalAttrInterface>()); 226 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context); 227 auto iface = attr.dyn_cast<TestExternalAttrInterface>(); 228 ASSERT_TRUE(iface != nullptr); 229 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); 230 EXPECT_EQ(iface.getSomeNumber(), 42); 231 } 232 233 /// External model for an interface attachable to a non-builtin attribute. 234 struct TestExternalSimpleAAttrModel 235 : public TestExternalAttrInterface::ExternalModel< 236 TestExternalSimpleAAttrModel, test::SimpleAAttr> { 237 const Dialect *getDialectPtr(Attribute attr) const { 238 return &attr.getDialect(); 239 } 240 241 static int getSomeNumber() { return 21; } 242 }; 243 244 TEST(InterfaceAttachmentTest, AttributeDelayed) { 245 // Attribute interfaces use the exact same mechanism as types, so just check 246 // that the delayed registration work for attributes. 247 DialectRegistry registry; 248 registry.insert<test::TestDialect>(); 249 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 250 test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx); 251 }); 252 253 MLIRContext context(registry); 254 context.loadDialect<test::TestDialect>(); 255 auto attr = test::SimpleAAttr::get(&context); 256 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>()); 257 258 MLIRContext initiallyEmpty; 259 initiallyEmpty.loadDialect<test::TestDialect>(); 260 attr = test::SimpleAAttr::get(&initiallyEmpty); 261 EXPECT_FALSE(attr.isa<TestExternalAttrInterface>()); 262 initiallyEmpty.appendDialectRegistry(registry); 263 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>()); 264 } 265 266 /// External interface model for the module operation. Only provides non-default 267 /// methods. 268 struct TestExternalOpModel 269 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel, 270 ModuleOp> { 271 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 272 return op->getName().getStringRef().size() + arg; 273 } 274 275 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 276 return ModuleOp::getOperationName().size() + 2 * arg; 277 } 278 }; 279 280 /// External interface model for the func operation. Provides non-deafult and 281 /// overrides default methods. 282 struct TestExternalOpOverridingModel 283 : public TestExternalOpInterface::FallbackModel< 284 TestExternalOpOverridingModel> { 285 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 286 return op->getName().getStringRef().size() + arg; 287 } 288 289 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 290 return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg; 291 } 292 293 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { 294 return 42; 295 } 296 297 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } 298 }; 299 300 TEST(InterfaceAttachment, Operation) { 301 MLIRContext context; 302 OpBuilder builder(&context); 303 304 // Initially, the operation doesn't have the interface. 305 OwningOpRef<ModuleOp> moduleOp = 306 builder.create<ModuleOp>(UnknownLoc::get(&context)); 307 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation())); 308 309 // We can attach an external interface and now the operaiton has it. 310 ModuleOp::attachInterface<TestExternalOpModel>(context); 311 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation()); 312 ASSERT_TRUE(iface != nullptr); 313 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u); 314 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u); 315 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u); 316 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); 317 318 // Default implementation can be overridden. 319 OwningOpRef<UnrealizedConversionCastOp> castOp = 320 builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context), 321 TypeRange(), ValueRange()); 322 ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation())); 323 UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>( 324 context); 325 iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation()); 326 ASSERT_TRUE(iface != nullptr); 327 EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u); 328 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); 329 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u); 330 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); 331 332 // Another context doesn't have the interfaces registered. 333 MLIRContext other; 334 OwningOpRef<ModuleOp> otherModuleOp = 335 ModuleOp::create(UnknownLoc::get(&other)); 336 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation())); 337 } 338 339 template <class ConcreteOp> 340 struct TestExternalTestOpModel 341 : public TestExternalOpInterface::ExternalModel< 342 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> { 343 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 344 return op->getName().getStringRef().size() + arg; 345 } 346 347 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 348 return ConcreteOp::getOperationName().size() + 2 * arg; 349 } 350 }; 351 352 TEST(InterfaceAttachment, OperationDelayedContextConstruct) { 353 DialectRegistry registry; 354 registry.insert<test::TestDialect>(); 355 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 356 ModuleOp::attachInterface<TestExternalOpModel>(*ctx); 357 }); 358 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 359 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); 360 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); 361 }); 362 363 // Construct the context directly from a registry. The interfaces are 364 // expected to be readily available on operations. 365 MLIRContext context(registry); 366 context.loadDialect<test::TestDialect>(); 367 368 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 369 OpBuilder builder(module->getBody(), module->getBody()->begin()); 370 auto opJ = 371 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 372 auto opH = 373 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 374 auto opI = 375 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 376 377 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 378 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 379 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 380 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 381 } 382 383 TEST(InterfaceAttachment, OperationDelayedContextAppend) { 384 DialectRegistry registry; 385 registry.insert<test::TestDialect>(); 386 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { 387 ModuleOp::attachInterface<TestExternalOpModel>(*ctx); 388 }); 389 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { 390 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); 391 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); 392 }); 393 394 // Construct the context, create ops, and only then append the registry. The 395 // interfaces are expected to be available after appending the registry. 396 MLIRContext context; 397 context.loadDialect<test::TestDialect>(); 398 399 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 400 OpBuilder builder(module->getBody(), module->getBody()->begin()); 401 auto opJ = 402 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 403 auto opH = 404 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 405 auto opI = 406 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 407 408 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation())); 409 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation())); 410 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation())); 411 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 412 413 context.appendDialectRegistry(registry); 414 415 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 416 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 417 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 418 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 419 } 420 421 } // namespace 422