xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision 984b800a036fc61ccb129a8da7592af9cadc94dd)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AsmState.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::detail;
17 
18 //===----------------------------------------------------------------------===//
19 // DenseElementsAttr
20 //===----------------------------------------------------------------------===//
21 
22 template <typename EltTy>
23 static void testSplat(Type eltType, const EltTy &splatElt) {
24   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
25 
26   // Check that the generated splat is the same for 1 element and N elements.
27   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
28   EXPECT_TRUE(splat.isSplat());
29 
30   auto detectedSplat =
31       DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
32   EXPECT_EQ(detectedSplat, splat);
33 
34   for (auto newValue : detectedSplat.template getValues<EltTy>())
35     EXPECT_TRUE(newValue == splatElt);
36 }
37 
38 namespace {
39 TEST(DenseSplatTest, BoolSplat) {
40   MLIRContext context;
41   IntegerType boolTy = IntegerType::get(&context, 1);
42   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
43 
44   // Check that splat is automatically detected for boolean values.
45   /// True.
46   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
47   EXPECT_TRUE(trueSplat.isSplat());
48   /// False.
49   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
50   EXPECT_TRUE(falseSplat.isSplat());
51   EXPECT_NE(falseSplat, trueSplat);
52 
53   /// Detect and handle splat within 8 elements (bool values are bit-packed).
54   /// True.
55   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
56   EXPECT_EQ(detectedSplat, trueSplat);
57   /// False.
58   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
59   EXPECT_EQ(detectedSplat, falseSplat);
60 }
61 TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
62   MLIRContext context;
63   IntegerType boolTy = IntegerType::get(&context, 1);
64   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
65 
66   // Check that splat booleans properly round trip via the raw API.
67   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
68   EXPECT_TRUE(trueSplat.isSplat());
69   DenseElementsAttr trueSplatFromRaw =
70       DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
71   EXPECT_TRUE(trueSplatFromRaw.isSplat());
72 
73   EXPECT_EQ(trueSplat, trueSplatFromRaw);
74 }
75 
76 TEST(DenseSplatTest, LargeBoolSplat) {
77   constexpr int64_t boolCount = 56;
78 
79   MLIRContext context;
80   IntegerType boolTy = IntegerType::get(&context, 1);
81   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
82 
83   // Check that splat is automatically detected for boolean values.
84   /// True.
85   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
86   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
87   EXPECT_TRUE(trueSplat.isSplat());
88   EXPECT_TRUE(falseSplat.isSplat());
89 
90   /// Detect that the large boolean arrays are properly splatted.
91   /// True.
92   SmallVector<bool, 64> trueValues(boolCount, true);
93   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
94   EXPECT_EQ(detectedSplat, trueSplat);
95   /// False.
96   SmallVector<bool, 64> falseValues(boolCount, false);
97   detectedSplat = DenseElementsAttr::get(shape, falseValues);
98   EXPECT_EQ(detectedSplat, falseSplat);
99 }
100 
101 TEST(DenseSplatTest, BoolNonSplat) {
102   MLIRContext context;
103   IntegerType boolTy = IntegerType::get(&context, 1);
104   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
105 
106   // Check that we properly handle non-splat values.
107   DenseElementsAttr nonSplat =
108       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
109   EXPECT_FALSE(nonSplat.isSplat());
110 }
111 
112 TEST(DenseSplatTest, OddIntSplat) {
113   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
114   MLIRContext context;
115   constexpr size_t intWidth = 19;
116   IntegerType intTy = IntegerType::get(&context, intWidth);
117   APInt value(intWidth, 10);
118 
119   testSplat(intTy, value);
120 }
121 
122 TEST(DenseSplatTest, Int32Splat) {
123   MLIRContext context;
124   IntegerType intTy = IntegerType::get(&context, 32);
125   int value = 64;
126 
127   testSplat(intTy, value);
128 }
129 
130 TEST(DenseSplatTest, IntAttrSplat) {
131   MLIRContext context;
132   IntegerType intTy = IntegerType::get(&context, 85);
133   Attribute value = IntegerAttr::get(intTy, 109);
134 
135   testSplat(intTy, value);
136 }
137 
138 TEST(DenseSplatTest, F32Splat) {
139   MLIRContext context;
140   FloatType floatTy = FloatType::getF32(&context);
141   float value = 10.0;
142 
143   testSplat(floatTy, value);
144 }
145 
146 TEST(DenseSplatTest, F64Splat) {
147   MLIRContext context;
148   FloatType floatTy = FloatType::getF64(&context);
149   double value = 10.0;
150 
151   testSplat(floatTy, APFloat(value));
152 }
153 
154 TEST(DenseSplatTest, FloatAttrSplat) {
155   MLIRContext context;
156   FloatType floatTy = FloatType::getF32(&context);
157   Attribute value = FloatAttr::get(floatTy, 10.0);
158 
159   testSplat(floatTy, value);
160 }
161 
162 TEST(DenseSplatTest, BF16Splat) {
163   MLIRContext context;
164   FloatType floatTy = FloatType::getBF16(&context);
165   Attribute value = FloatAttr::get(floatTy, 10.0);
166 
167   testSplat(floatTy, value);
168 }
169 
170 TEST(DenseSplatTest, StringSplat) {
171   MLIRContext context;
172   context.allowUnregisteredDialects();
173   Type stringType =
174       OpaqueType::get(StringAttr::get(&context, "test"), "string");
175   StringRef value = "test-string";
176   testSplat(stringType, value);
177 }
178 
179 TEST(DenseSplatTest, StringAttrSplat) {
180   MLIRContext context;
181   context.allowUnregisteredDialects();
182   Type stringType =
183       OpaqueType::get(StringAttr::get(&context, "test"), "string");
184   Attribute stringAttr = StringAttr::get("test-string", stringType);
185   testSplat(stringType, stringAttr);
186 }
187 
188 TEST(DenseComplexTest, ComplexFloatSplat) {
189   MLIRContext context;
190   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
191   std::complex<float> value(10.0, 15.0);
192   testSplat(complexType, value);
193 }
194 
195 TEST(DenseComplexTest, ComplexIntSplat) {
196   MLIRContext context;
197   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
198   std::complex<int64_t> value(10, 15);
199   testSplat(complexType, value);
200 }
201 
202 TEST(DenseComplexTest, ComplexAPFloatSplat) {
203   MLIRContext context;
204   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
205   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
206   testSplat(complexType, value);
207 }
208 
209 TEST(DenseComplexTest, ComplexAPIntSplat) {
210   MLIRContext context;
211   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
212   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
213   testSplat(complexType, value);
214 }
215 
216 TEST(DenseScalarTest, ExtractZeroRankElement) {
217   MLIRContext context;
218   const int elementValue = 12;
219   IntegerType intTy = IntegerType::get(&context, 32);
220   Attribute value = IntegerAttr::get(intTy, elementValue);
221   RankedTensorType shape = RankedTensorType::get({}, intTy);
222 
223   auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}));
224   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
225 }
226 
227 TEST(DenseSplatMapValuesTest, I32ToTrue) {
228   MLIRContext context;
229   const int elementValue = 12;
230   IntegerType boolTy = IntegerType::get(&context, 1);
231   IntegerType intTy = IntegerType::get(&context, 32);
232   RankedTensorType shape = RankedTensorType::get({4}, intTy);
233 
234   auto attr =
235       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
236           .mapValues(boolTy, [](const APInt &x) {
237             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
238           });
239   EXPECT_EQ(attr.getNumElements(), 4);
240   EXPECT_TRUE(attr.isSplat());
241   EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
242 }
243 
244 TEST(DenseSplatMapValuesTest, I32ToFalse) {
245   MLIRContext context;
246   const int elementValue = 0;
247   IntegerType boolTy = IntegerType::get(&context, 1);
248   IntegerType intTy = IntegerType::get(&context, 32);
249   RankedTensorType shape = RankedTensorType::get({4}, intTy);
250 
251   auto attr =
252       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
253           .mapValues(boolTy, [](const APInt &x) {
254             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
255           });
256   EXPECT_EQ(attr.getNumElements(), 4);
257   EXPECT_TRUE(attr.isSplat());
258   EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
259 }
260 } // namespace
261 
262 //===----------------------------------------------------------------------===//
263 // DenseResourceElementsAttr
264 //===----------------------------------------------------------------------===//
265 
266 template <typename AttrT, typename T>
267 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
268                               Type elementType) {
269   auto type = RankedTensorType::get(data.size(), elementType);
270   auto attr = AttrT::get(type, "resource",
271                          UnmanagedAsmResourceBlob::allocateInferAlign(data));
272 
273   // Check that we can access and iterate the data properly.
274   Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
275   EXPECT_TRUE(attrData.has_value());
276   EXPECT_EQ(*attrData, data);
277 
278   // Check that we cast to this attribute when possible.
279   Attribute genericAttr = attr;
280   EXPECT_TRUE(genericAttr.template isa<AttrT>());
281 }
282 template <typename AttrT, typename T>
283 static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
284   T data[] = {0, 1, 2};
285   checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data),
286                               builder.getIntegerType(intWidth));
287 }
288 
289 namespace {
290 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
291   MLIRContext context;
292   Builder builder(&context);
293 
294   // Bool
295   bool boolData[] = {true, false, true};
296   checkNativeAccess<DenseBoolResourceElementsAttr>(
297       &context, llvm::ArrayRef(boolData), builder.getI1Type());
298 
299   // Unsigned integers
300   checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
301   checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
302   checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
303   checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
304 
305   // Signed integers
306   checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
307   checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
308   checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
309   checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
310 
311   // Float
312   float floatData[] = {0, 1, 2};
313   checkNativeAccess<DenseF32ResourceElementsAttr>(
314       &context, llvm::ArrayRef(floatData), builder.getF32Type());
315 
316   // Double
317   double doubleData[] = {0, 1, 2};
318   checkNativeAccess<DenseF64ResourceElementsAttr>(
319       &context, llvm::ArrayRef(doubleData), builder.getF64Type());
320 }
321 
322 TEST(DenseResourceElementsAttrTest, CheckNoCast) {
323   MLIRContext context;
324   Builder builder(&context);
325 
326   // Create a i32 attribute.
327   ArrayRef<uint32_t> data;
328   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
329   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
330       type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
331 
332   EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
333   EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
334   EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
335 }
336 
337 TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
338   MLIRContext context;
339   Builder builder(&context);
340 
341   // Create a bool attribute with data of the incorrect type.
342   ArrayRef<uint32_t> data;
343   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
344   EXPECT_DEBUG_DEATH(
345       {
346         DenseBoolResourceElementsAttr::get(
347             type, "resource",
348             UnmanagedAsmResourceBlob::allocateInferAlign(data));
349       },
350       "alignment mismatch between expected alignment and blob alignment");
351 }
352 
353 TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
354   MLIRContext context;
355   Builder builder(&context);
356 
357   // Create a bool attribute with incorrect type.
358   ArrayRef<bool> data;
359   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
360   EXPECT_DEBUG_DEATH(
361       {
362         DenseBoolResourceElementsAttr::get(
363             type, "resource",
364             UnmanagedAsmResourceBlob::allocateInferAlign(data));
365       },
366       "invalid shape element type for provided type `T`");
367 }
368 } // namespace
369 
370 //===----------------------------------------------------------------------===//
371 // SparseElementsAttr
372 //===----------------------------------------------------------------------===//
373 
374 namespace {
375 TEST(SparseElementsAttrTest, GetZero) {
376   MLIRContext context;
377   context.allowUnregisteredDialects();
378 
379   IntegerType intTy = IntegerType::get(&context, 32);
380   FloatType floatTy = FloatType::getF32(&context);
381   Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
382 
383   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
384   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
385   ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
386 
387   auto indicesType =
388       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
389   auto indices =
390       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
391 
392   RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
393   auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
394 
395   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
396   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
397 
398   RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
399   auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
400 
401   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
402   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
403   auto sparseString =
404       SparseElementsAttr::get(tensorString, indices, stringValue);
405 
406   // Only index (0, 0) contains an element, others are supposed to return
407   // the zero/empty value.
408   auto zeroIntValue =
409       sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>();
410   EXPECT_EQ(zeroIntValue.getInt(), 0);
411   EXPECT_TRUE(zeroIntValue.getType() == intTy);
412 
413   auto zeroFloatValue =
414       sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>();
415   EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
416   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
417 
418   auto zeroStringValue =
419       sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>();
420   EXPECT_TRUE(zeroStringValue.getValue().empty());
421   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
422 }
423 
424 } // namespace
425