1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AsmState.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "gtest/gtest.h" 14 15 using namespace mlir; 16 using namespace mlir::detail; 17 18 //===----------------------------------------------------------------------===// 19 // DenseElementsAttr 20 //===----------------------------------------------------------------------===// 21 22 template <typename EltTy> 23 static void testSplat(Type eltType, const EltTy &splatElt) { 24 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 25 26 // Check that the generated splat is the same for 1 element and N elements. 27 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 28 EXPECT_TRUE(splat.isSplat()); 29 30 auto detectedSplat = 31 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); 32 EXPECT_EQ(detectedSplat, splat); 33 34 for (auto newValue : detectedSplat.template getValues<EltTy>()) 35 EXPECT_TRUE(newValue == splatElt); 36 } 37 38 namespace { 39 TEST(DenseSplatTest, BoolSplat) { 40 MLIRContext context; 41 IntegerType boolTy = IntegerType::get(&context, 1); 42 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 43 44 // Check that splat is automatically detected for boolean values. 45 /// True. 46 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 47 EXPECT_TRUE(trueSplat.isSplat()); 48 /// False. 49 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 50 EXPECT_TRUE(falseSplat.isSplat()); 51 EXPECT_NE(falseSplat, trueSplat); 52 53 /// Detect and handle splat within 8 elements (bool values are bit-packed). 54 /// True. 55 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 56 EXPECT_EQ(detectedSplat, trueSplat); 57 /// False. 58 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 59 EXPECT_EQ(detectedSplat, falseSplat); 60 } 61 62 TEST(DenseSplatTest, LargeBoolSplat) { 63 constexpr int64_t boolCount = 56; 64 65 MLIRContext context; 66 IntegerType boolTy = IntegerType::get(&context, 1); 67 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 68 69 // Check that splat is automatically detected for boolean values. 70 /// True. 71 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 72 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 73 EXPECT_TRUE(trueSplat.isSplat()); 74 EXPECT_TRUE(falseSplat.isSplat()); 75 76 /// Detect that the large boolean arrays are properly splatted. 77 /// True. 78 SmallVector<bool, 64> trueValues(boolCount, true); 79 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 80 EXPECT_EQ(detectedSplat, trueSplat); 81 /// False. 82 SmallVector<bool, 64> falseValues(boolCount, false); 83 detectedSplat = DenseElementsAttr::get(shape, falseValues); 84 EXPECT_EQ(detectedSplat, falseSplat); 85 } 86 87 TEST(DenseSplatTest, BoolNonSplat) { 88 MLIRContext context; 89 IntegerType boolTy = IntegerType::get(&context, 1); 90 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 91 92 // Check that we properly handle non-splat values. 93 DenseElementsAttr nonSplat = 94 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 95 EXPECT_FALSE(nonSplat.isSplat()); 96 } 97 98 TEST(DenseSplatTest, OddIntSplat) { 99 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 100 MLIRContext context; 101 constexpr size_t intWidth = 19; 102 IntegerType intTy = IntegerType::get(&context, intWidth); 103 APInt value(intWidth, 10); 104 105 testSplat(intTy, value); 106 } 107 108 TEST(DenseSplatTest, Int32Splat) { 109 MLIRContext context; 110 IntegerType intTy = IntegerType::get(&context, 32); 111 int value = 64; 112 113 testSplat(intTy, value); 114 } 115 116 TEST(DenseSplatTest, IntAttrSplat) { 117 MLIRContext context; 118 IntegerType intTy = IntegerType::get(&context, 85); 119 Attribute value = IntegerAttr::get(intTy, 109); 120 121 testSplat(intTy, value); 122 } 123 124 TEST(DenseSplatTest, F32Splat) { 125 MLIRContext context; 126 FloatType floatTy = FloatType::getF32(&context); 127 float value = 10.0; 128 129 testSplat(floatTy, value); 130 } 131 132 TEST(DenseSplatTest, F64Splat) { 133 MLIRContext context; 134 FloatType floatTy = FloatType::getF64(&context); 135 double value = 10.0; 136 137 testSplat(floatTy, APFloat(value)); 138 } 139 140 TEST(DenseSplatTest, FloatAttrSplat) { 141 MLIRContext context; 142 FloatType floatTy = FloatType::getF32(&context); 143 Attribute value = FloatAttr::get(floatTy, 10.0); 144 145 testSplat(floatTy, value); 146 } 147 148 TEST(DenseSplatTest, BF16Splat) { 149 MLIRContext context; 150 FloatType floatTy = FloatType::getBF16(&context); 151 Attribute value = FloatAttr::get(floatTy, 10.0); 152 153 testSplat(floatTy, value); 154 } 155 156 TEST(DenseSplatTest, StringSplat) { 157 MLIRContext context; 158 context.allowUnregisteredDialects(); 159 Type stringType = 160 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 161 StringRef value = "test-string"; 162 testSplat(stringType, value); 163 } 164 165 TEST(DenseSplatTest, StringAttrSplat) { 166 MLIRContext context; 167 context.allowUnregisteredDialects(); 168 Type stringType = 169 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 170 Attribute stringAttr = StringAttr::get("test-string", stringType); 171 testSplat(stringType, stringAttr); 172 } 173 174 TEST(DenseComplexTest, ComplexFloatSplat) { 175 MLIRContext context; 176 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 177 std::complex<float> value(10.0, 15.0); 178 testSplat(complexType, value); 179 } 180 181 TEST(DenseComplexTest, ComplexIntSplat) { 182 MLIRContext context; 183 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 184 std::complex<int64_t> value(10, 15); 185 testSplat(complexType, value); 186 } 187 188 TEST(DenseComplexTest, ComplexAPFloatSplat) { 189 MLIRContext context; 190 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 191 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 192 testSplat(complexType, value); 193 } 194 195 TEST(DenseComplexTest, ComplexAPIntSplat) { 196 MLIRContext context; 197 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 198 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 199 testSplat(complexType, value); 200 } 201 202 TEST(DenseScalarTest, ExtractZeroRankElement) { 203 MLIRContext context; 204 const int elementValue = 12; 205 IntegerType intTy = IntegerType::get(&context, 32); 206 Attribute value = IntegerAttr::get(intTy, elementValue); 207 RankedTensorType shape = RankedTensorType::get({}, intTy); 208 209 auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); 210 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 211 } 212 } // namespace 213 214 //===----------------------------------------------------------------------===// 215 // DenseResourceElementsAttr 216 //===----------------------------------------------------------------------===// 217 218 template <typename AttrT, typename T> 219 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, 220 Type elementType) { 221 auto type = RankedTensorType::get(data.size(), elementType); 222 auto attr = AttrT::get(type, "resource", 223 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 224 225 // Check that we can access and iterate the data properly. 226 Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); 227 EXPECT_TRUE(attrData.has_value()); 228 EXPECT_EQ(*attrData, data); 229 230 // Check that we cast to this attribute when possible. 231 Attribute genericAttr = attr; 232 EXPECT_TRUE(genericAttr.template isa<AttrT>()); 233 } 234 template <typename AttrT, typename T> 235 static void checkNativeIntAccess(Builder &builder, size_t intWidth) { 236 T data[] = {0, 1, 2}; 237 checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data), 238 builder.getIntegerType(intWidth)); 239 } 240 241 namespace { 242 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { 243 MLIRContext context; 244 Builder builder(&context); 245 246 // Bool 247 bool boolData[] = {true, false, true}; 248 checkNativeAccess<DenseBoolResourceElementsAttr>( 249 &context, llvm::makeArrayRef(boolData), builder.getI1Type()); 250 251 // Unsigned integers 252 checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8); 253 checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16); 254 checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32); 255 checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64); 256 257 // Signed integers 258 checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8); 259 checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16); 260 checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32); 261 checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64); 262 263 // Float 264 float floatData[] = {0, 1, 2}; 265 checkNativeAccess<DenseF32ResourceElementsAttr>( 266 &context, llvm::makeArrayRef(floatData), builder.getF32Type()); 267 268 // Double 269 double doubleData[] = {0, 1, 2}; 270 checkNativeAccess<DenseF64ResourceElementsAttr>( 271 &context, llvm::makeArrayRef(doubleData), builder.getF64Type()); 272 } 273 274 TEST(DenseResourceElementsAttrTest, CheckNoCast) { 275 MLIRContext context; 276 Builder builder(&context); 277 278 // Create a i32 attribute. 279 ArrayRef<uint32_t> data; 280 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 281 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 282 type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); 283 284 EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>()); 285 EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>()); 286 EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>()); 287 } 288 289 TEST(DenseResourceElementsAttrTest, CheckInvalidData) { 290 MLIRContext context; 291 Builder builder(&context); 292 293 // Create a bool attribute with data of the incorrect type. 294 ArrayRef<uint32_t> data; 295 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 296 EXPECT_DEBUG_DEATH( 297 { 298 DenseBoolResourceElementsAttr::get( 299 type, "resource", 300 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 301 }, 302 "alignment mismatch between expected alignment and blob alignment"); 303 } 304 305 TEST(DenseResourceElementsAttrTest, CheckInvalidType) { 306 MLIRContext context; 307 Builder builder(&context); 308 309 // Create a bool attribute with incorrect type. 310 ArrayRef<bool> data; 311 auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 312 EXPECT_DEBUG_DEATH( 313 { 314 DenseBoolResourceElementsAttr::get( 315 type, "resource", 316 UnmanagedAsmResourceBlob::allocateInferAlign(data)); 317 }, 318 "invalid shape element type for provided type `T`"); 319 } 320 } // namespace 321 322 //===----------------------------------------------------------------------===// 323 // SparseElementsAttr 324 //===----------------------------------------------------------------------===// 325 326 namespace { 327 TEST(SparseElementsAttrTest, GetZero) { 328 MLIRContext context; 329 context.allowUnregisteredDialects(); 330 331 IntegerType intTy = IntegerType::get(&context, 32); 332 FloatType floatTy = FloatType::getF32(&context); 333 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 334 335 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 336 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 337 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 338 339 auto indicesType = 340 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 341 auto indices = 342 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 343 344 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 345 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 346 347 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 348 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 349 350 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 351 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 352 353 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 354 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 355 auto sparseString = 356 SparseElementsAttr::get(tensorString, indices, stringValue); 357 358 // Only index (0, 0) contains an element, others are supposed to return 359 // the zero/empty value. 360 auto zeroIntValue = 361 sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>(); 362 EXPECT_EQ(zeroIntValue.getInt(), 0); 363 EXPECT_TRUE(zeroIntValue.getType() == intTy); 364 365 auto zeroFloatValue = 366 sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>(); 367 EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); 368 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 369 370 auto zeroStringValue = 371 sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>(); 372 EXPECT_TRUE(zeroStringValue.getValue().empty()); 373 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 374 } 375 376 } // namespace 377