1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AsmState.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "gtest/gtest.h" 14 #include <optional> 15 16 #include "../../test/lib/Dialect/Test/TestDialect.h" 17 18 using namespace mlir; 19 using namespace mlir::detail; 20 21 //===----------------------------------------------------------------------===// 22 // DenseElementsAttr 23 //===----------------------------------------------------------------------===// 24 25 template <typename EltTy> 26 static void testSplat(Type eltType, const EltTy &splatElt) { 27 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 28 29 // Check that the generated splat is the same for 1 element and N elements. 30 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 31 EXPECT_TRUE(splat.isSplat()); 32 33 auto detectedSplat = 34 DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); 35 EXPECT_EQ(detectedSplat, splat); 36 37 for (auto newValue : detectedSplat.template getValues<EltTy>()) 38 EXPECT_TRUE(newValue == splatElt); 39 } 40 41 namespace { 42 TEST(DenseSplatTest, BoolSplat) { 43 MLIRContext context; 44 IntegerType boolTy = IntegerType::get(&context, 1); 45 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 46 47 // Check that splat is automatically detected for boolean values. 48 /// True. 49 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 50 EXPECT_TRUE(trueSplat.isSplat()); 51 /// False. 52 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 53 EXPECT_TRUE(falseSplat.isSplat()); 54 EXPECT_NE(falseSplat, trueSplat); 55 56 /// Detect and handle splat within 8 elements (bool values are bit-packed). 57 /// True. 58 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 59 EXPECT_EQ(detectedSplat, trueSplat); 60 /// False. 61 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 62 EXPECT_EQ(detectedSplat, falseSplat); 63 } 64 TEST(DenseSplatTest, BoolSplatRawRoundtrip) { 65 MLIRContext context; 66 IntegerType boolTy = IntegerType::get(&context, 1); 67 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 68 69 // Check that splat booleans properly round trip via the raw API. 70 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 71 EXPECT_TRUE(trueSplat.isSplat()); 72 DenseElementsAttr trueSplatFromRaw = 73 DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData()); 74 EXPECT_TRUE(trueSplatFromRaw.isSplat()); 75 76 EXPECT_EQ(trueSplat, trueSplatFromRaw); 77 } 78 79 TEST(DenseSplatTest, BoolSplatSmall) { 80 MLIRContext context; 81 Builder builder(&context); 82 83 // Check that splats that don't fill entire byte are handled properly. 84 auto tensorType = RankedTensorType::get({4}, builder.getI1Type()); 85 std::vector<char> data{0b00001111}; 86 auto trueSplatFromRaw = 87 DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data); 88 EXPECT_TRUE(trueSplatFromRaw.isSplat()); 89 DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true); 90 EXPECT_EQ(trueSplat, trueSplatFromRaw); 91 } 92 93 TEST(DenseSplatTest, LargeBoolSplat) { 94 constexpr int64_t boolCount = 56; 95 96 MLIRContext context; 97 IntegerType boolTy = IntegerType::get(&context, 1); 98 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 99 100 // Check that splat is automatically detected for boolean values. 101 /// True. 102 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 103 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 104 EXPECT_TRUE(trueSplat.isSplat()); 105 EXPECT_TRUE(falseSplat.isSplat()); 106 107 /// Detect that the large boolean arrays are properly splatted. 108 /// True. 109 SmallVector<bool, 64> trueValues(boolCount, true); 110 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 111 EXPECT_EQ(detectedSplat, trueSplat); 112 /// False. 113 SmallVector<bool, 64> falseValues(boolCount, false); 114 detectedSplat = DenseElementsAttr::get(shape, falseValues); 115 EXPECT_EQ(detectedSplat, falseSplat); 116 } 117 118 TEST(DenseSplatTest, BoolNonSplat) { 119 MLIRContext context; 120 IntegerType boolTy = IntegerType::get(&context, 1); 121 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 122 123 // Check that we properly handle non-splat values. 124 DenseElementsAttr nonSplat = 125 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 126 EXPECT_FALSE(nonSplat.isSplat()); 127 } 128 129 TEST(DenseSplatTest, OddIntSplat) { 130 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 131 MLIRContext context; 132 constexpr size_t intWidth = 19; 133 IntegerType intTy = IntegerType::get(&context, intWidth); 134 APInt value(intWidth, 10); 135 136 testSplat(intTy, value); 137 } 138 139 TEST(DenseSplatTest, Int32Splat) { 140 MLIRContext context; 141 IntegerType intTy = IntegerType::get(&context, 32); 142 int value = 64; 143 144 testSplat(intTy, value); 145 } 146 147 TEST(DenseSplatTest, IntAttrSplat) { 148 MLIRContext context; 149 IntegerType intTy = IntegerType::get(&context, 85); 150 Attribute value = IntegerAttr::get(intTy, 109); 151 152 testSplat(intTy, value); 153 } 154 155 TEST(DenseSplatTest, F32Splat) { 156 MLIRContext context; 157 FloatType floatTy = Float32Type::get(&context); 158 float value = 10.0; 159 160 testSplat(floatTy, value); 161 } 162 163 TEST(DenseSplatTest, F64Splat) { 164 MLIRContext context; 165 FloatType floatTy = Float64Type::get(&context); 166 double value = 10.0; 167 168 testSplat(floatTy, APFloat(value)); 169 } 170 171 TEST(DenseSplatTest, FloatAttrSplat) { 172 MLIRContext context; 173 FloatType floatTy = Float32Type::get(&context); 174 Attribute value = FloatAttr::get(floatTy, 10.0); 175 176 testSplat(floatTy, value); 177 } 178 179 TEST(DenseSplatTest, BF16Splat) { 180 MLIRContext context; 181 FloatType floatTy = BFloat16Type::get(&context); 182 Attribute value = FloatAttr::get(floatTy, 10.0); 183 184 testSplat(floatTy, value); 185 } 186 187 TEST(DenseSplatTest, StringSplat) { 188 MLIRContext context; 189 context.allowUnregisteredDialects(); 190 Type stringType = 191 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 192 StringRef value = "test-string"; 193 testSplat(stringType, value); 194 } 195 196 TEST(DenseSplatTest, StringAttrSplat) { 197 MLIRContext context; 198 context.allowUnregisteredDialects(); 199 Type stringType = 200 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 201 Attribute stringAttr = StringAttr::get("test-string", stringType); 202 testSplat(stringType, stringAttr); 203 } 204 205 TEST(DenseComplexTest, ComplexFloatSplat) { 206 MLIRContext context; 207 ComplexType complexType = ComplexType::get(Float32Type::get(&context)); 208 std::complex<float> value(10.0, 15.0); 209 testSplat(complexType, value); 210 } 211 212 TEST(DenseComplexTest, ComplexIntSplat) { 213 MLIRContext context; 214 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 215 std::complex<int64_t> value(10, 15); 216 testSplat(complexType, value); 217 } 218 219 TEST(DenseComplexTest, ComplexAPFloatSplat) { 220 MLIRContext context; 221 ComplexType complexType = ComplexType::get(Float32Type::get(&context)); 222 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 223 testSplat(complexType, value); 224 } 225 226 TEST(DenseComplexTest, ComplexAPIntSplat) { 227 MLIRContext context; 228 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 229 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 230 testSplat(complexType, value); 231 } 232 233 TEST(DenseScalarTest, ExtractZeroRankElement) { 234 MLIRContext context; 235 const int elementValue = 12; 236 IntegerType intTy = IntegerType::get(&context, 32); 237 Attribute value = IntegerAttr::get(intTy, elementValue); 238 RankedTensorType shape = RankedTensorType::get({}, intTy); 239 240 auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})); 241 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 242 } 243 244 TEST(DenseSplatMapValuesTest, I32ToTrue) { 245 MLIRContext context; 246 const int elementValue = 12; 247 IntegerType boolTy = IntegerType::get(&context, 1); 248 IntegerType intTy = IntegerType::get(&context, 32); 249 RankedTensorType shape = RankedTensorType::get({4}, intTy); 250 251 auto attr = 252 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 253 .mapValues(boolTy, [](const APInt &x) { 254 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 255 }); 256 EXPECT_EQ(attr.getNumElements(), 4); 257 EXPECT_TRUE(attr.isSplat()); 258 EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue()); 259 } 260 261 TEST(DenseSplatMapValuesTest, I32ToFalse) { 262 MLIRContext context; 263 const int elementValue = 0; 264 IntegerType boolTy = IntegerType::get(&context, 1); 265 IntegerType intTy = IntegerType::get(&context, 32); 266 RankedTensorType shape = RankedTensorType::get({4}, intTy); 267 268 auto attr = 269 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 270 .mapValues(boolTy, [](const APInt &x) { 271 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 272 }); 273 EXPECT_EQ(attr.getNumElements(), 4); 274 EXPECT_TRUE(attr.isSplat()); 275 EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue()); 276 } 277 } // namespace 278 279 //===----------------------------------------------------------------------===// 280 // DenseResourceElementsAttr 281 //===----------------------------------------------------------------------===// 282 283 template <typename AttrT, typename T> 284 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, 285 Type elementType) { 286 auto type = RankedTensorType::get(data.size(), elementType); 287 auto attr = AttrT::get(type, "resource", 288 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 289 290 // Check that we can access and iterate the data properly. 291 std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); 292 EXPECT_TRUE(attrData.has_value()); 293 EXPECT_EQ(*attrData, data); 294 295 // Check that we cast to this attribute when possible. 296 Attribute genericAttr = attr; 297 EXPECT_TRUE(isa<AttrT>(genericAttr)); 298 } 299 template <typename AttrT, typename T> 300 static void checkNativeIntAccess(Builder &builder, size_t intWidth) { 301 T data[] = {0, 1, 2}; 302 checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data), 303 builder.getIntegerType(intWidth)); 304 } 305 306 namespace { 307 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { 308 MLIRContext context; 309 Builder builder(&context); 310 311 // Bool 312 bool boolData[] = {true, false, true}; 313 checkNativeAccess<DenseBoolResourceElementsAttr>( 314 &context, llvm::ArrayRef(boolData), builder.getI1Type()); 315 316 // Unsigned integers 317 checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8); 318 checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16); 319 checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32); 320 checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64); 321 322 // Signed integers 323 checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8); 324 checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16); 325 checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32); 326 checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64); 327 328 // Float 329 float floatData[] = {0, 1, 2}; 330 checkNativeAccess<DenseF32ResourceElementsAttr>( 331 &context, llvm::ArrayRef(floatData), builder.getF32Type()); 332 333 // Double 334 double doubleData[] = {0, 1, 2}; 335 checkNativeAccess<DenseF64ResourceElementsAttr>( 336 &context, llvm::ArrayRef(doubleData), builder.getF64Type()); 337 } 338 339 TEST(DenseResourceElementsAttrTest, CheckNoCast) { 340 MLIRContext context; 341 Builder builder(&context); 342 343 // Create a i32 attribute. 344 ArrayRef<uint32_t> data; 345 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 346 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 347 type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); 348 349 EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr)); 350 EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr)); 351 EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr)); 352 } 353 354 TEST(DenseResourceElementsAttrTest, CheckNotMutableAllocateAndCopy) { 355 MLIRContext context; 356 Builder builder(&context); 357 358 // Create a i32 attribute. 359 std::vector<int32_t> data = {10, 20, 30}; 360 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 361 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 362 type, "resource", 363 HeapAsmResourceBlob::allocateAndCopyInferAlign<int32_t>( 364 data, /*is_mutable=*/false)); 365 366 EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr)); 367 } 368 369 TEST(DenseResourceElementsAttrTest, CheckInvalidData) { 370 MLIRContext context; 371 Builder builder(&context); 372 373 // Create a bool attribute with data of the incorrect type. 374 ArrayRef<uint32_t> data; 375 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 376 EXPECT_DEBUG_DEATH( 377 { 378 DenseBoolResourceElementsAttr::get( 379 type, "resource", 380 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 381 }, 382 "alignment mismatch between expected alignment and blob alignment"); 383 } 384 385 TEST(DenseResourceElementsAttrTest, CheckInvalidType) { 386 MLIRContext context; 387 Builder builder(&context); 388 389 // Create a bool attribute with incorrect type. 390 ArrayRef<bool> data; 391 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 392 EXPECT_DEBUG_DEATH( 393 { 394 DenseBoolResourceElementsAttr::get( 395 type, "resource", 396 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 397 }, 398 "invalid shape element type for provided type `T`"); 399 } 400 } // namespace 401 402 //===----------------------------------------------------------------------===// 403 // SparseElementsAttr 404 //===----------------------------------------------------------------------===// 405 406 namespace { 407 TEST(SparseElementsAttrTest, GetZero) { 408 MLIRContext context; 409 context.allowUnregisteredDialects(); 410 411 IntegerType intTy = IntegerType::get(&context, 32); 412 FloatType floatTy = Float32Type::get(&context); 413 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 414 415 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 416 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 417 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 418 419 auto indicesType = 420 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 421 auto indices = 422 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 423 424 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 425 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 426 427 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 428 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 429 430 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 431 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 432 433 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 434 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 435 auto sparseString = 436 SparseElementsAttr::get(tensorString, indices, stringValue); 437 438 // Only index (0, 0) contains an element, others are supposed to return 439 // the zero/empty value. 440 auto zeroIntValue = 441 cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]); 442 EXPECT_EQ(zeroIntValue.getInt(), 0); 443 EXPECT_TRUE(zeroIntValue.getType() == intTy); 444 445 auto zeroFloatValue = 446 cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]); 447 EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); 448 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 449 450 auto zeroStringValue = 451 cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]); 452 EXPECT_TRUE(zeroStringValue.empty()); 453 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 454 } 455 456 //===----------------------------------------------------------------------===// 457 // SubElements 458 //===----------------------------------------------------------------------===// 459 460 TEST(SubElementTest, Nested) { 461 MLIRContext context; 462 Builder builder(&context); 463 464 BoolAttr trueAttr = builder.getBoolAttr(true); 465 BoolAttr falseAttr = builder.getBoolAttr(false); 466 ArrayAttr boolArrayAttr = 467 builder.getArrayAttr({trueAttr, falseAttr, trueAttr}); 468 StringAttr strAttr = builder.getStringAttr("array"); 469 DictionaryAttr dictAttr = 470 builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr)); 471 472 SmallVector<Attribute> subAttrs; 473 dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); }); 474 // Note that trueAttr appears only once, identical subattributes are skipped. 475 EXPECT_EQ(llvm::ArrayRef(subAttrs), 476 ArrayRef<Attribute>( 477 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); 478 } 479 480 // Test how many times we call copy-ctor when building an attribute. 481 TEST(CopyCountAttr, CopyCount) { 482 MLIRContext context; 483 context.loadDialect<test::TestDialect>(); 484 485 test::CopyCount::counter = 0; 486 test::CopyCount copyCount("hello"); 487 test::TestCopyCountAttr::get(&context, std::move(copyCount)); 488 int counter1 = test::CopyCount::counter; 489 test::CopyCount::counter = 0; 490 test::TestCopyCountAttr::get(&context, std::move(copyCount)); 491 #ifndef NDEBUG 492 // One verification enabled only in assert-mode requires a copy. 493 EXPECT_EQ(counter1, 1); 494 EXPECT_EQ(test::CopyCount::counter, 1); 495 #else 496 EXPECT_EQ(counter1, 0); 497 EXPECT_EQ(test::CopyCount::counter, 0); 498 #endif 499 } 500 501 // Test stripped printing using test dialect attribute. 502 TEST(CopyCountAttr, PrintStripped) { 503 MLIRContext context; 504 context.loadDialect<test::TestDialect>(); 505 // Doesn't matter which dialect attribute is used, just chose TestCopyCount 506 // given proximity. 507 test::CopyCount::counter = 0; 508 test::CopyCount copyCount("hello"); 509 Attribute res = test::TestCopyCountAttr::get(&context, std::move(copyCount)); 510 511 std::string str; 512 llvm::raw_string_ostream os(str); 513 os << "|" << res << "|"; 514 res.printStripped(os << "["); 515 os << "]"; 516 EXPECT_EQ(str, "|#test.copy_count<hello>|[copy_count<hello>]"); 517 } 518 519 } // namespace 520