1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AsmState.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "gtest/gtest.h" 14 15 using namespace mlir; 16 using namespace mlir::detail; 17 18 //===----------------------------------------------------------------------===// 19 // DenseElementsAttr 20 //===----------------------------------------------------------------------===// 21 22 template <typename EltTy> 23 static void testSplat(Type eltType, const EltTy &splatElt) { 24 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 25 26 // Check that the generated splat is the same for 1 element and N elements. 27 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 28 EXPECT_TRUE(splat.isSplat()); 29 30 auto detectedSplat = 31 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); 32 EXPECT_EQ(detectedSplat, splat); 33 34 for (auto newValue : detectedSplat.template getValues<EltTy>()) 35 EXPECT_TRUE(newValue == splatElt); 36 } 37 38 namespace { 39 TEST(DenseSplatTest, BoolSplat) { 40 MLIRContext context; 41 IntegerType boolTy = IntegerType::get(&context, 1); 42 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 43 44 // Check that splat is automatically detected for boolean values. 45 /// True. 46 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 47 EXPECT_TRUE(trueSplat.isSplat()); 48 /// False. 49 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 50 EXPECT_TRUE(falseSplat.isSplat()); 51 EXPECT_NE(falseSplat, trueSplat); 52 53 /// Detect and handle splat within 8 elements (bool values are bit-packed). 54 /// True. 55 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 56 EXPECT_EQ(detectedSplat, trueSplat); 57 /// False. 58 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 59 EXPECT_EQ(detectedSplat, falseSplat); 60 } 61 62 TEST(DenseSplatTest, LargeBoolSplat) { 63 constexpr int64_t boolCount = 56; 64 65 MLIRContext context; 66 IntegerType boolTy = IntegerType::get(&context, 1); 67 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 68 69 // Check that splat is automatically detected for boolean values. 70 /// True. 71 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 72 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 73 EXPECT_TRUE(trueSplat.isSplat()); 74 EXPECT_TRUE(falseSplat.isSplat()); 75 76 /// Detect that the large boolean arrays are properly splatted. 77 /// True. 78 SmallVector<bool, 64> trueValues(boolCount, true); 79 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 80 EXPECT_EQ(detectedSplat, trueSplat); 81 /// False. 82 SmallVector<bool, 64> falseValues(boolCount, false); 83 detectedSplat = DenseElementsAttr::get(shape, falseValues); 84 EXPECT_EQ(detectedSplat, falseSplat); 85 } 86 87 TEST(DenseSplatTest, BoolNonSplat) { 88 MLIRContext context; 89 IntegerType boolTy = IntegerType::get(&context, 1); 90 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 91 92 // Check that we properly handle non-splat values. 93 DenseElementsAttr nonSplat = 94 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 95 EXPECT_FALSE(nonSplat.isSplat()); 96 } 97 98 TEST(DenseSplatTest, OddIntSplat) { 99 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 100 MLIRContext context; 101 constexpr size_t intWidth = 19; 102 IntegerType intTy = IntegerType::get(&context, intWidth); 103 APInt value(intWidth, 10); 104 105 testSplat(intTy, value); 106 } 107 108 TEST(DenseSplatTest, Int32Splat) { 109 MLIRContext context; 110 IntegerType intTy = IntegerType::get(&context, 32); 111 int value = 64; 112 113 testSplat(intTy, value); 114 } 115 116 TEST(DenseSplatTest, IntAttrSplat) { 117 MLIRContext context; 118 IntegerType intTy = IntegerType::get(&context, 85); 119 Attribute value = IntegerAttr::get(intTy, 109); 120 121 testSplat(intTy, value); 122 } 123 124 TEST(DenseSplatTest, F32Splat) { 125 MLIRContext context; 126 FloatType floatTy = FloatType::getF32(&context); 127 float value = 10.0; 128 129 testSplat(floatTy, value); 130 } 131 132 TEST(DenseSplatTest, F64Splat) { 133 MLIRContext context; 134 FloatType floatTy = FloatType::getF64(&context); 135 double value = 10.0; 136 137 testSplat(floatTy, APFloat(value)); 138 } 139 140 TEST(DenseSplatTest, FloatAttrSplat) { 141 MLIRContext context; 142 FloatType floatTy = FloatType::getF32(&context); 143 Attribute value = FloatAttr::get(floatTy, 10.0); 144 145 testSplat(floatTy, value); 146 } 147 148 TEST(DenseSplatTest, BF16Splat) { 149 MLIRContext context; 150 FloatType floatTy = FloatType::getBF16(&context); 151 Attribute value = FloatAttr::get(floatTy, 10.0); 152 153 testSplat(floatTy, value); 154 } 155 156 TEST(DenseSplatTest, StringSplat) { 157 MLIRContext context; 158 context.allowUnregisteredDialects(); 159 Type stringType = 160 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 161 StringRef value = "test-string"; 162 testSplat(stringType, value); 163 } 164 165 TEST(DenseSplatTest, StringAttrSplat) { 166 MLIRContext context; 167 context.allowUnregisteredDialects(); 168 Type stringType = 169 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 170 Attribute stringAttr = StringAttr::get("test-string", stringType); 171 testSplat(stringType, stringAttr); 172 } 173 174 TEST(DenseComplexTest, ComplexFloatSplat) { 175 MLIRContext context; 176 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 177 std::complex<float> value(10.0, 15.0); 178 testSplat(complexType, value); 179 } 180 181 TEST(DenseComplexTest, ComplexIntSplat) { 182 MLIRContext context; 183 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 184 std::complex<int64_t> value(10, 15); 185 testSplat(complexType, value); 186 } 187 188 TEST(DenseComplexTest, ComplexAPFloatSplat) { 189 MLIRContext context; 190 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 191 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 192 testSplat(complexType, value); 193 } 194 195 TEST(DenseComplexTest, ComplexAPIntSplat) { 196 MLIRContext context; 197 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 198 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 199 testSplat(complexType, value); 200 } 201 202 TEST(DenseScalarTest, ExtractZeroRankElement) { 203 MLIRContext context; 204 const int elementValue = 12; 205 IntegerType intTy = IntegerType::get(&context, 32); 206 Attribute value = IntegerAttr::get(intTy, elementValue); 207 RankedTensorType shape = RankedTensorType::get({}, intTy); 208 209 auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); 210 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 211 } 212 213 TEST(DenseSplatMapValuesTest, I32ToTrue) { 214 MLIRContext context; 215 const int elementValue = 12; 216 IntegerType boolTy = IntegerType::get(&context, 1); 217 IntegerType intTy = IntegerType::get(&context, 32); 218 RankedTensorType shape = RankedTensorType::get({4}, intTy); 219 220 auto attr = 221 DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) 222 .mapValues(boolTy, [](const APInt &x) { 223 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 224 }); 225 EXPECT_EQ(attr.getNumElements(), 4); 226 EXPECT_TRUE(attr.isSplat()); 227 EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue()); 228 } 229 230 TEST(DenseSplatMapValuesTest, I32ToFalse) { 231 MLIRContext context; 232 const int elementValue = 0; 233 IntegerType boolTy = IntegerType::get(&context, 1); 234 IntegerType intTy = IntegerType::get(&context, 32); 235 RankedTensorType shape = RankedTensorType::get({4}, intTy); 236 237 auto attr = 238 DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) 239 .mapValues(boolTy, [](const APInt &x) { 240 return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 241 }); 242 EXPECT_EQ(attr.getNumElements(), 4); 243 EXPECT_TRUE(attr.isSplat()); 244 EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue()); 245 } 246 } // namespace 247 248 //===----------------------------------------------------------------------===// 249 // DenseResourceElementsAttr 250 //===----------------------------------------------------------------------===// 251 252 template <typename AttrT, typename T> 253 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, 254 Type elementType) { 255 auto type = RankedTensorType::get(data.size(), elementType); 256 auto attr = AttrT::get(type, "resource", 257 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 258 259 // Check that we can access and iterate the data properly. 260 Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); 261 EXPECT_TRUE(attrData.has_value()); 262 EXPECT_EQ(*attrData, data); 263 264 // Check that we cast to this attribute when possible. 265 Attribute genericAttr = attr; 266 EXPECT_TRUE(genericAttr.template isa<AttrT>()); 267 } 268 template <typename AttrT, typename T> 269 static void checkNativeIntAccess(Builder &builder, size_t intWidth) { 270 T data[] = {0, 1, 2}; 271 checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data), 272 builder.getIntegerType(intWidth)); 273 } 274 275 namespace { 276 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { 277 MLIRContext context; 278 Builder builder(&context); 279 280 // Bool 281 bool boolData[] = {true, false, true}; 282 checkNativeAccess<DenseBoolResourceElementsAttr>( 283 &context, llvm::makeArrayRef(boolData), builder.getI1Type()); 284 285 // Unsigned integers 286 checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8); 287 checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16); 288 checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32); 289 checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64); 290 291 // Signed integers 292 checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8); 293 checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16); 294 checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32); 295 checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64); 296 297 // Float 298 float floatData[] = {0, 1, 2}; 299 checkNativeAccess<DenseF32ResourceElementsAttr>( 300 &context, llvm::makeArrayRef(floatData), builder.getF32Type()); 301 302 // Double 303 double doubleData[] = {0, 1, 2}; 304 checkNativeAccess<DenseF64ResourceElementsAttr>( 305 &context, llvm::makeArrayRef(doubleData), builder.getF64Type()); 306 } 307 308 TEST(DenseResourceElementsAttrTest, CheckNoCast) { 309 MLIRContext context; 310 Builder builder(&context); 311 312 // Create a i32 attribute. 313 ArrayRef<uint32_t> data; 314 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 315 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 316 type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); 317 318 EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>()); 319 EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>()); 320 EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>()); 321 } 322 323 TEST(DenseResourceElementsAttrTest, CheckInvalidData) { 324 MLIRContext context; 325 Builder builder(&context); 326 327 // Create a bool attribute with data of the incorrect type. 328 ArrayRef<uint32_t> data; 329 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 330 EXPECT_DEBUG_DEATH( 331 { 332 DenseBoolResourceElementsAttr::get( 333 type, "resource", 334 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 335 }, 336 "alignment mismatch between expected alignment and blob alignment"); 337 } 338 339 TEST(DenseResourceElementsAttrTest, CheckInvalidType) { 340 MLIRContext context; 341 Builder builder(&context); 342 343 // Create a bool attribute with incorrect type. 344 ArrayRef<bool> data; 345 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 346 EXPECT_DEBUG_DEATH( 347 { 348 DenseBoolResourceElementsAttr::get( 349 type, "resource", 350 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 351 }, 352 "invalid shape element type for provided type `T`"); 353 } 354 } // namespace 355 356 //===----------------------------------------------------------------------===// 357 // SparseElementsAttr 358 //===----------------------------------------------------------------------===// 359 360 namespace { 361 TEST(SparseElementsAttrTest, GetZero) { 362 MLIRContext context; 363 context.allowUnregisteredDialects(); 364 365 IntegerType intTy = IntegerType::get(&context, 32); 366 FloatType floatTy = FloatType::getF32(&context); 367 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 368 369 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 370 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 371 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 372 373 auto indicesType = 374 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 375 auto indices = 376 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 377 378 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 379 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 380 381 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 382 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 383 384 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 385 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 386 387 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 388 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 389 auto sparseString = 390 SparseElementsAttr::get(tensorString, indices, stringValue); 391 392 // Only index (0, 0) contains an element, others are supposed to return 393 // the zero/empty value. 394 auto zeroIntValue = 395 sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>(); 396 EXPECT_EQ(zeroIntValue.getInt(), 0); 397 EXPECT_TRUE(zeroIntValue.getType() == intTy); 398 399 auto zeroFloatValue = 400 sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>(); 401 EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); 402 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 403 404 auto zeroStringValue = 405 sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>(); 406 EXPECT_TRUE(zeroStringValue.getValue().empty()); 407 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 408 } 409 410 } // namespace 411