1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "gtest/gtest.h" 12 13 using namespace mlir; 14 using namespace mlir::detail; 15 16 template <typename EltTy> 17 static void testSplat(Type eltType, const EltTy &splatElt) { 18 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 19 20 // Check that the generated splat is the same for 1 element and N elements. 21 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 22 EXPECT_TRUE(splat.isSplat()); 23 24 auto detectedSplat = 25 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); 26 EXPECT_EQ(detectedSplat, splat); 27 28 for (auto newValue : detectedSplat.template getValues<EltTy>()) 29 EXPECT_TRUE(newValue == splatElt); 30 } 31 32 namespace { 33 TEST(DenseSplatTest, BoolSplat) { 34 MLIRContext context; 35 IntegerType boolTy = IntegerType::get(&context, 1); 36 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 37 38 // Check that splat is automatically detected for boolean values. 39 /// True. 40 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 41 EXPECT_TRUE(trueSplat.isSplat()); 42 /// False. 43 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 44 EXPECT_TRUE(falseSplat.isSplat()); 45 EXPECT_NE(falseSplat, trueSplat); 46 47 /// Detect and handle splat within 8 elements (bool values are bit-packed). 48 /// True. 49 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 50 EXPECT_EQ(detectedSplat, trueSplat); 51 /// False. 52 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 53 EXPECT_EQ(detectedSplat, falseSplat); 54 } 55 56 TEST(DenseSplatTest, LargeBoolSplat) { 57 constexpr int64_t boolCount = 56; 58 59 MLIRContext context; 60 IntegerType boolTy = IntegerType::get(&context, 1); 61 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 62 63 // Check that splat is automatically detected for boolean values. 64 /// True. 65 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 66 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 67 EXPECT_TRUE(trueSplat.isSplat()); 68 EXPECT_TRUE(falseSplat.isSplat()); 69 70 /// Detect that the large boolean arrays are properly splatted. 71 /// True. 72 SmallVector<bool, 64> trueValues(boolCount, true); 73 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 74 EXPECT_EQ(detectedSplat, trueSplat); 75 /// False. 76 SmallVector<bool, 64> falseValues(boolCount, false); 77 detectedSplat = DenseElementsAttr::get(shape, falseValues); 78 EXPECT_EQ(detectedSplat, falseSplat); 79 } 80 81 TEST(DenseSplatTest, BoolNonSplat) { 82 MLIRContext context; 83 IntegerType boolTy = IntegerType::get(&context, 1); 84 RankedTensorType shape = RankedTensorType::get({6}, boolTy); 85 86 // Check that we properly handle non-splat values. 87 DenseElementsAttr nonSplat = 88 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 89 EXPECT_FALSE(nonSplat.isSplat()); 90 } 91 92 TEST(DenseSplatTest, OddIntSplat) { 93 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 94 MLIRContext context; 95 constexpr size_t intWidth = 19; 96 IntegerType intTy = IntegerType::get(&context, intWidth); 97 APInt value(intWidth, 10); 98 99 testSplat(intTy, value); 100 } 101 102 TEST(DenseSplatTest, Int32Splat) { 103 MLIRContext context; 104 IntegerType intTy = IntegerType::get(&context, 32); 105 int value = 64; 106 107 testSplat(intTy, value); 108 } 109 110 TEST(DenseSplatTest, IntAttrSplat) { 111 MLIRContext context; 112 IntegerType intTy = IntegerType::get(&context, 85); 113 Attribute value = IntegerAttr::get(intTy, 109); 114 115 testSplat(intTy, value); 116 } 117 118 TEST(DenseSplatTest, F32Splat) { 119 MLIRContext context; 120 FloatType floatTy = FloatType::getF32(&context); 121 float value = 10.0; 122 123 testSplat(floatTy, value); 124 } 125 126 TEST(DenseSplatTest, F64Splat) { 127 MLIRContext context; 128 FloatType floatTy = FloatType::getF64(&context); 129 double value = 10.0; 130 131 testSplat(floatTy, APFloat(value)); 132 } 133 134 TEST(DenseSplatTest, FloatAttrSplat) { 135 MLIRContext context; 136 FloatType floatTy = FloatType::getF32(&context); 137 Attribute value = FloatAttr::get(floatTy, 10.0); 138 139 testSplat(floatTy, value); 140 } 141 142 TEST(DenseSplatTest, BF16Splat) { 143 MLIRContext context; 144 FloatType floatTy = FloatType::getBF16(&context); 145 Attribute value = FloatAttr::get(floatTy, 10.0); 146 147 testSplat(floatTy, value); 148 } 149 150 TEST(DenseSplatTest, StringSplat) { 151 MLIRContext context; 152 context.allowUnregisteredDialects(); 153 Type stringType = 154 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 155 StringRef value = "test-string"; 156 testSplat(stringType, value); 157 } 158 159 TEST(DenseSplatTest, StringAttrSplat) { 160 MLIRContext context; 161 context.allowUnregisteredDialects(); 162 Type stringType = 163 OpaqueType::get(StringAttr::get(&context, "test"), "string"); 164 Attribute stringAttr = StringAttr::get("test-string", stringType); 165 testSplat(stringType, stringAttr); 166 } 167 168 TEST(DenseComplexTest, ComplexFloatSplat) { 169 MLIRContext context; 170 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 171 std::complex<float> value(10.0, 15.0); 172 testSplat(complexType, value); 173 } 174 175 TEST(DenseComplexTest, ComplexIntSplat) { 176 MLIRContext context; 177 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 178 std::complex<int64_t> value(10, 15); 179 testSplat(complexType, value); 180 } 181 182 TEST(DenseComplexTest, ComplexAPFloatSplat) { 183 MLIRContext context; 184 ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); 185 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 186 testSplat(complexType, value); 187 } 188 189 TEST(DenseComplexTest, ComplexAPIntSplat) { 190 MLIRContext context; 191 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 192 std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 193 testSplat(complexType, value); 194 } 195 196 TEST(DenseScalarTest, ExtractZeroRankElement) { 197 MLIRContext context; 198 const int elementValue = 12; 199 IntegerType intTy = IntegerType::get(&context, 32); 200 Attribute value = IntegerAttr::get(intTy, elementValue); 201 RankedTensorType shape = RankedTensorType::get({}, intTy); 202 203 auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); 204 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 205 } 206 207 TEST(SparseElementsAttrTest, GetZero) { 208 MLIRContext context; 209 context.allowUnregisteredDialects(); 210 211 IntegerType intTy = IntegerType::get(&context, 32); 212 FloatType floatTy = FloatType::getF32(&context); 213 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 214 215 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 216 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 217 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 218 219 auto indicesType = 220 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 221 auto indices = 222 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 223 224 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 225 auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 226 227 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 228 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 229 230 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 231 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 232 233 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 234 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 235 auto sparseString = 236 SparseElementsAttr::get(tensorString, indices, stringValue); 237 238 // Only index (0, 0) contains an element, others are supposed to return 239 // the zero/empty value. 240 auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}]; 241 EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0); 242 EXPECT_TRUE(zeroIntValue.getType() == intTy); 243 244 auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}]; 245 EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f); 246 EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 247 248 auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}]; 249 EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty()); 250 EXPECT_TRUE(zeroStringValue.getType() == stringTy); 251 } 252 253 } // end namespace 254