1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AsmState.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "gtest/gtest.h" 14 15 using namespace mlir; 16 using namespace mlir::detail; 17 18 //===----------------------------------------------------------------------===// 19 // DenseElementsAttr 20 //===----------------------------------------------------------------------===// 21 22 template <typename EltTy> 23 static void testSplat(Type eltType, const EltTy &splatElt) { 24 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 25 26 // Check that the generated splat is the same for 1 element and N elements. 27 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 28 EXPECT_TRUE(splat.isSplat()); 29 30 auto detectedSplat = 31 DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); 32 EXPECT_EQ(detectedSplat, splat); 33 34 for (auto newValue : detectedSplat.template getValues<EltTy>()) 35 EXPECT_TRUE(newValue == splatElt); 36 } 37 38 namespace { 39 TEST(DenseSplatTest, BoolSplat) { 40 MLIRContext context; 41 IntegerType boolTy = IntegerType::get(&context, 1); 42 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 43 44 // Check that splat is automatically detected for boolean values. 45 /// True. 46 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 47 EXPECT_TRUE(trueSplat.isSplat()); 48 /// False. 49 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 50 EXPECT_TRUE(falseSplat.isSplat()); 51 EXPECT_NE(falseSplat, trueSplat); 52 53 /// Detect and handle splat within 8 elements (bool values are bit-packed). 54 /// True. 55 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 56 EXPECT_EQ(detectedSplat, trueSplat); 57 /// False. 58 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 59 EXPECT_EQ(detectedSplat, falseSplat); 60 } 61 TEST(DenseSplatTest, BoolSplatRawRoundtrip) { 62 MLIRContext context; 63 IntegerType boolTy = IntegerType::get(&context, 1); 64 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 65 66 // Check that splat booleans properly round trip via the raw API. 67 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 68 EXPECT_TRUE(trueSplat.isSplat()); 69 DenseElementsAttr trueSplatFromRaw = 70 DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData()); 71 EXPECT_TRUE(trueSplatFromRaw.isSplat()); 72 73 EXPECT_EQ(trueSplat, trueSplatFromRaw); 74 } 75 76 TEST(DenseSplatTest, LargeBoolSplat) { 77 constexpr int64_t boolCount = 56; 78 79 MLIRContext context; 80 IntegerType boolTy = IntegerType::get(&context, 1); 81 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 82 83 // Check that splat is automatically detected for boolean values. 84 /// True. 85 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 86 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 87 EXPECT_TRUE(trueSplat.isSplat()); 88 EXPECT_TRUE(falseSplat.isSplat()); 89 90 /// Detect that the large boolean arrays are properly splatted. 91 /// True. 92 SmallVector<bool, 64> trueValues(boolCount, true); 93 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 94 EXPECT_EQ(detectedSplat, trueSplat); 95 /// False. 96 SmallVector<bool, 64> falseValues(boolCount, false); 97 detectedSplat = DenseElementsAttr::get(shape, falseValues); 98 EXPECT_EQ(detectedSplat, falseSplat); 99 } 100 101 TEST(DenseSplatTest, BoolNonSplat) { 102 MLIRContext context; 103 IntegerType boolTy = IntegerType::get(&context, 1); 104 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 105 106 // Check that we properly handle non-splat values. 107 DenseElementsAttr nonSplat = 108 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 109 EXPECT_FALSE(nonSplat.isSplat()); 110 } 111 112 TEST(DenseSplatTest, OddIntSplat) { 113 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 114 MLIRContext context; 115 constexpr size_t intWidth = 19; 116 IntegerType intTy = IntegerType::get(&context, intWidth); 117 APInt value(intWidth, 10); 118 119 testSplat(intTy, value); 120 } 121 122 TEST(DenseSplatTest, Int32Splat) { 123 MLIRContext context; 124 IntegerType intTy = IntegerType::get(&context, 32); 125 int value = 64; 126 127 testSplat(intTy, value); 128 } 129 130 TEST(DenseSplatTest, IntAttrSplat) { 131 MLIRContext context; 132 IntegerType intTy = IntegerType::get(&context, 85); 133 Attribute value = IntegerAttr::get(intTy, 109); 134 135 testSplat(intTy, value); 136 } 137 138 TEST(DenseSplatTest, F32Splat) { 139 MLIRContext context; 140 FloatType floatTy = FloatType::getF32(&context); 141 float value = 10.0; 142 143 testSplat(floatTy, value); 144 } 145 146 TEST(DenseSplatTest, F64Splat) { 147 MLIRContext context; 148 FloatType floatTy = FloatType::getF64(&context); 149 double value = 10.0; 150 151 testSplat(floatTy, APFloat(value)); 152 } 153 154 TEST(DenseSplatTest, FloatAttrSplat) { 155 MLIRContext context; 156 FloatType floatTy = FloatType::getF32(&context); 157 Attribute value = FloatAttr::get(floatTy, 10.0); 158 159 testSplat(floatTy, value); 160 } 161 162 TEST(DenseSplatTest, BF16Splat) { 163 MLIRContext context; 164 FloatType floatTy = FloatType::getBF16(&context); 165 Attribute value = FloatAttr::get(floatTy, 10.0); 166 167 testSplat(floatTy, value); 168 } 169 170 TEST(DenseSplatTest, StringSplat) { 171 MLIRContext context; 172 context.allowUnregisteredDialects(); 173 Type stringType = 174 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 175 StringRef value = "test-string"; 176 testSplat(stringType, value); 177 } 178 179 TEST(DenseSplatTest, StringAttrSplat) { 180 MLIRContext context; 181 context.allowUnregisteredDialects(); 182 Type stringType = 183 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 184 Attribute stringAttr = StringAttr::get("test-string", stringType); 185 testSplat(stringType, stringAttr); 186 } 187 188 TEST(DenseComplexTest, ComplexFloatSplat) { 189 MLIRContext context; 190 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 191 std::complex<float> value(10.0, 15.0); 192 testSplat(complexType, value); 193 } 194 195 TEST(DenseComplexTest, ComplexIntSplat) { 196 MLIRContext context; 197 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 198 std::complex<int64_t> value(10, 15); 199 testSplat(complexType, value); 200 } 201 202 TEST(DenseComplexTest, ComplexAPFloatSplat) { 203 MLIRContext context; 204 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 205 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 206 testSplat(complexType, value); 207 } 208 209 TEST(DenseComplexTest, ComplexAPIntSplat) { 210 MLIRContext context; 211 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 212 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 213 testSplat(complexType, value); 214 } 215 216 TEST(DenseScalarTest, ExtractZeroRankElement) { 217 MLIRContext context; 218 const int elementValue = 12; 219 IntegerType intTy = IntegerType::get(&context, 32); 220 Attribute value = IntegerAttr::get(intTy, elementValue); 221 RankedTensorType shape = RankedTensorType::get({}, intTy); 222 223 auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})); 224 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 225 } 226 227 TEST(DenseSplatMapValuesTest, I32ToTrue) { 228 MLIRContext context; 229 const int elementValue = 12; 230 IntegerType boolTy = IntegerType::get(&context, 1); 231 IntegerType intTy = IntegerType::get(&context, 32); 232 RankedTensorType shape = RankedTensorType::get({4}, intTy); 233 234 auto attr = 235 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 236 .mapValues(boolTy, [](const APInt &x) { 237 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 238 }); 239 EXPECT_EQ(attr.getNumElements(), 4); 240 EXPECT_TRUE(attr.isSplat()); 241 EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue()); 242 } 243 244 TEST(DenseSplatMapValuesTest, I32ToFalse) { 245 MLIRContext context; 246 const int elementValue = 0; 247 IntegerType boolTy = IntegerType::get(&context, 1); 248 IntegerType intTy = IntegerType::get(&context, 32); 249 RankedTensorType shape = RankedTensorType::get({4}, intTy); 250 251 auto attr = 252 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 253 .mapValues(boolTy, [](const APInt &x) { 254 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 255 }); 256 EXPECT_EQ(attr.getNumElements(), 4); 257 EXPECT_TRUE(attr.isSplat()); 258 EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue()); 259 } 260 } // namespace 261 262 //===----------------------------------------------------------------------===// 263 // DenseResourceElementsAttr 264 //===----------------------------------------------------------------------===// 265 266 template <typename AttrT, typename T> 267 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, 268 Type elementType) { 269 auto type = RankedTensorType::get(data.size(), elementType); 270 auto attr = AttrT::get(type, "resource", 271 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 272 273 // Check that we can access and iterate the data properly. 274 Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); 275 EXPECT_TRUE(attrData.has_value()); 276 EXPECT_EQ(*attrData, data); 277 278 // Check that we cast to this attribute when possible. 279 Attribute genericAttr = attr; 280 EXPECT_TRUE(genericAttr.template isa<AttrT>()); 281 } 282 template <typename AttrT, typename T> 283 static void checkNativeIntAccess(Builder &builder, size_t intWidth) { 284 T data[] = {0, 1, 2}; 285 checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data), 286 builder.getIntegerType(intWidth)); 287 } 288 289 namespace { 290 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { 291 MLIRContext context; 292 Builder builder(&context); 293 294 // Bool 295 bool boolData[] = {true, false, true}; 296 checkNativeAccess<DenseBoolResourceElementsAttr>( 297 &context, llvm::ArrayRef(boolData), builder.getI1Type()); 298 299 // Unsigned integers 300 checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8); 301 checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16); 302 checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32); 303 checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64); 304 305 // Signed integers 306 checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8); 307 checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16); 308 checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32); 309 checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64); 310 311 // Float 312 float floatData[] = {0, 1, 2}; 313 checkNativeAccess<DenseF32ResourceElementsAttr>( 314 &context, llvm::ArrayRef(floatData), builder.getF32Type()); 315 316 // Double 317 double doubleData[] = {0, 1, 2}; 318 checkNativeAccess<DenseF64ResourceElementsAttr>( 319 &context, llvm::ArrayRef(doubleData), builder.getF64Type()); 320 } 321 322 TEST(DenseResourceElementsAttrTest, CheckNoCast) { 323 MLIRContext context; 324 Builder builder(&context); 325 326 // Create a i32 attribute. 327 ArrayRef<uint32_t> data; 328 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 329 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 330 type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); 331 332 EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>()); 333 EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>()); 334 EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>()); 335 } 336 337 TEST(DenseResourceElementsAttrTest, CheckInvalidData) { 338 MLIRContext context; 339 Builder builder(&context); 340 341 // Create a bool attribute with data of the incorrect type. 342 ArrayRef<uint32_t> data; 343 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 344 EXPECT_DEBUG_DEATH( 345 { 346 DenseBoolResourceElementsAttr::get( 347 type, "resource", 348 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 349 }, 350 "alignment mismatch between expected alignment and blob alignment"); 351 } 352 353 TEST(DenseResourceElementsAttrTest, CheckInvalidType) { 354 MLIRContext context; 355 Builder builder(&context); 356 357 // Create a bool attribute with incorrect type. 358 ArrayRef<bool> data; 359 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 360 EXPECT_DEBUG_DEATH( 361 { 362 DenseBoolResourceElementsAttr::get( 363 type, "resource", 364 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 365 }, 366 "invalid shape element type for provided type `T`"); 367 } 368 } // namespace 369 370 //===----------------------------------------------------------------------===// 371 // SparseElementsAttr 372 //===----------------------------------------------------------------------===// 373 374 namespace { 375 TEST(SparseElementsAttrTest, GetZero) { 376 MLIRContext context; 377 context.allowUnregisteredDialects(); 378 379 IntegerType intTy = IntegerType::get(&context, 32); 380 FloatType floatTy = FloatType::getF32(&context); 381 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 382 383 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 384 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 385 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 386 387 auto indicesType = 388 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 389 auto indices = 390 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 391 392 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 393 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 394 395 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 396 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 397 398 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 399 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 400 401 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 402 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 403 auto sparseString = 404 SparseElementsAttr::get(tensorString, indices, stringValue); 405 406 // Only index (0, 0) contains an element, others are supposed to return 407 // the zero/empty value. 408 auto zeroIntValue = 409 sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>(); 410 EXPECT_EQ(zeroIntValue.getInt(), 0); 411 EXPECT_TRUE(zeroIntValue.getType() == intTy); 412 413 auto zeroFloatValue = 414 sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>(); 415 EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); 416 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 417 418 auto zeroStringValue = 419 sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>(); 420 EXPECT_TRUE(zeroStringValue.getValue().empty()); 421 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 422 } 423 424 } // namespace 425