1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This implements the tests for attaching interfaces to attributes and types 10 // without having to specify them on the attribute or type class directly. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/BuiltinDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "gtest/gtest.h" 19 20 #include "../../test/lib/Dialect/Test/TestAttributes.h" 21 #include "../../test/lib/Dialect/Test/TestDialect.h" 22 #include "../../test/lib/Dialect/Test/TestTypes.h" 23 #include "mlir/IR/OwningOpRef.h" 24 25 using namespace mlir; 26 using namespace test; 27 28 namespace { 29 30 /// External interface model for the integer type. Only provides non-default 31 /// methods. 32 struct Model 33 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { 34 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 35 return type.getIntOrFloatBitWidth() + arg; 36 } 37 38 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 39 }; 40 41 /// External interface model for the float type. Provides non-deafult and 42 /// overrides default methods. 43 struct OverridingModel 44 : public TestExternalTypeInterface::ExternalModel<OverridingModel, 45 FloatType> { 46 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { 47 return type.getIntOrFloatBitWidth() + arg; 48 } 49 50 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } 51 52 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { 53 return 128; 54 } 55 56 static unsigned staticGetArgument(unsigned arg) { return 420; } 57 }; 58 59 TEST(InterfaceAttachment, Type) { 60 MLIRContext context; 61 62 // Check that the type has no interface. 63 IntegerType i8 = IntegerType::get(&context, 8); 64 ASSERT_FALSE(i8.isa<TestExternalTypeInterface>()); 65 66 // Attach an interface and check that the type now has the interface. 67 IntegerType::attachInterface<Model>(context); 68 TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>(); 69 ASSERT_TRUE(iface != nullptr); 70 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); 71 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); 72 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); 73 EXPECT_EQ(iface.staticGetArgument(17), 17u); 74 75 // Same, but with the default implementation overridden. 76 FloatType flt = Float32Type::get(&context); 77 ASSERT_FALSE(flt.isa<TestExternalTypeInterface>()); 78 Float32Type::attachInterface<OverridingModel>(context); 79 iface = flt.dyn_cast<TestExternalTypeInterface>(); 80 ASSERT_TRUE(iface != nullptr); 81 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); 82 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); 83 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); 84 EXPECT_EQ(iface.staticGetArgument(17), 420u); 85 86 // Other contexts shouldn't have the attribute attached. 87 MLIRContext other; 88 IntegerType i8other = IntegerType::get(&other, 8); 89 EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>()); 90 } 91 92 /// External interface model for the test type from the test dialect. 93 struct TestTypeModel 94 : public TestExternalTypeInterface::ExternalModel<TestTypeModel, 95 test::TestType> { 96 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } 97 98 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } 99 }; 100 101 TEST(InterfaceAttachment, TypeDelayedContextConstruct) { 102 // Put the interface in the registry. 103 DialectRegistry registry; 104 registry.insert<test::TestDialect>(); 105 registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>(); 106 107 // Check that when a context is constructed with the given registry, the type 108 // interface gets registered. 109 MLIRContext context(registry); 110 context.loadDialect<test::TestDialect>(); 111 test::TestType testType = test::TestType::get(&context); 112 auto iface = testType.dyn_cast<TestExternalTypeInterface>(); 113 ASSERT_TRUE(iface != nullptr); 114 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); 115 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); 116 } 117 118 TEST(InterfaceAttachment, TypeDelayedContextAppend) { 119 // Put the interface in the registry. 120 DialectRegistry registry; 121 registry.insert<test::TestDialect>(); 122 registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>(); 123 124 // Check that when the registry gets appended to the context, the interface 125 // becomes available for objects in loaded dialects. 126 MLIRContext context; 127 context.loadDialect<test::TestDialect>(); 128 test::TestType testType = test::TestType::get(&context); 129 EXPECT_FALSE(testType.isa<TestExternalTypeInterface>()); 130 context.appendDialectRegistry(registry); 131 EXPECT_TRUE(testType.isa<TestExternalTypeInterface>()); 132 } 133 134 TEST(InterfaceAttachment, RepeatedRegistration) { 135 DialectRegistry registry; 136 registry.addTypeInterface<BuiltinDialect, IntegerType, Model>(); 137 MLIRContext context(registry); 138 139 // Should't fail on repeated registration through the dialect registry. 140 context.appendDialectRegistry(registry); 141 } 142 143 TEST(InterfaceAttachment, TypeBuiltinDelayed) { 144 // Builtin dialect needs to registration or loading, but delayed interface 145 // registration must still work. 146 DialectRegistry registry; 147 registry.addTypeInterface<BuiltinDialect, IntegerType, Model>(); 148 149 MLIRContext context(registry); 150 IntegerType i16 = IntegerType::get(&context, 16); 151 EXPECT_TRUE(i16.isa<TestExternalTypeInterface>()); 152 153 MLIRContext initiallyEmpty; 154 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); 155 EXPECT_FALSE(i32.isa<TestExternalTypeInterface>()); 156 initiallyEmpty.appendDialectRegistry(registry); 157 EXPECT_TRUE(i32.isa<TestExternalTypeInterface>()); 158 } 159 160 /// The interface provides a default implementation that expects 161 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this 162 /// just derives from the ExternalModel. 163 struct TestExternalFallbackTypeIntegerModel 164 : public TestExternalFallbackTypeInterface::ExternalModel< 165 TestExternalFallbackTypeIntegerModel, IntegerType> {}; 166 167 /// The interface provides a default implementation that expects 168 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use 169 /// FallbackModel instead to override this and make sure the code still compiles 170 /// because we never instantiate the ExternalModel class template with a 171 /// template argument that would have led to compilation failures. 172 struct TestExternalFallbackTypeVectorModel 173 : public TestExternalFallbackTypeInterface::FallbackModel< 174 TestExternalFallbackTypeVectorModel> { 175 unsigned getBitwidth(Type type) const { 176 IntegerType elementType = type.cast<VectorType>() 177 .getElementType() 178 .dyn_cast_or_null<IntegerType>(); 179 return elementType ? elementType.getWidth() : 0; 180 } 181 }; 182 183 TEST(InterfaceAttachment, Fallback) { 184 MLIRContext context; 185 186 // Just check that we can attach the interface. 187 IntegerType i8 = IntegerType::get(&context, 8); 188 ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>()); 189 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); 190 ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>()); 191 192 // Call the method so it is guaranteed not to be instantiated. 193 VectorType vec = VectorType::get({42}, i8); 194 ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>()); 195 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); 196 ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>()); 197 EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u); 198 } 199 200 /// External model for attribute interfaces. 201 struct TestExternalIntegerAttrModel 202 : public TestExternalAttrInterface::ExternalModel< 203 TestExternalIntegerAttrModel, IntegerAttr> { 204 const Dialect *getDialectPtr(Attribute attr) const { 205 return &attr.cast<IntegerAttr>().getDialect(); 206 } 207 208 static int getSomeNumber() { return 42; } 209 }; 210 211 TEST(InterfaceAttachment, Attribute) { 212 MLIRContext context; 213 214 // Attribute interfaces use the exact same mechanism as types, so just check 215 // that the basics work for attributes. 216 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); 217 ASSERT_FALSE(attr.isa<TestExternalAttrInterface>()); 218 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context); 219 auto iface = attr.dyn_cast<TestExternalAttrInterface>(); 220 ASSERT_TRUE(iface != nullptr); 221 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); 222 EXPECT_EQ(iface.getSomeNumber(), 42); 223 } 224 225 /// External model for an interface attachable to a non-builtin attribute. 226 struct TestExternalSimpleAAttrModel 227 : public TestExternalAttrInterface::ExternalModel< 228 TestExternalSimpleAAttrModel, test::SimpleAAttr> { 229 const Dialect *getDialectPtr(Attribute attr) const { 230 return &attr.getDialect(); 231 } 232 233 static int getSomeNumber() { return 21; } 234 }; 235 236 TEST(InterfaceAttachmentTest, AttributeDelayed) { 237 // Attribute interfaces use the exact same mechanism as types, so just check 238 // that the delayed registration work for attributes. 239 DialectRegistry registry; 240 registry.insert<test::TestDialect>(); 241 registry.addAttrInterface<test::TestDialect, test::SimpleAAttr, 242 TestExternalSimpleAAttrModel>(); 243 244 MLIRContext context(registry); 245 context.loadDialect<test::TestDialect>(); 246 auto attr = test::SimpleAAttr::get(&context); 247 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>()); 248 249 MLIRContext initiallyEmpty; 250 initiallyEmpty.loadDialect<test::TestDialect>(); 251 attr = test::SimpleAAttr::get(&initiallyEmpty); 252 EXPECT_FALSE(attr.isa<TestExternalAttrInterface>()); 253 initiallyEmpty.appendDialectRegistry(registry); 254 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>()); 255 } 256 257 /// External interface model for the module operation. Only provides non-default 258 /// methods. 259 struct TestExternalOpModel 260 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel, 261 ModuleOp> { 262 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 263 return op->getName().getStringRef().size() + arg; 264 } 265 266 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 267 return ModuleOp::getOperationName().size() + 2 * arg; 268 } 269 }; 270 271 /// External interface model for the func operation. Provides non-deafult and 272 /// overrides default methods. 273 struct TestExternalOpOverridingModel 274 : public TestExternalOpInterface::FallbackModel< 275 TestExternalOpOverridingModel> { 276 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 277 return op->getName().getStringRef().size() + arg; 278 } 279 280 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 281 return FuncOp::getOperationName().size() + 2 * arg; 282 } 283 284 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { 285 return 42; 286 } 287 288 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } 289 }; 290 291 TEST(InterfaceAttachment, Operation) { 292 MLIRContext context; 293 294 // Initially, the operation doesn't have the interface. 295 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(UnknownLoc::get(&context)); 296 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation())); 297 298 // We can attach an external interface and now the operaiton has it. 299 ModuleOp::attachInterface<TestExternalOpModel>(context); 300 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation()); 301 ASSERT_TRUE(iface != nullptr); 302 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u); 303 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u); 304 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u); 305 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); 306 307 // Default implementation can be overridden. 308 OwningOpRef<FuncOp> funcOp = 309 FuncOp::create(UnknownLoc::get(&context), "function", 310 FunctionType::get(&context, {}, {})); 311 ASSERT_FALSE(isa<TestExternalOpInterface>(funcOp->getOperation())); 312 FuncOp::attachInterface<TestExternalOpOverridingModel>(context); 313 iface = dyn_cast<TestExternalOpInterface>(funcOp->getOperation()); 314 ASSERT_TRUE(iface != nullptr); 315 EXPECT_EQ(iface.getNameLengthPlusArg(10), 22u); 316 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); 317 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 28u); 318 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); 319 320 // Another context doesn't have the interfaces registered. 321 MLIRContext other; 322 OwningOpRef<ModuleOp> otherModuleOp = 323 ModuleOp::create(UnknownLoc::get(&other)); 324 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation())); 325 } 326 327 template <class ConcreteOp> 328 struct TestExternalTestOpModel 329 : public TestExternalOpInterface::ExternalModel< 330 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> { 331 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { 332 return op->getName().getStringRef().size() + arg; 333 } 334 335 static unsigned getNameLengthPlusArgTwice(unsigned arg) { 336 return ConcreteOp::getOperationName().size() + 2 * arg; 337 } 338 }; 339 340 TEST(InterfaceAttachment, OperationDelayedContextConstruct) { 341 DialectRegistry registry; 342 registry.insert<test::TestDialect>(); 343 registry.addOpInterface<ModuleOp, TestExternalOpModel>(); 344 registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>(); 345 registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>(); 346 347 // Construct the context directly from a registry. The interfaces are expected 348 // to be readily available on operations. 349 MLIRContext context(registry); 350 context.loadDialect<test::TestDialect>(); 351 352 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 353 OpBuilder builder(module->getBody(), module->getBody()->begin()); 354 auto opJ = 355 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 356 auto opH = 357 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 358 auto opI = 359 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 360 361 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 362 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 363 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 364 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 365 } 366 367 TEST(InterfaceAttachment, OperationDelayedContextAppend) { 368 DialectRegistry registry; 369 registry.insert<test::TestDialect>(); 370 registry.addOpInterface<ModuleOp, TestExternalOpModel>(); 371 registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>(); 372 registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>(); 373 374 // Construct the context, create ops, and only then append the registry. The 375 // interfaces are expected to be available after appending the registry. 376 MLIRContext context; 377 context.loadDialect<test::TestDialect>(); 378 379 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); 380 OpBuilder builder(module->getBody(), module->getBody()->begin()); 381 auto opJ = 382 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); 383 auto opH = 384 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); 385 auto opI = 386 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); 387 388 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation())); 389 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation())); 390 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation())); 391 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 392 393 context.appendDialectRegistry(registry); 394 395 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); 396 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); 397 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); 398 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); 399 } 400 401 } // end namespace 402