xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision 998003eaf7b91b5620eeb0b77c81654b9f36efb7)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AsmState.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
14 #include <optional>
15 
16 using namespace mlir;
17 using namespace mlir::detail;
18 
19 //===----------------------------------------------------------------------===//
20 // DenseElementsAttr
21 //===----------------------------------------------------------------------===//
22 
23 template <typename EltTy>
24 static void testSplat(Type eltType, const EltTy &splatElt) {
25   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
26 
27   // Check that the generated splat is the same for 1 element and N elements.
28   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
29   EXPECT_TRUE(splat.isSplat());
30 
31   auto detectedSplat =
32       DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
33   EXPECT_EQ(detectedSplat, splat);
34 
35   for (auto newValue : detectedSplat.template getValues<EltTy>())
36     EXPECT_TRUE(newValue == splatElt);
37 }
38 
39 namespace {
40 TEST(DenseSplatTest, BoolSplat) {
41   MLIRContext context;
42   IntegerType boolTy = IntegerType::get(&context, 1);
43   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
44 
45   // Check that splat is automatically detected for boolean values.
46   /// True.
47   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
48   EXPECT_TRUE(trueSplat.isSplat());
49   /// False.
50   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
51   EXPECT_TRUE(falseSplat.isSplat());
52   EXPECT_NE(falseSplat, trueSplat);
53 
54   /// Detect and handle splat within 8 elements (bool values are bit-packed).
55   /// True.
56   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
57   EXPECT_EQ(detectedSplat, trueSplat);
58   /// False.
59   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
60   EXPECT_EQ(detectedSplat, falseSplat);
61 }
62 TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
63   MLIRContext context;
64   IntegerType boolTy = IntegerType::get(&context, 1);
65   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
66 
67   // Check that splat booleans properly round trip via the raw API.
68   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
69   EXPECT_TRUE(trueSplat.isSplat());
70   DenseElementsAttr trueSplatFromRaw =
71       DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
72   EXPECT_TRUE(trueSplatFromRaw.isSplat());
73 
74   EXPECT_EQ(trueSplat, trueSplatFromRaw);
75 }
76 
77 TEST(DenseSplatTest, BoolSplatSmall) {
78   MLIRContext context;
79   Builder builder(&context);
80 
81   // Check that splats that don't fill entire byte are handled properly.
82   auto tensorType = RankedTensorType::get({4}, builder.getI1Type());
83   std::vector<char> data{0b00001111};
84   auto trueSplatFromRaw =
85       DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data);
86   EXPECT_TRUE(trueSplatFromRaw.isSplat());
87   DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true);
88   EXPECT_EQ(trueSplat, trueSplatFromRaw);
89 }
90 
91 TEST(DenseSplatTest, LargeBoolSplat) {
92   constexpr int64_t boolCount = 56;
93 
94   MLIRContext context;
95   IntegerType boolTy = IntegerType::get(&context, 1);
96   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
97 
98   // Check that splat is automatically detected for boolean values.
99   /// True.
100   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
101   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
102   EXPECT_TRUE(trueSplat.isSplat());
103   EXPECT_TRUE(falseSplat.isSplat());
104 
105   /// Detect that the large boolean arrays are properly splatted.
106   /// True.
107   SmallVector<bool, 64> trueValues(boolCount, true);
108   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
109   EXPECT_EQ(detectedSplat, trueSplat);
110   /// False.
111   SmallVector<bool, 64> falseValues(boolCount, false);
112   detectedSplat = DenseElementsAttr::get(shape, falseValues);
113   EXPECT_EQ(detectedSplat, falseSplat);
114 }
115 
116 TEST(DenseSplatTest, BoolNonSplat) {
117   MLIRContext context;
118   IntegerType boolTy = IntegerType::get(&context, 1);
119   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
120 
121   // Check that we properly handle non-splat values.
122   DenseElementsAttr nonSplat =
123       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
124   EXPECT_FALSE(nonSplat.isSplat());
125 }
126 
127 TEST(DenseSplatTest, OddIntSplat) {
128   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
129   MLIRContext context;
130   constexpr size_t intWidth = 19;
131   IntegerType intTy = IntegerType::get(&context, intWidth);
132   APInt value(intWidth, 10);
133 
134   testSplat(intTy, value);
135 }
136 
137 TEST(DenseSplatTest, Int32Splat) {
138   MLIRContext context;
139   IntegerType intTy = IntegerType::get(&context, 32);
140   int value = 64;
141 
142   testSplat(intTy, value);
143 }
144 
145 TEST(DenseSplatTest, IntAttrSplat) {
146   MLIRContext context;
147   IntegerType intTy = IntegerType::get(&context, 85);
148   Attribute value = IntegerAttr::get(intTy, 109);
149 
150   testSplat(intTy, value);
151 }
152 
153 TEST(DenseSplatTest, F32Splat) {
154   MLIRContext context;
155   FloatType floatTy = FloatType::getF32(&context);
156   float value = 10.0;
157 
158   testSplat(floatTy, value);
159 }
160 
161 TEST(DenseSplatTest, F64Splat) {
162   MLIRContext context;
163   FloatType floatTy = FloatType::getF64(&context);
164   double value = 10.0;
165 
166   testSplat(floatTy, APFloat(value));
167 }
168 
169 TEST(DenseSplatTest, FloatAttrSplat) {
170   MLIRContext context;
171   FloatType floatTy = FloatType::getF32(&context);
172   Attribute value = FloatAttr::get(floatTy, 10.0);
173 
174   testSplat(floatTy, value);
175 }
176 
177 TEST(DenseSplatTest, BF16Splat) {
178   MLIRContext context;
179   FloatType floatTy = FloatType::getBF16(&context);
180   Attribute value = FloatAttr::get(floatTy, 10.0);
181 
182   testSplat(floatTy, value);
183 }
184 
185 TEST(DenseSplatTest, StringSplat) {
186   MLIRContext context;
187   context.allowUnregisteredDialects();
188   Type stringType =
189       OpaqueType::get(StringAttr::get(&context, "test"), "string");
190   StringRef value = "test-string";
191   testSplat(stringType, value);
192 }
193 
194 TEST(DenseSplatTest, StringAttrSplat) {
195   MLIRContext context;
196   context.allowUnregisteredDialects();
197   Type stringType =
198       OpaqueType::get(StringAttr::get(&context, "test"), "string");
199   Attribute stringAttr = StringAttr::get("test-string", stringType);
200   testSplat(stringType, stringAttr);
201 }
202 
203 TEST(DenseComplexTest, ComplexFloatSplat) {
204   MLIRContext context;
205   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
206   std::complex<float> value(10.0, 15.0);
207   testSplat(complexType, value);
208 }
209 
210 TEST(DenseComplexTest, ComplexIntSplat) {
211   MLIRContext context;
212   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
213   std::complex<int64_t> value(10, 15);
214   testSplat(complexType, value);
215 }
216 
217 TEST(DenseComplexTest, ComplexAPFloatSplat) {
218   MLIRContext context;
219   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
220   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
221   testSplat(complexType, value);
222 }
223 
224 TEST(DenseComplexTest, ComplexAPIntSplat) {
225   MLIRContext context;
226   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
227   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
228   testSplat(complexType, value);
229 }
230 
231 TEST(DenseScalarTest, ExtractZeroRankElement) {
232   MLIRContext context;
233   const int elementValue = 12;
234   IntegerType intTy = IntegerType::get(&context, 32);
235   Attribute value = IntegerAttr::get(intTy, elementValue);
236   RankedTensorType shape = RankedTensorType::get({}, intTy);
237 
238   auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}));
239   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
240 }
241 
242 TEST(DenseSplatMapValuesTest, I32ToTrue) {
243   MLIRContext context;
244   const int elementValue = 12;
245   IntegerType boolTy = IntegerType::get(&context, 1);
246   IntegerType intTy = IntegerType::get(&context, 32);
247   RankedTensorType shape = RankedTensorType::get({4}, intTy);
248 
249   auto attr =
250       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
251           .mapValues(boolTy, [](const APInt &x) {
252             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
253           });
254   EXPECT_EQ(attr.getNumElements(), 4);
255   EXPECT_TRUE(attr.isSplat());
256   EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
257 }
258 
259 TEST(DenseSplatMapValuesTest, I32ToFalse) {
260   MLIRContext context;
261   const int elementValue = 0;
262   IntegerType boolTy = IntegerType::get(&context, 1);
263   IntegerType intTy = IntegerType::get(&context, 32);
264   RankedTensorType shape = RankedTensorType::get({4}, intTy);
265 
266   auto attr =
267       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
268           .mapValues(boolTy, [](const APInt &x) {
269             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
270           });
271   EXPECT_EQ(attr.getNumElements(), 4);
272   EXPECT_TRUE(attr.isSplat());
273   EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
274 }
275 } // namespace
276 
277 //===----------------------------------------------------------------------===//
278 // DenseResourceElementsAttr
279 //===----------------------------------------------------------------------===//
280 
281 template <typename AttrT, typename T>
282 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
283                               Type elementType) {
284   auto type = RankedTensorType::get(data.size(), elementType);
285   auto attr = AttrT::get(type, "resource",
286                          UnmanagedAsmResourceBlob::allocateInferAlign(data));
287 
288   // Check that we can access and iterate the data properly.
289   std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
290   EXPECT_TRUE(attrData.has_value());
291   EXPECT_EQ(*attrData, data);
292 
293   // Check that we cast to this attribute when possible.
294   Attribute genericAttr = attr;
295   EXPECT_TRUE(isa<AttrT>(genericAttr));
296 }
297 template <typename AttrT, typename T>
298 static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
299   T data[] = {0, 1, 2};
300   checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data),
301                               builder.getIntegerType(intWidth));
302 }
303 
304 namespace {
305 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
306   MLIRContext context;
307   Builder builder(&context);
308 
309   // Bool
310   bool boolData[] = {true, false, true};
311   checkNativeAccess<DenseBoolResourceElementsAttr>(
312       &context, llvm::ArrayRef(boolData), builder.getI1Type());
313 
314   // Unsigned integers
315   checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
316   checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
317   checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
318   checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
319 
320   // Signed integers
321   checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
322   checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
323   checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
324   checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
325 
326   // Float
327   float floatData[] = {0, 1, 2};
328   checkNativeAccess<DenseF32ResourceElementsAttr>(
329       &context, llvm::ArrayRef(floatData), builder.getF32Type());
330 
331   // Double
332   double doubleData[] = {0, 1, 2};
333   checkNativeAccess<DenseF64ResourceElementsAttr>(
334       &context, llvm::ArrayRef(doubleData), builder.getF64Type());
335 }
336 
337 TEST(DenseResourceElementsAttrTest, CheckNoCast) {
338   MLIRContext context;
339   Builder builder(&context);
340 
341   // Create a i32 attribute.
342   ArrayRef<uint32_t> data;
343   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
344   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
345       type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
346 
347   EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr));
348   EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr));
349   EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr));
350 }
351 
352 TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
353   MLIRContext context;
354   Builder builder(&context);
355 
356   // Create a bool attribute with data of the incorrect type.
357   ArrayRef<uint32_t> data;
358   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
359   EXPECT_DEBUG_DEATH(
360       {
361         DenseBoolResourceElementsAttr::get(
362             type, "resource",
363             UnmanagedAsmResourceBlob::allocateInferAlign(data));
364       },
365       "alignment mismatch between expected alignment and blob alignment");
366 }
367 
368 TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
369   MLIRContext context;
370   Builder builder(&context);
371 
372   // Create a bool attribute with incorrect type.
373   ArrayRef<bool> data;
374   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
375   EXPECT_DEBUG_DEATH(
376       {
377         DenseBoolResourceElementsAttr::get(
378             type, "resource",
379             UnmanagedAsmResourceBlob::allocateInferAlign(data));
380       },
381       "invalid shape element type for provided type `T`");
382 }
383 } // namespace
384 
385 //===----------------------------------------------------------------------===//
386 // SparseElementsAttr
387 //===----------------------------------------------------------------------===//
388 
389 namespace {
390 TEST(SparseElementsAttrTest, GetZero) {
391   MLIRContext context;
392   context.allowUnregisteredDialects();
393 
394   IntegerType intTy = IntegerType::get(&context, 32);
395   FloatType floatTy = FloatType::getF32(&context);
396   Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
397 
398   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
399   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
400   ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
401 
402   auto indicesType =
403       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
404   auto indices =
405       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
406 
407   RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
408   auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
409 
410   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
411   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
412 
413   RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
414   auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
415 
416   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
417   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
418   auto sparseString =
419       SparseElementsAttr::get(tensorString, indices, stringValue);
420 
421   // Only index (0, 0) contains an element, others are supposed to return
422   // the zero/empty value.
423   auto zeroIntValue =
424       cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]);
425   EXPECT_EQ(zeroIntValue.getInt(), 0);
426   EXPECT_TRUE(zeroIntValue.getType() == intTy);
427 
428   auto zeroFloatValue =
429       cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
430   EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
431   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
432 
433   auto zeroStringValue =
434       cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
435   EXPECT_TRUE(zeroStringValue.getValue().empty());
436   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // SubElements
441 //===----------------------------------------------------------------------===//
442 
443 TEST(SubElementTest, Nested) {
444   MLIRContext context;
445   Builder builder(&context);
446 
447   BoolAttr trueAttr = builder.getBoolAttr(true);
448   BoolAttr falseAttr = builder.getBoolAttr(false);
449   ArrayAttr boolArrayAttr =
450       builder.getArrayAttr({trueAttr, falseAttr, trueAttr});
451   StringAttr strAttr = builder.getStringAttr("array");
452   DictionaryAttr dictAttr =
453       builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
454 
455   SmallVector<Attribute> subAttrs;
456   dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); });
457   // Note that trueAttr appears only once, identical subattributes are skipped.
458   EXPECT_EQ(llvm::ArrayRef(subAttrs),
459             ArrayRef<Attribute>(
460                 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
461 }
462 } // namespace
463