xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision a1fe1f5f77d48b03b76884a9b9b91a6795193ac1)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AsmState.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
14 #include <optional>
15 
16 using namespace mlir;
17 using namespace mlir::detail;
18 
19 //===----------------------------------------------------------------------===//
20 // DenseElementsAttr
21 //===----------------------------------------------------------------------===//
22 
23 template <typename EltTy>
24 static void testSplat(Type eltType, const EltTy &splatElt) {
25   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
26 
27   // Check that the generated splat is the same for 1 element and N elements.
28   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
29   EXPECT_TRUE(splat.isSplat());
30 
31   auto detectedSplat =
32       DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
33   EXPECT_EQ(detectedSplat, splat);
34 
35   for (auto newValue : detectedSplat.template getValues<EltTy>())
36     EXPECT_TRUE(newValue == splatElt);
37 }
38 
39 namespace {
40 TEST(DenseSplatTest, BoolSplat) {
41   MLIRContext context;
42   IntegerType boolTy = IntegerType::get(&context, 1);
43   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
44 
45   // Check that splat is automatically detected for boolean values.
46   /// True.
47   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
48   EXPECT_TRUE(trueSplat.isSplat());
49   /// False.
50   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
51   EXPECT_TRUE(falseSplat.isSplat());
52   EXPECT_NE(falseSplat, trueSplat);
53 
54   /// Detect and handle splat within 8 elements (bool values are bit-packed).
55   /// True.
56   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
57   EXPECT_EQ(detectedSplat, trueSplat);
58   /// False.
59   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
60   EXPECT_EQ(detectedSplat, falseSplat);
61 }
62 TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
63   MLIRContext context;
64   IntegerType boolTy = IntegerType::get(&context, 1);
65   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
66 
67   // Check that splat booleans properly round trip via the raw API.
68   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
69   EXPECT_TRUE(trueSplat.isSplat());
70   DenseElementsAttr trueSplatFromRaw =
71       DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
72   EXPECT_TRUE(trueSplatFromRaw.isSplat());
73 
74   EXPECT_EQ(trueSplat, trueSplatFromRaw);
75 }
76 
77 TEST(DenseSplatTest, LargeBoolSplat) {
78   constexpr int64_t boolCount = 56;
79 
80   MLIRContext context;
81   IntegerType boolTy = IntegerType::get(&context, 1);
82   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
83 
84   // Check that splat is automatically detected for boolean values.
85   /// True.
86   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
87   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
88   EXPECT_TRUE(trueSplat.isSplat());
89   EXPECT_TRUE(falseSplat.isSplat());
90 
91   /// Detect that the large boolean arrays are properly splatted.
92   /// True.
93   SmallVector<bool, 64> trueValues(boolCount, true);
94   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
95   EXPECT_EQ(detectedSplat, trueSplat);
96   /// False.
97   SmallVector<bool, 64> falseValues(boolCount, false);
98   detectedSplat = DenseElementsAttr::get(shape, falseValues);
99   EXPECT_EQ(detectedSplat, falseSplat);
100 }
101 
102 TEST(DenseSplatTest, BoolNonSplat) {
103   MLIRContext context;
104   IntegerType boolTy = IntegerType::get(&context, 1);
105   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
106 
107   // Check that we properly handle non-splat values.
108   DenseElementsAttr nonSplat =
109       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
110   EXPECT_FALSE(nonSplat.isSplat());
111 }
112 
113 TEST(DenseSplatTest, OddIntSplat) {
114   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
115   MLIRContext context;
116   constexpr size_t intWidth = 19;
117   IntegerType intTy = IntegerType::get(&context, intWidth);
118   APInt value(intWidth, 10);
119 
120   testSplat(intTy, value);
121 }
122 
123 TEST(DenseSplatTest, Int32Splat) {
124   MLIRContext context;
125   IntegerType intTy = IntegerType::get(&context, 32);
126   int value = 64;
127 
128   testSplat(intTy, value);
129 }
130 
131 TEST(DenseSplatTest, IntAttrSplat) {
132   MLIRContext context;
133   IntegerType intTy = IntegerType::get(&context, 85);
134   Attribute value = IntegerAttr::get(intTy, 109);
135 
136   testSplat(intTy, value);
137 }
138 
139 TEST(DenseSplatTest, F32Splat) {
140   MLIRContext context;
141   FloatType floatTy = FloatType::getF32(&context);
142   float value = 10.0;
143 
144   testSplat(floatTy, value);
145 }
146 
147 TEST(DenseSplatTest, F64Splat) {
148   MLIRContext context;
149   FloatType floatTy = FloatType::getF64(&context);
150   double value = 10.0;
151 
152   testSplat(floatTy, APFloat(value));
153 }
154 
155 TEST(DenseSplatTest, FloatAttrSplat) {
156   MLIRContext context;
157   FloatType floatTy = FloatType::getF32(&context);
158   Attribute value = FloatAttr::get(floatTy, 10.0);
159 
160   testSplat(floatTy, value);
161 }
162 
163 TEST(DenseSplatTest, BF16Splat) {
164   MLIRContext context;
165   FloatType floatTy = FloatType::getBF16(&context);
166   Attribute value = FloatAttr::get(floatTy, 10.0);
167 
168   testSplat(floatTy, value);
169 }
170 
171 TEST(DenseSplatTest, StringSplat) {
172   MLIRContext context;
173   context.allowUnregisteredDialects();
174   Type stringType =
175       OpaqueType::get(StringAttr::get(&context, "test"), "string");
176   StringRef value = "test-string";
177   testSplat(stringType, value);
178 }
179 
180 TEST(DenseSplatTest, StringAttrSplat) {
181   MLIRContext context;
182   context.allowUnregisteredDialects();
183   Type stringType =
184       OpaqueType::get(StringAttr::get(&context, "test"), "string");
185   Attribute stringAttr = StringAttr::get("test-string", stringType);
186   testSplat(stringType, stringAttr);
187 }
188 
189 TEST(DenseComplexTest, ComplexFloatSplat) {
190   MLIRContext context;
191   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
192   std::complex<float> value(10.0, 15.0);
193   testSplat(complexType, value);
194 }
195 
196 TEST(DenseComplexTest, ComplexIntSplat) {
197   MLIRContext context;
198   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
199   std::complex<int64_t> value(10, 15);
200   testSplat(complexType, value);
201 }
202 
203 TEST(DenseComplexTest, ComplexAPFloatSplat) {
204   MLIRContext context;
205   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
206   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
207   testSplat(complexType, value);
208 }
209 
210 TEST(DenseComplexTest, ComplexAPIntSplat) {
211   MLIRContext context;
212   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
213   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
214   testSplat(complexType, value);
215 }
216 
217 TEST(DenseScalarTest, ExtractZeroRankElement) {
218   MLIRContext context;
219   const int elementValue = 12;
220   IntegerType intTy = IntegerType::get(&context, 32);
221   Attribute value = IntegerAttr::get(intTy, elementValue);
222   RankedTensorType shape = RankedTensorType::get({}, intTy);
223 
224   auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}));
225   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
226 }
227 
228 TEST(DenseSplatMapValuesTest, I32ToTrue) {
229   MLIRContext context;
230   const int elementValue = 12;
231   IntegerType boolTy = IntegerType::get(&context, 1);
232   IntegerType intTy = IntegerType::get(&context, 32);
233   RankedTensorType shape = RankedTensorType::get({4}, intTy);
234 
235   auto attr =
236       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
237           .mapValues(boolTy, [](const APInt &x) {
238             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
239           });
240   EXPECT_EQ(attr.getNumElements(), 4);
241   EXPECT_TRUE(attr.isSplat());
242   EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
243 }
244 
245 TEST(DenseSplatMapValuesTest, I32ToFalse) {
246   MLIRContext context;
247   const int elementValue = 0;
248   IntegerType boolTy = IntegerType::get(&context, 1);
249   IntegerType intTy = IntegerType::get(&context, 32);
250   RankedTensorType shape = RankedTensorType::get({4}, intTy);
251 
252   auto attr =
253       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
254           .mapValues(boolTy, [](const APInt &x) {
255             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
256           });
257   EXPECT_EQ(attr.getNumElements(), 4);
258   EXPECT_TRUE(attr.isSplat());
259   EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
260 }
261 } // namespace
262 
263 //===----------------------------------------------------------------------===//
264 // DenseResourceElementsAttr
265 //===----------------------------------------------------------------------===//
266 
267 template <typename AttrT, typename T>
268 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
269                               Type elementType) {
270   auto type = RankedTensorType::get(data.size(), elementType);
271   auto attr = AttrT::get(type, "resource",
272                          UnmanagedAsmResourceBlob::allocateInferAlign(data));
273 
274   // Check that we can access and iterate the data properly.
275   Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
276   EXPECT_TRUE(attrData.has_value());
277   EXPECT_EQ(*attrData, data);
278 
279   // Check that we cast to this attribute when possible.
280   Attribute genericAttr = attr;
281   EXPECT_TRUE(genericAttr.template isa<AttrT>());
282 }
283 template <typename AttrT, typename T>
284 static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
285   T data[] = {0, 1, 2};
286   checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data),
287                               builder.getIntegerType(intWidth));
288 }
289 
290 namespace {
291 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
292   MLIRContext context;
293   Builder builder(&context);
294 
295   // Bool
296   bool boolData[] = {true, false, true};
297   checkNativeAccess<DenseBoolResourceElementsAttr>(
298       &context, llvm::ArrayRef(boolData), builder.getI1Type());
299 
300   // Unsigned integers
301   checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
302   checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
303   checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
304   checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
305 
306   // Signed integers
307   checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
308   checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
309   checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
310   checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
311 
312   // Float
313   float floatData[] = {0, 1, 2};
314   checkNativeAccess<DenseF32ResourceElementsAttr>(
315       &context, llvm::ArrayRef(floatData), builder.getF32Type());
316 
317   // Double
318   double doubleData[] = {0, 1, 2};
319   checkNativeAccess<DenseF64ResourceElementsAttr>(
320       &context, llvm::ArrayRef(doubleData), builder.getF64Type());
321 }
322 
323 TEST(DenseResourceElementsAttrTest, CheckNoCast) {
324   MLIRContext context;
325   Builder builder(&context);
326 
327   // Create a i32 attribute.
328   ArrayRef<uint32_t> data;
329   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
330   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
331       type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
332 
333   EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
334   EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
335   EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
336 }
337 
338 TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
339   MLIRContext context;
340   Builder builder(&context);
341 
342   // Create a bool attribute with data of the incorrect type.
343   ArrayRef<uint32_t> data;
344   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
345   EXPECT_DEBUG_DEATH(
346       {
347         DenseBoolResourceElementsAttr::get(
348             type, "resource",
349             UnmanagedAsmResourceBlob::allocateInferAlign(data));
350       },
351       "alignment mismatch between expected alignment and blob alignment");
352 }
353 
354 TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
355   MLIRContext context;
356   Builder builder(&context);
357 
358   // Create a bool attribute with incorrect type.
359   ArrayRef<bool> data;
360   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
361   EXPECT_DEBUG_DEATH(
362       {
363         DenseBoolResourceElementsAttr::get(
364             type, "resource",
365             UnmanagedAsmResourceBlob::allocateInferAlign(data));
366       },
367       "invalid shape element type for provided type `T`");
368 }
369 } // namespace
370 
371 //===----------------------------------------------------------------------===//
372 // SparseElementsAttr
373 //===----------------------------------------------------------------------===//
374 
375 namespace {
376 TEST(SparseElementsAttrTest, GetZero) {
377   MLIRContext context;
378   context.allowUnregisteredDialects();
379 
380   IntegerType intTy = IntegerType::get(&context, 32);
381   FloatType floatTy = FloatType::getF32(&context);
382   Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
383 
384   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
385   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
386   ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
387 
388   auto indicesType =
389       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
390   auto indices =
391       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
392 
393   RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
394   auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
395 
396   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
397   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
398 
399   RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
400   auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
401 
402   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
403   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
404   auto sparseString =
405       SparseElementsAttr::get(tensorString, indices, stringValue);
406 
407   // Only index (0, 0) contains an element, others are supposed to return
408   // the zero/empty value.
409   auto zeroIntValue =
410       sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>();
411   EXPECT_EQ(zeroIntValue.getInt(), 0);
412   EXPECT_TRUE(zeroIntValue.getType() == intTy);
413 
414   auto zeroFloatValue =
415       sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>();
416   EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
417   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
418 
419   auto zeroStringValue =
420       sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>();
421   EXPECT_TRUE(zeroStringValue.getValue().empty());
422   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
423 }
424 
425 } // namespace
426