1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/DataLayoutInterfaces.h" 10 #include "mlir/Dialect/DLTI/DLTI.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/Parser/Parser.h" 18 19 #include <gtest/gtest.h> 20 21 using namespace mlir; 22 23 namespace { 24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout"; 25 26 /// Trivial array storage for the custom data layout spec attribute, just a list 27 /// of entries. 28 class DataLayoutSpecStorage : public AttributeStorage { 29 public: 30 using KeyTy = ArrayRef<DataLayoutEntryInterface>; 31 32 DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries) 33 : entries(entries) {} 34 35 bool operator==(const KeyTy &key) const { return key == entries; } 36 37 static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator, 38 const KeyTy &key) { 39 return new (allocator.allocate<DataLayoutSpecStorage>()) 40 DataLayoutSpecStorage(allocator.copyInto(key)); 41 } 42 43 ArrayRef<DataLayoutEntryInterface> entries; 44 }; 45 46 /// Simple data layout spec containing a list of entries that always verifies 47 /// as valid. 48 struct CustomDataLayoutSpec 49 : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute, 50 DataLayoutSpecStorage, 51 DataLayoutSpecInterface::Trait> { 52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec) 53 54 using Base::Base; 55 static CustomDataLayoutSpec get(MLIRContext *ctx, 56 ArrayRef<DataLayoutEntryInterface> entries) { 57 return Base::get(ctx, entries); 58 } 59 CustomDataLayoutSpec 60 combineWith(ArrayRef<DataLayoutSpecInterface> specs) const { 61 return *this; 62 } 63 DataLayoutEntryListRef getEntries() const { return getImpl()->entries; } 64 LogicalResult verifySpec(Location loc) { return success(); } 65 }; 66 67 /// A type subject to data layout that exits the program if it is queried more 68 /// than once. Handy to check if the cache works. 69 struct SingleQueryType 70 : public Type::TypeBase<SingleQueryType, Type, TypeStorage, 71 DataLayoutTypeInterface::Trait> { 72 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SingleQueryType) 73 74 using Base::Base; 75 76 static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); } 77 78 unsigned getTypeSizeInBits(const DataLayout &layout, 79 DataLayoutEntryListRef params) const { 80 static bool executed = false; 81 if (executed) 82 llvm::report_fatal_error("repeated call"); 83 84 executed = true; 85 return 1; 86 } 87 88 unsigned getABIAlignment(const DataLayout &layout, 89 DataLayoutEntryListRef params) { 90 static bool executed = false; 91 if (executed) 92 llvm::report_fatal_error("repeated call"); 93 94 executed = true; 95 return 2; 96 } 97 98 unsigned getPreferredAlignment(const DataLayout &layout, 99 DataLayoutEntryListRef params) { 100 static bool executed = false; 101 if (executed) 102 llvm::report_fatal_error("repeated call"); 103 104 executed = true; 105 return 4; 106 } 107 }; 108 109 /// A types that is not subject to data layout. 110 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> { 111 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TypeNoLayout) 112 113 using Base::Base; 114 115 static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); } 116 }; 117 118 /// An op that serves as scope for data layout queries with the relevant 119 /// attribute attached. This can handle data layout requests for the built-in 120 /// types itself. 121 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> { 122 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithLayout) 123 124 using Op::Op; 125 static ArrayRef<StringRef> getAttributeNames() { return {}; } 126 127 static StringRef getOperationName() { return "dltest.op_with_layout"; } 128 129 DataLayoutSpecInterface getDataLayoutSpec() { 130 return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName); 131 } 132 133 static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout, 134 DataLayoutEntryListRef params) { 135 // Make a recursive query. 136 if (type.isa<FloatType>()) 137 return dataLayout.getTypeSizeInBits( 138 IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth())); 139 140 // Handle built-in types that are not handled by the default process. 141 if (auto iType = type.dyn_cast<IntegerType>()) { 142 for (DataLayoutEntryInterface entry : params) 143 if (entry.getKey().dyn_cast<Type>() == type) 144 return 8 * 145 entry.getValue().cast<IntegerAttr>().getValue().getZExtValue(); 146 return 8 * iType.getIntOrFloatBitWidth(); 147 } 148 149 // Use the default process for everything else. 150 return detail::getDefaultTypeSize(type, dataLayout, params); 151 } 152 153 static unsigned getTypeABIAlignment(Type type, const DataLayout &dataLayout, 154 DataLayoutEntryListRef params) { 155 return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params)); 156 } 157 158 static unsigned getTypePreferredAlignment(Type type, 159 const DataLayout &dataLayout, 160 DataLayoutEntryListRef params) { 161 return 2 * getTypeABIAlignment(type, dataLayout, params); 162 } 163 }; 164 165 struct OpWith7BitByte 166 : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> { 167 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWith7BitByte) 168 169 using Op::Op; 170 static ArrayRef<StringRef> getAttributeNames() { return {}; } 171 172 static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; } 173 174 DataLayoutSpecInterface getDataLayoutSpec() { 175 return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName); 176 } 177 178 // Bytes are assumed to be 7-bit here. 179 static unsigned getTypeSize(Type type, const DataLayout &dataLayout, 180 DataLayoutEntryListRef params) { 181 return llvm::divideCeil(dataLayout.getTypeSizeInBits(type), 7); 182 } 183 }; 184 185 /// A dialect putting all the above together. 186 struct DLTestDialect : Dialect { 187 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTestDialect) 188 189 explicit DLTestDialect(MLIRContext *ctx) 190 : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) { 191 ctx->getOrLoadDialect<DLTIDialect>(); 192 addAttributes<CustomDataLayoutSpec>(); 193 addOperations<OpWithLayout, OpWith7BitByte>(); 194 addTypes<SingleQueryType, TypeNoLayout>(); 195 } 196 static StringRef getDialectNamespace() { return "dltest"; } 197 198 void printAttribute(Attribute attr, 199 DialectAsmPrinter &printer) const override { 200 printer << "spec<"; 201 llvm::interleaveComma(attr.cast<CustomDataLayoutSpec>().getEntries(), 202 printer); 203 printer << ">"; 204 } 205 206 Attribute parseAttribute(DialectAsmParser &parser, Type type) const override { 207 bool ok = 208 succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess()); 209 (void)ok; 210 assert(ok); 211 if (succeeded(parser.parseOptionalGreater())) 212 return CustomDataLayoutSpec::get(parser.getContext(), {}); 213 214 SmallVector<DataLayoutEntryInterface> entries; 215 ok = succeeded(parser.parseCommaSeparatedList([&]() { 216 entries.emplace_back(); 217 ok = succeeded(parser.parseAttribute(entries.back())); 218 assert(ok); 219 return success(); 220 })); 221 assert(ok); 222 ok = succeeded(parser.parseGreater()); 223 assert(ok); 224 return CustomDataLayoutSpec::get(parser.getContext(), entries); 225 } 226 227 void printType(Type type, DialectAsmPrinter &printer) const override { 228 if (type.isa<SingleQueryType>()) 229 printer << "single_query"; 230 else 231 printer << "no_layout"; 232 } 233 234 Type parseType(DialectAsmParser &parser) const override { 235 bool ok = succeeded(parser.parseKeyword("single_query")); 236 (void)ok; 237 assert(ok); 238 return SingleQueryType::get(parser.getContext()); 239 } 240 }; 241 242 } // namespace 243 244 TEST(DataLayout, FallbackDefault) { 245 const char *ir = R"MLIR( 246 module {} 247 )MLIR"; 248 249 DialectRegistry registry; 250 registry.insert<DLTIDialect, DLTestDialect>(); 251 MLIRContext ctx(registry); 252 253 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 254 DataLayout layout(module.get()); 255 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); 256 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u); 257 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); 258 EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u); 259 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); 260 EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u); 261 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u); 262 EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u); 263 } 264 265 TEST(DataLayout, NullSpec) { 266 const char *ir = R"MLIR( 267 "dltest.op_with_layout"() : () -> () 268 )MLIR"; 269 270 DialectRegistry registry; 271 registry.insert<DLTIDialect, DLTestDialect>(); 272 MLIRContext ctx(registry); 273 274 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 275 auto op = 276 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 277 DataLayout layout(op); 278 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u); 279 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u); 280 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u); 281 EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u); 282 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u); 283 EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u); 284 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u); 285 EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u); 286 } 287 288 TEST(DataLayout, EmptySpec) { 289 const char *ir = R"MLIR( 290 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> () 291 )MLIR"; 292 293 DialectRegistry registry; 294 registry.insert<DLTIDialect, DLTestDialect>(); 295 MLIRContext ctx(registry); 296 297 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 298 auto op = 299 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 300 DataLayout layout(op); 301 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u); 302 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u); 303 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u); 304 EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u); 305 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u); 306 EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u); 307 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u); 308 EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u); 309 } 310 311 TEST(DataLayout, SpecWithEntries) { 312 const char *ir = R"MLIR( 313 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< 314 #dlti.dl_entry<i42, 5>, 315 #dlti.dl_entry<i16, 6> 316 > } : () -> () 317 )MLIR"; 318 319 DialectRegistry registry; 320 registry.insert<DLTIDialect, DLTestDialect>(); 321 MLIRContext ctx(registry); 322 323 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 324 auto op = 325 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 326 DataLayout layout(op); 327 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u); 328 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u); 329 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u); 330 EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u); 331 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); 332 EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u); 333 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u); 334 EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u); 335 336 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u); 337 EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u); 338 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u); 339 EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u); 340 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u); 341 EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u); 342 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u); 343 EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u); 344 } 345 346 TEST(DataLayout, Caching) { 347 const char *ir = R"MLIR( 348 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> () 349 )MLIR"; 350 351 DialectRegistry registry; 352 registry.insert<DLTIDialect, DLTestDialect>(); 353 MLIRContext ctx(registry); 354 355 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 356 auto op = 357 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 358 DataLayout layout(op); 359 360 unsigned sum = 0; 361 sum += layout.getTypeSize(SingleQueryType::get(&ctx)); 362 // The second call should hit the cache. If it does not, the function in 363 // SingleQueryType will be called and will abort the process. 364 sum += layout.getTypeSize(SingleQueryType::get(&ctx)); 365 // Make sure the complier doesn't optimize away the query code. 366 EXPECT_EQ(sum, 2u); 367 368 // A fresh data layout has a new cache, so the call to it should be dispatched 369 // down to the type and abort the proces. 370 DataLayout second(op); 371 ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call"); 372 } 373 374 TEST(DataLayout, CacheInvalidation) { 375 const char *ir = R"MLIR( 376 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< 377 #dlti.dl_entry<i42, 5>, 378 #dlti.dl_entry<i16, 6> 379 > } : () -> () 380 )MLIR"; 381 382 DialectRegistry registry; 383 registry.insert<DLTIDialect, DLTestDialect>(); 384 MLIRContext ctx(registry); 385 386 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 387 auto op = 388 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 389 DataLayout layout(op); 390 391 // Normal query is fine. 392 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u); 393 394 // Replace the data layout spec with a new, empty spec. 395 op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {})); 396 397 // Data layout is no longer valid and should trigger assertion when queried. 398 #ifndef NDEBUG 399 ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid"); 400 #endif 401 } 402 403 TEST(DataLayout, UnimplementedTypeInterface) { 404 const char *ir = R"MLIR( 405 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> () 406 )MLIR"; 407 408 DialectRegistry registry; 409 registry.insert<DLTIDialect, DLTestDialect>(); 410 MLIRContext ctx(registry); 411 412 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 413 auto op = 414 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 415 DataLayout layout(op); 416 417 ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)), 418 "neither the scoping op nor the type class provide data layout " 419 "information"); 420 } 421 422 TEST(DataLayout, SevenBitByte) { 423 const char *ir = R"MLIR( 424 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> () 425 )MLIR"; 426 427 DialectRegistry registry; 428 registry.insert<DLTIDialect, DLTestDialect>(); 429 MLIRContext ctx(registry); 430 431 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 432 auto op = 433 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 434 DataLayout layout(op); 435 436 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); 437 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u); 438 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); 439 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u); 440 } 441