1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AsmState.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "gtest/gtest.h" 14 #include <optional> 15 16 using namespace mlir; 17 using namespace mlir::detail; 18 19 //===----------------------------------------------------------------------===// 20 // DenseElementsAttr 21 //===----------------------------------------------------------------------===// 22 23 template <typename EltTy> 24 static void testSplat(Type eltType, const EltTy &splatElt) { 25 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 26 27 // Check that the generated splat is the same for 1 element and N elements. 28 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 29 EXPECT_TRUE(splat.isSplat()); 30 31 auto detectedSplat = 32 DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); 33 EXPECT_EQ(detectedSplat, splat); 34 35 for (auto newValue : detectedSplat.template getValues<EltTy>()) 36 EXPECT_TRUE(newValue == splatElt); 37 } 38 39 namespace { 40 TEST(DenseSplatTest, BoolSplat) { 41 MLIRContext context; 42 IntegerType boolTy = IntegerType::get(&context, 1); 43 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 44 45 // Check that splat is automatically detected for boolean values. 46 /// True. 47 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 48 EXPECT_TRUE(trueSplat.isSplat()); 49 /// False. 50 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 51 EXPECT_TRUE(falseSplat.isSplat()); 52 EXPECT_NE(falseSplat, trueSplat); 53 54 /// Detect and handle splat within 8 elements (bool values are bit-packed). 55 /// True. 56 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 57 EXPECT_EQ(detectedSplat, trueSplat); 58 /// False. 59 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 60 EXPECT_EQ(detectedSplat, falseSplat); 61 } 62 TEST(DenseSplatTest, BoolSplatRawRoundtrip) { 63 MLIRContext context; 64 IntegerType boolTy = IntegerType::get(&context, 1); 65 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 66 67 // Check that splat booleans properly round trip via the raw API. 68 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 69 EXPECT_TRUE(trueSplat.isSplat()); 70 DenseElementsAttr trueSplatFromRaw = 71 DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData()); 72 EXPECT_TRUE(trueSplatFromRaw.isSplat()); 73 74 EXPECT_EQ(trueSplat, trueSplatFromRaw); 75 } 76 77 TEST(DenseSplatTest, LargeBoolSplat) { 78 constexpr int64_t boolCount = 56; 79 80 MLIRContext context; 81 IntegerType boolTy = IntegerType::get(&context, 1); 82 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 83 84 // Check that splat is automatically detected for boolean values. 85 /// True. 86 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 87 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 88 EXPECT_TRUE(trueSplat.isSplat()); 89 EXPECT_TRUE(falseSplat.isSplat()); 90 91 /// Detect that the large boolean arrays are properly splatted. 92 /// True. 93 SmallVector<bool, 64> trueValues(boolCount, true); 94 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 95 EXPECT_EQ(detectedSplat, trueSplat); 96 /// False. 97 SmallVector<bool, 64> falseValues(boolCount, false); 98 detectedSplat = DenseElementsAttr::get(shape, falseValues); 99 EXPECT_EQ(detectedSplat, falseSplat); 100 } 101 102 TEST(DenseSplatTest, BoolNonSplat) { 103 MLIRContext context; 104 IntegerType boolTy = IntegerType::get(&context, 1); 105 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 106 107 // Check that we properly handle non-splat values. 108 DenseElementsAttr nonSplat = 109 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 110 EXPECT_FALSE(nonSplat.isSplat()); 111 } 112 113 TEST(DenseSplatTest, OddIntSplat) { 114 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 115 MLIRContext context; 116 constexpr size_t intWidth = 19; 117 IntegerType intTy = IntegerType::get(&context, intWidth); 118 APInt value(intWidth, 10); 119 120 testSplat(intTy, value); 121 } 122 123 TEST(DenseSplatTest, Int32Splat) { 124 MLIRContext context; 125 IntegerType intTy = IntegerType::get(&context, 32); 126 int value = 64; 127 128 testSplat(intTy, value); 129 } 130 131 TEST(DenseSplatTest, IntAttrSplat) { 132 MLIRContext context; 133 IntegerType intTy = IntegerType::get(&context, 85); 134 Attribute value = IntegerAttr::get(intTy, 109); 135 136 testSplat(intTy, value); 137 } 138 139 TEST(DenseSplatTest, F32Splat) { 140 MLIRContext context; 141 FloatType floatTy = FloatType::getF32(&context); 142 float value = 10.0; 143 144 testSplat(floatTy, value); 145 } 146 147 TEST(DenseSplatTest, F64Splat) { 148 MLIRContext context; 149 FloatType floatTy = FloatType::getF64(&context); 150 double value = 10.0; 151 152 testSplat(floatTy, APFloat(value)); 153 } 154 155 TEST(DenseSplatTest, FloatAttrSplat) { 156 MLIRContext context; 157 FloatType floatTy = FloatType::getF32(&context); 158 Attribute value = FloatAttr::get(floatTy, 10.0); 159 160 testSplat(floatTy, value); 161 } 162 163 TEST(DenseSplatTest, BF16Splat) { 164 MLIRContext context; 165 FloatType floatTy = FloatType::getBF16(&context); 166 Attribute value = FloatAttr::get(floatTy, 10.0); 167 168 testSplat(floatTy, value); 169 } 170 171 TEST(DenseSplatTest, StringSplat) { 172 MLIRContext context; 173 context.allowUnregisteredDialects(); 174 Type stringType = 175 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 176 StringRef value = "test-string"; 177 testSplat(stringType, value); 178 } 179 180 TEST(DenseSplatTest, StringAttrSplat) { 181 MLIRContext context; 182 context.allowUnregisteredDialects(); 183 Type stringType = 184 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 185 Attribute stringAttr = StringAttr::get("test-string", stringType); 186 testSplat(stringType, stringAttr); 187 } 188 189 TEST(DenseComplexTest, ComplexFloatSplat) { 190 MLIRContext context; 191 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 192 std::complex<float> value(10.0, 15.0); 193 testSplat(complexType, value); 194 } 195 196 TEST(DenseComplexTest, ComplexIntSplat) { 197 MLIRContext context; 198 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 199 std::complex<int64_t> value(10, 15); 200 testSplat(complexType, value); 201 } 202 203 TEST(DenseComplexTest, ComplexAPFloatSplat) { 204 MLIRContext context; 205 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 206 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 207 testSplat(complexType, value); 208 } 209 210 TEST(DenseComplexTest, ComplexAPIntSplat) { 211 MLIRContext context; 212 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 213 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 214 testSplat(complexType, value); 215 } 216 217 TEST(DenseScalarTest, ExtractZeroRankElement) { 218 MLIRContext context; 219 const int elementValue = 12; 220 IntegerType intTy = IntegerType::get(&context, 32); 221 Attribute value = IntegerAttr::get(intTy, elementValue); 222 RankedTensorType shape = RankedTensorType::get({}, intTy); 223 224 auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})); 225 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 226 } 227 228 TEST(DenseSplatMapValuesTest, I32ToTrue) { 229 MLIRContext context; 230 const int elementValue = 12; 231 IntegerType boolTy = IntegerType::get(&context, 1); 232 IntegerType intTy = IntegerType::get(&context, 32); 233 RankedTensorType shape = RankedTensorType::get({4}, intTy); 234 235 auto attr = 236 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 237 .mapValues(boolTy, [](const APInt &x) { 238 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 239 }); 240 EXPECT_EQ(attr.getNumElements(), 4); 241 EXPECT_TRUE(attr.isSplat()); 242 EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue()); 243 } 244 245 TEST(DenseSplatMapValuesTest, I32ToFalse) { 246 MLIRContext context; 247 const int elementValue = 0; 248 IntegerType boolTy = IntegerType::get(&context, 1); 249 IntegerType intTy = IntegerType::get(&context, 32); 250 RankedTensorType shape = RankedTensorType::get({4}, intTy); 251 252 auto attr = 253 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 254 .mapValues(boolTy, [](const APInt &x) { 255 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 256 }); 257 EXPECT_EQ(attr.getNumElements(), 4); 258 EXPECT_TRUE(attr.isSplat()); 259 EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue()); 260 } 261 } // namespace 262 263 //===----------------------------------------------------------------------===// 264 // DenseResourceElementsAttr 265 //===----------------------------------------------------------------------===// 266 267 template <typename AttrT, typename T> 268 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, 269 Type elementType) { 270 auto type = RankedTensorType::get(data.size(), elementType); 271 auto attr = AttrT::get(type, "resource", 272 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 273 274 // Check that we can access and iterate the data properly. 275 Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); 276 EXPECT_TRUE(attrData.has_value()); 277 EXPECT_EQ(*attrData, data); 278 279 // Check that we cast to this attribute when possible. 280 Attribute genericAttr = attr; 281 EXPECT_TRUE(genericAttr.template isa<AttrT>()); 282 } 283 template <typename AttrT, typename T> 284 static void checkNativeIntAccess(Builder &builder, size_t intWidth) { 285 T data[] = {0, 1, 2}; 286 checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data), 287 builder.getIntegerType(intWidth)); 288 } 289 290 namespace { 291 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { 292 MLIRContext context; 293 Builder builder(&context); 294 295 // Bool 296 bool boolData[] = {true, false, true}; 297 checkNativeAccess<DenseBoolResourceElementsAttr>( 298 &context, llvm::ArrayRef(boolData), builder.getI1Type()); 299 300 // Unsigned integers 301 checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8); 302 checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16); 303 checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32); 304 checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64); 305 306 // Signed integers 307 checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8); 308 checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16); 309 checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32); 310 checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64); 311 312 // Float 313 float floatData[] = {0, 1, 2}; 314 checkNativeAccess<DenseF32ResourceElementsAttr>( 315 &context, llvm::ArrayRef(floatData), builder.getF32Type()); 316 317 // Double 318 double doubleData[] = {0, 1, 2}; 319 checkNativeAccess<DenseF64ResourceElementsAttr>( 320 &context, llvm::ArrayRef(doubleData), builder.getF64Type()); 321 } 322 323 TEST(DenseResourceElementsAttrTest, CheckNoCast) { 324 MLIRContext context; 325 Builder builder(&context); 326 327 // Create a i32 attribute. 328 ArrayRef<uint32_t> data; 329 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 330 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 331 type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); 332 333 EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>()); 334 EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>()); 335 EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>()); 336 } 337 338 TEST(DenseResourceElementsAttrTest, CheckInvalidData) { 339 MLIRContext context; 340 Builder builder(&context); 341 342 // Create a bool attribute with data of the incorrect type. 343 ArrayRef<uint32_t> data; 344 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 345 EXPECT_DEBUG_DEATH( 346 { 347 DenseBoolResourceElementsAttr::get( 348 type, "resource", 349 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 350 }, 351 "alignment mismatch between expected alignment and blob alignment"); 352 } 353 354 TEST(DenseResourceElementsAttrTest, CheckInvalidType) { 355 MLIRContext context; 356 Builder builder(&context); 357 358 // Create a bool attribute with incorrect type. 359 ArrayRef<bool> data; 360 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 361 EXPECT_DEBUG_DEATH( 362 { 363 DenseBoolResourceElementsAttr::get( 364 type, "resource", 365 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 366 }, 367 "invalid shape element type for provided type `T`"); 368 } 369 } // namespace 370 371 //===----------------------------------------------------------------------===// 372 // SparseElementsAttr 373 //===----------------------------------------------------------------------===// 374 375 namespace { 376 TEST(SparseElementsAttrTest, GetZero) { 377 MLIRContext context; 378 context.allowUnregisteredDialects(); 379 380 IntegerType intTy = IntegerType::get(&context, 32); 381 FloatType floatTy = FloatType::getF32(&context); 382 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 383 384 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 385 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 386 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 387 388 auto indicesType = 389 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 390 auto indices = 391 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 392 393 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 394 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 395 396 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 397 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 398 399 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 400 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 401 402 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 403 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 404 auto sparseString = 405 SparseElementsAttr::get(tensorString, indices, stringValue); 406 407 // Only index (0, 0) contains an element, others are supposed to return 408 // the zero/empty value. 409 auto zeroIntValue = 410 sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>(); 411 EXPECT_EQ(zeroIntValue.getInt(), 0); 412 EXPECT_TRUE(zeroIntValue.getType() == intTy); 413 414 auto zeroFloatValue = 415 sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>(); 416 EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); 417 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 418 419 auto zeroStringValue = 420 sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>(); 421 EXPECT_TRUE(zeroStringValue.getValue().empty()); 422 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 423 } 424 425 } // namespace 426