xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision ae60a4a0efff337425638d04005b33a73dc5792f)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AsmState.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::detail;
17 
18 //===----------------------------------------------------------------------===//
19 // DenseElementsAttr
20 //===----------------------------------------------------------------------===//
21 
22 template <typename EltTy>
23 static void testSplat(Type eltType, const EltTy &splatElt) {
24   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
25 
26   // Check that the generated splat is the same for 1 element and N elements.
27   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
28   EXPECT_TRUE(splat.isSplat());
29 
30   auto detectedSplat =
31       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
32   EXPECT_EQ(detectedSplat, splat);
33 
34   for (auto newValue : detectedSplat.template getValues<EltTy>())
35     EXPECT_TRUE(newValue == splatElt);
36 }
37 
38 namespace {
39 TEST(DenseSplatTest, BoolSplat) {
40   MLIRContext context;
41   IntegerType boolTy = IntegerType::get(&context, 1);
42   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
43 
44   // Check that splat is automatically detected for boolean values.
45   /// True.
46   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
47   EXPECT_TRUE(trueSplat.isSplat());
48   /// False.
49   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
50   EXPECT_TRUE(falseSplat.isSplat());
51   EXPECT_NE(falseSplat, trueSplat);
52 
53   /// Detect and handle splat within 8 elements (bool values are bit-packed).
54   /// True.
55   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
56   EXPECT_EQ(detectedSplat, trueSplat);
57   /// False.
58   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
59   EXPECT_EQ(detectedSplat, falseSplat);
60 }
61 
62 TEST(DenseSplatTest, LargeBoolSplat) {
63   constexpr int64_t boolCount = 56;
64 
65   MLIRContext context;
66   IntegerType boolTy = IntegerType::get(&context, 1);
67   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
68 
69   // Check that splat is automatically detected for boolean values.
70   /// True.
71   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
72   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
73   EXPECT_TRUE(trueSplat.isSplat());
74   EXPECT_TRUE(falseSplat.isSplat());
75 
76   /// Detect that the large boolean arrays are properly splatted.
77   /// True.
78   SmallVector<bool, 64> trueValues(boolCount, true);
79   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
80   EXPECT_EQ(detectedSplat, trueSplat);
81   /// False.
82   SmallVector<bool, 64> falseValues(boolCount, false);
83   detectedSplat = DenseElementsAttr::get(shape, falseValues);
84   EXPECT_EQ(detectedSplat, falseSplat);
85 }
86 
87 TEST(DenseSplatTest, BoolNonSplat) {
88   MLIRContext context;
89   IntegerType boolTy = IntegerType::get(&context, 1);
90   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
91 
92   // Check that we properly handle non-splat values.
93   DenseElementsAttr nonSplat =
94       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
95   EXPECT_FALSE(nonSplat.isSplat());
96 }
97 
98 TEST(DenseSplatTest, OddIntSplat) {
99   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
100   MLIRContext context;
101   constexpr size_t intWidth = 19;
102   IntegerType intTy = IntegerType::get(&context, intWidth);
103   APInt value(intWidth, 10);
104 
105   testSplat(intTy, value);
106 }
107 
108 TEST(DenseSplatTest, Int32Splat) {
109   MLIRContext context;
110   IntegerType intTy = IntegerType::get(&context, 32);
111   int value = 64;
112 
113   testSplat(intTy, value);
114 }
115 
116 TEST(DenseSplatTest, IntAttrSplat) {
117   MLIRContext context;
118   IntegerType intTy = IntegerType::get(&context, 85);
119   Attribute value = IntegerAttr::get(intTy, 109);
120 
121   testSplat(intTy, value);
122 }
123 
124 TEST(DenseSplatTest, F32Splat) {
125   MLIRContext context;
126   FloatType floatTy = FloatType::getF32(&context);
127   float value = 10.0;
128 
129   testSplat(floatTy, value);
130 }
131 
132 TEST(DenseSplatTest, F64Splat) {
133   MLIRContext context;
134   FloatType floatTy = FloatType::getF64(&context);
135   double value = 10.0;
136 
137   testSplat(floatTy, APFloat(value));
138 }
139 
140 TEST(DenseSplatTest, FloatAttrSplat) {
141   MLIRContext context;
142   FloatType floatTy = FloatType::getF32(&context);
143   Attribute value = FloatAttr::get(floatTy, 10.0);
144 
145   testSplat(floatTy, value);
146 }
147 
148 TEST(DenseSplatTest, BF16Splat) {
149   MLIRContext context;
150   FloatType floatTy = FloatType::getBF16(&context);
151   Attribute value = FloatAttr::get(floatTy, 10.0);
152 
153   testSplat(floatTy, value);
154 }
155 
156 TEST(DenseSplatTest, StringSplat) {
157   MLIRContext context;
158   context.allowUnregisteredDialects();
159   Type stringType =
160       OpaqueType::get(StringAttr::get(&context, "test"), "string");
161   StringRef value = "test-string";
162   testSplat(stringType, value);
163 }
164 
165 TEST(DenseSplatTest, StringAttrSplat) {
166   MLIRContext context;
167   context.allowUnregisteredDialects();
168   Type stringType =
169       OpaqueType::get(StringAttr::get(&context, "test"), "string");
170   Attribute stringAttr = StringAttr::get("test-string", stringType);
171   testSplat(stringType, stringAttr);
172 }
173 
174 TEST(DenseComplexTest, ComplexFloatSplat) {
175   MLIRContext context;
176   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
177   std::complex<float> value(10.0, 15.0);
178   testSplat(complexType, value);
179 }
180 
181 TEST(DenseComplexTest, ComplexIntSplat) {
182   MLIRContext context;
183   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
184   std::complex<int64_t> value(10, 15);
185   testSplat(complexType, value);
186 }
187 
188 TEST(DenseComplexTest, ComplexAPFloatSplat) {
189   MLIRContext context;
190   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
191   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
192   testSplat(complexType, value);
193 }
194 
195 TEST(DenseComplexTest, ComplexAPIntSplat) {
196   MLIRContext context;
197   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
198   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
199   testSplat(complexType, value);
200 }
201 
202 TEST(DenseScalarTest, ExtractZeroRankElement) {
203   MLIRContext context;
204   const int elementValue = 12;
205   IntegerType intTy = IntegerType::get(&context, 32);
206   Attribute value = IntegerAttr::get(intTy, elementValue);
207   RankedTensorType shape = RankedTensorType::get({}, intTy);
208 
209   auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
210   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
211 }
212 
213 TEST(DenseSplatMapValuesTest, I32ToTrue) {
214   MLIRContext context;
215   const int elementValue = 12;
216   IntegerType boolTy = IntegerType::get(&context, 1);
217   IntegerType intTy = IntegerType::get(&context, 32);
218   RankedTensorType shape = RankedTensorType::get({4}, intTy);
219 
220   auto attr =
221       DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
222           .mapValues(boolTy, [](const APInt &x) {
223             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
224           });
225   EXPECT_EQ(attr.getNumElements(), 4);
226   EXPECT_TRUE(attr.isSplat());
227   EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
228 }
229 
230 TEST(DenseSplatMapValuesTest, I32ToFalse) {
231   MLIRContext context;
232   const int elementValue = 0;
233   IntegerType boolTy = IntegerType::get(&context, 1);
234   IntegerType intTy = IntegerType::get(&context, 32);
235   RankedTensorType shape = RankedTensorType::get({4}, intTy);
236 
237   auto attr =
238       DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
239           .mapValues(boolTy, [](const APInt &x) {
240             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
241           });
242   EXPECT_EQ(attr.getNumElements(), 4);
243   EXPECT_TRUE(attr.isSplat());
244   EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
245 }
246 } // namespace
247 
248 //===----------------------------------------------------------------------===//
249 // DenseResourceElementsAttr
250 //===----------------------------------------------------------------------===//
251 
252 template <typename AttrT, typename T>
253 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
254                               Type elementType) {
255   auto type = RankedTensorType::get(data.size(), elementType);
256   auto attr = AttrT::get(type, "resource",
257                          UnmanagedAsmResourceBlob::allocateInferAlign(data));
258 
259   // Check that we can access and iterate the data properly.
260   Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
261   EXPECT_TRUE(attrData.has_value());
262   EXPECT_EQ(*attrData, data);
263 
264   // Check that we cast to this attribute when possible.
265   Attribute genericAttr = attr;
266   EXPECT_TRUE(genericAttr.template isa<AttrT>());
267 }
268 template <typename AttrT, typename T>
269 static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
270   T data[] = {0, 1, 2};
271   checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data),
272                               builder.getIntegerType(intWidth));
273 }
274 
275 namespace {
276 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
277   MLIRContext context;
278   Builder builder(&context);
279 
280   // Bool
281   bool boolData[] = {true, false, true};
282   checkNativeAccess<DenseBoolResourceElementsAttr>(
283       &context, llvm::makeArrayRef(boolData), builder.getI1Type());
284 
285   // Unsigned integers
286   checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
287   checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
288   checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
289   checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
290 
291   // Signed integers
292   checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
293   checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
294   checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
295   checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
296 
297   // Float
298   float floatData[] = {0, 1, 2};
299   checkNativeAccess<DenseF32ResourceElementsAttr>(
300       &context, llvm::makeArrayRef(floatData), builder.getF32Type());
301 
302   // Double
303   double doubleData[] = {0, 1, 2};
304   checkNativeAccess<DenseF64ResourceElementsAttr>(
305       &context, llvm::makeArrayRef(doubleData), builder.getF64Type());
306 }
307 
308 TEST(DenseResourceElementsAttrTest, CheckNoCast) {
309   MLIRContext context;
310   Builder builder(&context);
311 
312   // Create a i32 attribute.
313   ArrayRef<uint32_t> data;
314   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
315   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
316       type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
317 
318   EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
319   EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
320   EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
321 }
322 
323 TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
324   MLIRContext context;
325   Builder builder(&context);
326 
327   // Create a bool attribute with data of the incorrect type.
328   ArrayRef<uint32_t> data;
329   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
330   EXPECT_DEBUG_DEATH(
331       {
332         DenseBoolResourceElementsAttr::get(
333             type, "resource",
334             UnmanagedAsmResourceBlob::allocateInferAlign(data));
335       },
336       "alignment mismatch between expected alignment and blob alignment");
337 }
338 
339 TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
340   MLIRContext context;
341   Builder builder(&context);
342 
343   // Create a bool attribute with incorrect type.
344   ArrayRef<bool> data;
345   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
346   EXPECT_DEBUG_DEATH(
347       {
348         DenseBoolResourceElementsAttr::get(
349             type, "resource",
350             UnmanagedAsmResourceBlob::allocateInferAlign(data));
351       },
352       "invalid shape element type for provided type `T`");
353 }
354 } // namespace
355 
356 //===----------------------------------------------------------------------===//
357 // SparseElementsAttr
358 //===----------------------------------------------------------------------===//
359 
360 namespace {
361 TEST(SparseElementsAttrTest, GetZero) {
362   MLIRContext context;
363   context.allowUnregisteredDialects();
364 
365   IntegerType intTy = IntegerType::get(&context, 32);
366   FloatType floatTy = FloatType::getF32(&context);
367   Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
368 
369   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
370   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
371   ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
372 
373   auto indicesType =
374       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
375   auto indices =
376       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
377 
378   RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
379   auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
380 
381   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
382   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
383 
384   RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
385   auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
386 
387   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
388   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
389   auto sparseString =
390       SparseElementsAttr::get(tensorString, indices, stringValue);
391 
392   // Only index (0, 0) contains an element, others are supposed to return
393   // the zero/empty value.
394   auto zeroIntValue =
395       sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>();
396   EXPECT_EQ(zeroIntValue.getInt(), 0);
397   EXPECT_TRUE(zeroIntValue.getType() == intTy);
398 
399   auto zeroFloatValue =
400       sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>();
401   EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
402   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
403 
404   auto zeroStringValue =
405       sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>();
406   EXPECT_TRUE(zeroStringValue.getValue().empty());
407   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
408 }
409 
410 } // namespace
411