1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/IR/Identifier.h" 12 #include "gtest/gtest.h" 13 14 using namespace mlir; 15 using namespace mlir::detail; 16 17 template <typename EltTy> 18 static void testSplat(Type eltType, const EltTy &splatElt) { 19 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 20 21 // Check that the generated splat is the same for 1 element and N elements. 22 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 23 EXPECT_TRUE(splat.isSplat()); 24 25 auto detectedSplat = 26 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); 27 EXPECT_EQ(detectedSplat, splat); 28 29 for (auto newValue : detectedSplat.template getValues<EltTy>()) 30 EXPECT_TRUE(newValue == splatElt); 31 } 32 33 namespace { 34 TEST(DenseSplatTest, BoolSplat) { 35 MLIRContext context; 36 IntegerType boolTy = IntegerType::get(&context, 1); 37 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 38 39 // Check that splat is automatically detected for boolean values. 40 /// True. 41 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 42 EXPECT_TRUE(trueSplat.isSplat()); 43 /// False. 44 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 45 EXPECT_TRUE(falseSplat.isSplat()); 46 EXPECT_NE(falseSplat, trueSplat); 47 48 /// Detect and handle splat within 8 elements (bool values are bit-packed). 49 /// True. 50 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 51 EXPECT_EQ(detectedSplat, trueSplat); 52 /// False. 53 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 54 EXPECT_EQ(detectedSplat, falseSplat); 55 } 56 57 TEST(DenseSplatTest, LargeBoolSplat) { 58 constexpr int64_t boolCount = 56; 59 60 MLIRContext context; 61 IntegerType boolTy = IntegerType::get(&context, 1); 62 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 63 64 // Check that splat is automatically detected for boolean values. 65 /// True. 66 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 67 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 68 EXPECT_TRUE(trueSplat.isSplat()); 69 EXPECT_TRUE(falseSplat.isSplat()); 70 71 /// Detect that the large boolean arrays are properly splatted. 72 /// True. 73 SmallVector<bool, 64> trueValues(boolCount, true); 74 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 75 EXPECT_EQ(detectedSplat, trueSplat); 76 /// False. 77 SmallVector<bool, 64> falseValues(boolCount, false); 78 detectedSplat = DenseElementsAttr::get(shape, falseValues); 79 EXPECT_EQ(detectedSplat, falseSplat); 80 } 81 82 TEST(DenseSplatTest, BoolNonSplat) { 83 MLIRContext context; 84 IntegerType boolTy = IntegerType::get(&context, 1); 85 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 86 87 // Check that we properly handle non-splat values. 88 DenseElementsAttr nonSplat = 89 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 90 EXPECT_FALSE(nonSplat.isSplat()); 91 } 92 93 TEST(DenseSplatTest, OddIntSplat) { 94 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 95 MLIRContext context; 96 constexpr size_t intWidth = 19; 97 IntegerType intTy = IntegerType::get(&context, intWidth); 98 APInt value(intWidth, 10); 99 100 testSplat(intTy, value); 101 } 102 103 TEST(DenseSplatTest, Int32Splat) { 104 MLIRContext context; 105 IntegerType intTy = IntegerType::get(&context, 32); 106 int value = 64; 107 108 testSplat(intTy, value); 109 } 110 111 TEST(DenseSplatTest, IntAttrSplat) { 112 MLIRContext context; 113 IntegerType intTy = IntegerType::get(&context, 85); 114 Attribute value = IntegerAttr::get(intTy, 109); 115 116 testSplat(intTy, value); 117 } 118 119 TEST(DenseSplatTest, F32Splat) { 120 MLIRContext context; 121 FloatType floatTy = FloatType::getF32(&context); 122 float value = 10.0; 123 124 testSplat(floatTy, value); 125 } 126 127 TEST(DenseSplatTest, F64Splat) { 128 MLIRContext context; 129 FloatType floatTy = FloatType::getF64(&context); 130 double value = 10.0; 131 132 testSplat(floatTy, APFloat(value)); 133 } 134 135 TEST(DenseSplatTest, FloatAttrSplat) { 136 MLIRContext context; 137 FloatType floatTy = FloatType::getF32(&context); 138 Attribute value = FloatAttr::get(floatTy, 10.0); 139 140 testSplat(floatTy, value); 141 } 142 143 TEST(DenseSplatTest, BF16Splat) { 144 MLIRContext context; 145 FloatType floatTy = FloatType::getBF16(&context); 146 Attribute value = FloatAttr::get(floatTy, 10.0); 147 148 testSplat(floatTy, value); 149 } 150 151 TEST(DenseSplatTest, StringSplat) { 152 MLIRContext context; 153 context.allowUnregisteredDialects(); 154 Type stringType = 155 OpaqueType::get(Identifier::get("test", &context), "string"); 156 StringRef value = "test-string"; 157 testSplat(stringType, value); 158 } 159 160 TEST(DenseSplatTest, StringAttrSplat) { 161 MLIRContext context; 162 context.allowUnregisteredDialects(); 163 Type stringType = 164 OpaqueType::get(Identifier::get("test", &context), "string"); 165 Attribute stringAttr = StringAttr::get("test-string", stringType); 166 testSplat(stringType, stringAttr); 167 } 168 169 TEST(DenseComplexTest, ComplexFloatSplat) { 170 MLIRContext context; 171 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 172 std::complex<float> value(10.0, 15.0); 173 testSplat(complexType, value); 174 } 175 176 TEST(DenseComplexTest, ComplexIntSplat) { 177 MLIRContext context; 178 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 179 std::complex<int64_t> value(10, 15); 180 testSplat(complexType, value); 181 } 182 183 TEST(DenseComplexTest, ComplexAPFloatSplat) { 184 MLIRContext context; 185 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 186 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 187 testSplat(complexType, value); 188 } 189 190 TEST(DenseComplexTest, ComplexAPIntSplat) { 191 MLIRContext context; 192 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 193 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 194 testSplat(complexType, value); 195 } 196 197 TEST(DenseScalarTest, ExtractZeroRankElement) { 198 MLIRContext context; 199 const int elementValue = 12; 200 IntegerType intTy = IntegerType::get(&context, 32); 201 Attribute value = IntegerAttr::get(intTy, elementValue); 202 RankedTensorType shape = RankedTensorType::get({}, intTy); 203 204 auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); 205 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 206 } 207 208 TEST(SparseElementsAttrTest, GetZero) { 209 MLIRContext context; 210 context.allowUnregisteredDialects(); 211 212 IntegerType intTy = IntegerType::get(&context, 32); 213 FloatType floatTy = FloatType::getF32(&context); 214 Type stringTy = OpaqueType::get(Identifier::get("test", &context), "string"); 215 216 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 217 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 218 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 219 220 auto indicesType = 221 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 222 auto indices = 223 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 224 225 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 226 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 227 228 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 229 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 230 231 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 232 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 233 234 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 235 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 236 auto sparseString = 237 SparseElementsAttr::get(tensorString, indices, stringValue); 238 239 // Only index (0, 0) contains an element, others are supposed to return 240 // the zero/empty value. 241 auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}]; 242 EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0); 243 EXPECT_TRUE(zeroIntValue.getType() == intTy); 244 245 auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}]; 246 EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f); 247 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 248 249 auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}]; 250 EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty()); 251 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 252 } 253 254 } // end namespace 255