1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AsmState.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "gtest/gtest.h" 14 #include <optional> 15 16 using namespace mlir; 17 using namespace mlir::detail; 18 19 //===----------------------------------------------------------------------===// 20 // DenseElementsAttr 21 //===----------------------------------------------------------------------===// 22 23 template <typename EltTy> 24 static void testSplat(Type eltType, const EltTy &splatElt) { 25 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 26 27 // Check that the generated splat is the same for 1 element and N elements. 28 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 29 EXPECT_TRUE(splat.isSplat()); 30 31 auto detectedSplat = 32 DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); 33 EXPECT_EQ(detectedSplat, splat); 34 35 for (auto newValue : detectedSplat.template getValues<EltTy>()) 36 EXPECT_TRUE(newValue == splatElt); 37 } 38 39 namespace { 40 TEST(DenseSplatTest, BoolSplat) { 41 MLIRContext context; 42 IntegerType boolTy = IntegerType::get(&context, 1); 43 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 44 45 // Check that splat is automatically detected for boolean values. 46 /// True. 47 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 48 EXPECT_TRUE(trueSplat.isSplat()); 49 /// False. 50 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 51 EXPECT_TRUE(falseSplat.isSplat()); 52 EXPECT_NE(falseSplat, trueSplat); 53 54 /// Detect and handle splat within 8 elements (bool values are bit-packed). 55 /// True. 56 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 57 EXPECT_EQ(detectedSplat, trueSplat); 58 /// False. 59 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 60 EXPECT_EQ(detectedSplat, falseSplat); 61 } 62 TEST(DenseSplatTest, BoolSplatRawRoundtrip) { 63 MLIRContext context; 64 IntegerType boolTy = IntegerType::get(&context, 1); 65 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 66 67 // Check that splat booleans properly round trip via the raw API. 68 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 69 EXPECT_TRUE(trueSplat.isSplat()); 70 DenseElementsAttr trueSplatFromRaw = 71 DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData()); 72 EXPECT_TRUE(trueSplatFromRaw.isSplat()); 73 74 EXPECT_EQ(trueSplat, trueSplatFromRaw); 75 } 76 77 TEST(DenseSplatTest, BoolSplatSmall) { 78 MLIRContext context; 79 Builder builder(&context); 80 81 // Check that splats that don't fill entire byte are handled properly. 82 auto tensorType = RankedTensorType::get({4}, builder.getI1Type()); 83 std::vector<char> data{0b00001111}; 84 auto trueSplatFromRaw = 85 DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data); 86 EXPECT_TRUE(trueSplatFromRaw.isSplat()); 87 DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true); 88 EXPECT_EQ(trueSplat, trueSplatFromRaw); 89 } 90 91 TEST(DenseSplatTest, LargeBoolSplat) { 92 constexpr int64_t boolCount = 56; 93 94 MLIRContext context; 95 IntegerType boolTy = IntegerType::get(&context, 1); 96 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 97 98 // Check that splat is automatically detected for boolean values. 99 /// True. 100 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 101 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 102 EXPECT_TRUE(trueSplat.isSplat()); 103 EXPECT_TRUE(falseSplat.isSplat()); 104 105 /// Detect that the large boolean arrays are properly splatted. 106 /// True. 107 SmallVector<bool, 64> trueValues(boolCount, true); 108 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 109 EXPECT_EQ(detectedSplat, trueSplat); 110 /// False. 111 SmallVector<bool, 64> falseValues(boolCount, false); 112 detectedSplat = DenseElementsAttr::get(shape, falseValues); 113 EXPECT_EQ(detectedSplat, falseSplat); 114 } 115 116 TEST(DenseSplatTest, BoolNonSplat) { 117 MLIRContext context; 118 IntegerType boolTy = IntegerType::get(&context, 1); 119 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 120 121 // Check that we properly handle non-splat values. 122 DenseElementsAttr nonSplat = 123 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 124 EXPECT_FALSE(nonSplat.isSplat()); 125 } 126 127 TEST(DenseSplatTest, OddIntSplat) { 128 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 129 MLIRContext context; 130 constexpr size_t intWidth = 19; 131 IntegerType intTy = IntegerType::get(&context, intWidth); 132 APInt value(intWidth, 10); 133 134 testSplat(intTy, value); 135 } 136 137 TEST(DenseSplatTest, Int32Splat) { 138 MLIRContext context; 139 IntegerType intTy = IntegerType::get(&context, 32); 140 int value = 64; 141 142 testSplat(intTy, value); 143 } 144 145 TEST(DenseSplatTest, IntAttrSplat) { 146 MLIRContext context; 147 IntegerType intTy = IntegerType::get(&context, 85); 148 Attribute value = IntegerAttr::get(intTy, 109); 149 150 testSplat(intTy, value); 151 } 152 153 TEST(DenseSplatTest, F32Splat) { 154 MLIRContext context; 155 FloatType floatTy = FloatType::getF32(&context); 156 float value = 10.0; 157 158 testSplat(floatTy, value); 159 } 160 161 TEST(DenseSplatTest, F64Splat) { 162 MLIRContext context; 163 FloatType floatTy = FloatType::getF64(&context); 164 double value = 10.0; 165 166 testSplat(floatTy, APFloat(value)); 167 } 168 169 TEST(DenseSplatTest, FloatAttrSplat) { 170 MLIRContext context; 171 FloatType floatTy = FloatType::getF32(&context); 172 Attribute value = FloatAttr::get(floatTy, 10.0); 173 174 testSplat(floatTy, value); 175 } 176 177 TEST(DenseSplatTest, BF16Splat) { 178 MLIRContext context; 179 FloatType floatTy = FloatType::getBF16(&context); 180 Attribute value = FloatAttr::get(floatTy, 10.0); 181 182 testSplat(floatTy, value); 183 } 184 185 TEST(DenseSplatTest, StringSplat) { 186 MLIRContext context; 187 context.allowUnregisteredDialects(); 188 Type stringType = 189 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 190 StringRef value = "test-string"; 191 testSplat(stringType, value); 192 } 193 194 TEST(DenseSplatTest, StringAttrSplat) { 195 MLIRContext context; 196 context.allowUnregisteredDialects(); 197 Type stringType = 198 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 199 Attribute stringAttr = StringAttr::get("test-string", stringType); 200 testSplat(stringType, stringAttr); 201 } 202 203 TEST(DenseComplexTest, ComplexFloatSplat) { 204 MLIRContext context; 205 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 206 std::complex<float> value(10.0, 15.0); 207 testSplat(complexType, value); 208 } 209 210 TEST(DenseComplexTest, ComplexIntSplat) { 211 MLIRContext context; 212 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 213 std::complex<int64_t> value(10, 15); 214 testSplat(complexType, value); 215 } 216 217 TEST(DenseComplexTest, ComplexAPFloatSplat) { 218 MLIRContext context; 219 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 220 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 221 testSplat(complexType, value); 222 } 223 224 TEST(DenseComplexTest, ComplexAPIntSplat) { 225 MLIRContext context; 226 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 227 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 228 testSplat(complexType, value); 229 } 230 231 TEST(DenseScalarTest, ExtractZeroRankElement) { 232 MLIRContext context; 233 const int elementValue = 12; 234 IntegerType intTy = IntegerType::get(&context, 32); 235 Attribute value = IntegerAttr::get(intTy, elementValue); 236 RankedTensorType shape = RankedTensorType::get({}, intTy); 237 238 auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})); 239 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 240 } 241 242 TEST(DenseSplatMapValuesTest, I32ToTrue) { 243 MLIRContext context; 244 const int elementValue = 12; 245 IntegerType boolTy = IntegerType::get(&context, 1); 246 IntegerType intTy = IntegerType::get(&context, 32); 247 RankedTensorType shape = RankedTensorType::get({4}, intTy); 248 249 auto attr = 250 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 251 .mapValues(boolTy, [](const APInt &x) { 252 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 253 }); 254 EXPECT_EQ(attr.getNumElements(), 4); 255 EXPECT_TRUE(attr.isSplat()); 256 EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue()); 257 } 258 259 TEST(DenseSplatMapValuesTest, I32ToFalse) { 260 MLIRContext context; 261 const int elementValue = 0; 262 IntegerType boolTy = IntegerType::get(&context, 1); 263 IntegerType intTy = IntegerType::get(&context, 32); 264 RankedTensorType shape = RankedTensorType::get({4}, intTy); 265 266 auto attr = 267 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 268 .mapValues(boolTy, [](const APInt &x) { 269 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 270 }); 271 EXPECT_EQ(attr.getNumElements(), 4); 272 EXPECT_TRUE(attr.isSplat()); 273 EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue()); 274 } 275 } // namespace 276 277 //===----------------------------------------------------------------------===// 278 // DenseResourceElementsAttr 279 //===----------------------------------------------------------------------===// 280 281 template <typename AttrT, typename T> 282 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, 283 Type elementType) { 284 auto type = RankedTensorType::get(data.size(), elementType); 285 auto attr = AttrT::get(type, "resource", 286 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 287 288 // Check that we can access and iterate the data properly. 289 std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); 290 EXPECT_TRUE(attrData.has_value()); 291 EXPECT_EQ(*attrData, data); 292 293 // Check that we cast to this attribute when possible. 294 Attribute genericAttr = attr; 295 EXPECT_TRUE(isa<AttrT>(genericAttr)); 296 } 297 template <typename AttrT, typename T> 298 static void checkNativeIntAccess(Builder &builder, size_t intWidth) { 299 T data[] = {0, 1, 2}; 300 checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data), 301 builder.getIntegerType(intWidth)); 302 } 303 304 namespace { 305 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { 306 MLIRContext context; 307 Builder builder(&context); 308 309 // Bool 310 bool boolData[] = {true, false, true}; 311 checkNativeAccess<DenseBoolResourceElementsAttr>( 312 &context, llvm::ArrayRef(boolData), builder.getI1Type()); 313 314 // Unsigned integers 315 checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8); 316 checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16); 317 checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32); 318 checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64); 319 320 // Signed integers 321 checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8); 322 checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16); 323 checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32); 324 checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64); 325 326 // Float 327 float floatData[] = {0, 1, 2}; 328 checkNativeAccess<DenseF32ResourceElementsAttr>( 329 &context, llvm::ArrayRef(floatData), builder.getF32Type()); 330 331 // Double 332 double doubleData[] = {0, 1, 2}; 333 checkNativeAccess<DenseF64ResourceElementsAttr>( 334 &context, llvm::ArrayRef(doubleData), builder.getF64Type()); 335 } 336 337 TEST(DenseResourceElementsAttrTest, CheckNoCast) { 338 MLIRContext context; 339 Builder builder(&context); 340 341 // Create a i32 attribute. 342 ArrayRef<uint32_t> data; 343 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 344 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 345 type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); 346 347 EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr)); 348 EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr)); 349 EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr)); 350 } 351 352 TEST(DenseResourceElementsAttrTest, CheckInvalidData) { 353 MLIRContext context; 354 Builder builder(&context); 355 356 // Create a bool attribute with data of the incorrect type. 357 ArrayRef<uint32_t> data; 358 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 359 EXPECT_DEBUG_DEATH( 360 { 361 DenseBoolResourceElementsAttr::get( 362 type, "resource", 363 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 364 }, 365 "alignment mismatch between expected alignment and blob alignment"); 366 } 367 368 TEST(DenseResourceElementsAttrTest, CheckInvalidType) { 369 MLIRContext context; 370 Builder builder(&context); 371 372 // Create a bool attribute with incorrect type. 373 ArrayRef<bool> data; 374 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 375 EXPECT_DEBUG_DEATH( 376 { 377 DenseBoolResourceElementsAttr::get( 378 type, "resource", 379 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 380 }, 381 "invalid shape element type for provided type `T`"); 382 } 383 } // namespace 384 385 //===----------------------------------------------------------------------===// 386 // SparseElementsAttr 387 //===----------------------------------------------------------------------===// 388 389 namespace { 390 TEST(SparseElementsAttrTest, GetZero) { 391 MLIRContext context; 392 context.allowUnregisteredDialects(); 393 394 IntegerType intTy = IntegerType::get(&context, 32); 395 FloatType floatTy = FloatType::getF32(&context); 396 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 397 398 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 399 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 400 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 401 402 auto indicesType = 403 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 404 auto indices = 405 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 406 407 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 408 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 409 410 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 411 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 412 413 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 414 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 415 416 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 417 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 418 auto sparseString = 419 SparseElementsAttr::get(tensorString, indices, stringValue); 420 421 // Only index (0, 0) contains an element, others are supposed to return 422 // the zero/empty value. 423 auto zeroIntValue = 424 cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]); 425 EXPECT_EQ(zeroIntValue.getInt(), 0); 426 EXPECT_TRUE(zeroIntValue.getType() == intTy); 427 428 auto zeroFloatValue = 429 cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]); 430 EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); 431 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 432 433 auto zeroStringValue = 434 cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]); 435 EXPECT_TRUE(zeroStringValue.getValue().empty()); 436 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 437 } 438 439 //===----------------------------------------------------------------------===// 440 // SubElements 441 //===----------------------------------------------------------------------===// 442 443 TEST(SubElementTest, Nested) { 444 MLIRContext context; 445 Builder builder(&context); 446 447 BoolAttr trueAttr = builder.getBoolAttr(true); 448 BoolAttr falseAttr = builder.getBoolAttr(false); 449 ArrayAttr boolArrayAttr = 450 builder.getArrayAttr({trueAttr, falseAttr, trueAttr}); 451 StringAttr strAttr = builder.getStringAttr("array"); 452 DictionaryAttr dictAttr = 453 builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr)); 454 455 SmallVector<Attribute> subAttrs; 456 dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); }); 457 // Note that trueAttr appears only once, identical subattributes are skipped. 458 EXPECT_EQ(llvm::ArrayRef(subAttrs), 459 ArrayRef<Attribute>( 460 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); 461 } 462 } // namespace 463