xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision f023da12d12635f5fba436e825cbfc999e28e623)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AsmState.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
14 #include <optional>
15 
16 #include "../../test/lib/Dialect/Test/TestDialect.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
21 //===----------------------------------------------------------------------===//
22 // DenseElementsAttr
23 //===----------------------------------------------------------------------===//
24 
25 template <typename EltTy>
26 static void testSplat(Type eltType, const EltTy &splatElt) {
27   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
28 
29   // Check that the generated splat is the same for 1 element and N elements.
30   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
31   EXPECT_TRUE(splat.isSplat());
32 
33   auto detectedSplat =
34       DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
35   EXPECT_EQ(detectedSplat, splat);
36 
37   for (auto newValue : detectedSplat.template getValues<EltTy>())
38     EXPECT_TRUE(newValue == splatElt);
39 }
40 
41 namespace {
42 TEST(DenseSplatTest, BoolSplat) {
43   MLIRContext context;
44   IntegerType boolTy = IntegerType::get(&context, 1);
45   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
46 
47   // Check that splat is automatically detected for boolean values.
48   /// True.
49   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
50   EXPECT_TRUE(trueSplat.isSplat());
51   /// False.
52   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
53   EXPECT_TRUE(falseSplat.isSplat());
54   EXPECT_NE(falseSplat, trueSplat);
55 
56   /// Detect and handle splat within 8 elements (bool values are bit-packed).
57   /// True.
58   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
59   EXPECT_EQ(detectedSplat, trueSplat);
60   /// False.
61   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
62   EXPECT_EQ(detectedSplat, falseSplat);
63 }
64 TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
65   MLIRContext context;
66   IntegerType boolTy = IntegerType::get(&context, 1);
67   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
68 
69   // Check that splat booleans properly round trip via the raw API.
70   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
71   EXPECT_TRUE(trueSplat.isSplat());
72   DenseElementsAttr trueSplatFromRaw =
73       DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
74   EXPECT_TRUE(trueSplatFromRaw.isSplat());
75 
76   EXPECT_EQ(trueSplat, trueSplatFromRaw);
77 }
78 
79 TEST(DenseSplatTest, BoolSplatSmall) {
80   MLIRContext context;
81   Builder builder(&context);
82 
83   // Check that splats that don't fill entire byte are handled properly.
84   auto tensorType = RankedTensorType::get({4}, builder.getI1Type());
85   std::vector<char> data{0b00001111};
86   auto trueSplatFromRaw =
87       DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data);
88   EXPECT_TRUE(trueSplatFromRaw.isSplat());
89   DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true);
90   EXPECT_EQ(trueSplat, trueSplatFromRaw);
91 }
92 
93 TEST(DenseSplatTest, LargeBoolSplat) {
94   constexpr int64_t boolCount = 56;
95 
96   MLIRContext context;
97   IntegerType boolTy = IntegerType::get(&context, 1);
98   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
99 
100   // Check that splat is automatically detected for boolean values.
101   /// True.
102   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
103   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
104   EXPECT_TRUE(trueSplat.isSplat());
105   EXPECT_TRUE(falseSplat.isSplat());
106 
107   /// Detect that the large boolean arrays are properly splatted.
108   /// True.
109   SmallVector<bool, 64> trueValues(boolCount, true);
110   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
111   EXPECT_EQ(detectedSplat, trueSplat);
112   /// False.
113   SmallVector<bool, 64> falseValues(boolCount, false);
114   detectedSplat = DenseElementsAttr::get(shape, falseValues);
115   EXPECT_EQ(detectedSplat, falseSplat);
116 }
117 
118 TEST(DenseSplatTest, BoolNonSplat) {
119   MLIRContext context;
120   IntegerType boolTy = IntegerType::get(&context, 1);
121   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
122 
123   // Check that we properly handle non-splat values.
124   DenseElementsAttr nonSplat =
125       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
126   EXPECT_FALSE(nonSplat.isSplat());
127 }
128 
129 TEST(DenseSplatTest, OddIntSplat) {
130   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
131   MLIRContext context;
132   constexpr size_t intWidth = 19;
133   IntegerType intTy = IntegerType::get(&context, intWidth);
134   APInt value(intWidth, 10);
135 
136   testSplat(intTy, value);
137 }
138 
139 TEST(DenseSplatTest, Int32Splat) {
140   MLIRContext context;
141   IntegerType intTy = IntegerType::get(&context, 32);
142   int value = 64;
143 
144   testSplat(intTy, value);
145 }
146 
147 TEST(DenseSplatTest, IntAttrSplat) {
148   MLIRContext context;
149   IntegerType intTy = IntegerType::get(&context, 85);
150   Attribute value = IntegerAttr::get(intTy, 109);
151 
152   testSplat(intTy, value);
153 }
154 
155 TEST(DenseSplatTest, F32Splat) {
156   MLIRContext context;
157   FloatType floatTy = Float32Type::get(&context);
158   float value = 10.0;
159 
160   testSplat(floatTy, value);
161 }
162 
163 TEST(DenseSplatTest, F64Splat) {
164   MLIRContext context;
165   FloatType floatTy = Float64Type::get(&context);
166   double value = 10.0;
167 
168   testSplat(floatTy, APFloat(value));
169 }
170 
171 TEST(DenseSplatTest, FloatAttrSplat) {
172   MLIRContext context;
173   FloatType floatTy = Float32Type::get(&context);
174   Attribute value = FloatAttr::get(floatTy, 10.0);
175 
176   testSplat(floatTy, value);
177 }
178 
179 TEST(DenseSplatTest, BF16Splat) {
180   MLIRContext context;
181   FloatType floatTy = BFloat16Type::get(&context);
182   Attribute value = FloatAttr::get(floatTy, 10.0);
183 
184   testSplat(floatTy, value);
185 }
186 
187 TEST(DenseSplatTest, StringSplat) {
188   MLIRContext context;
189   context.allowUnregisteredDialects();
190   Type stringType =
191       OpaqueType::get(StringAttr::get(&context, "test"), "string");
192   StringRef value = "test-string";
193   testSplat(stringType, value);
194 }
195 
196 TEST(DenseSplatTest, StringAttrSplat) {
197   MLIRContext context;
198   context.allowUnregisteredDialects();
199   Type stringType =
200       OpaqueType::get(StringAttr::get(&context, "test"), "string");
201   Attribute stringAttr = StringAttr::get("test-string", stringType);
202   testSplat(stringType, stringAttr);
203 }
204 
205 TEST(DenseComplexTest, ComplexFloatSplat) {
206   MLIRContext context;
207   ComplexType complexType = ComplexType::get(Float32Type::get(&context));
208   std::complex<float> value(10.0, 15.0);
209   testSplat(complexType, value);
210 }
211 
212 TEST(DenseComplexTest, ComplexIntSplat) {
213   MLIRContext context;
214   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
215   std::complex<int64_t> value(10, 15);
216   testSplat(complexType, value);
217 }
218 
219 TEST(DenseComplexTest, ComplexAPFloatSplat) {
220   MLIRContext context;
221   ComplexType complexType = ComplexType::get(Float32Type::get(&context));
222   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
223   testSplat(complexType, value);
224 }
225 
226 TEST(DenseComplexTest, ComplexAPIntSplat) {
227   MLIRContext context;
228   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
229   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
230   testSplat(complexType, value);
231 }
232 
233 TEST(DenseScalarTest, ExtractZeroRankElement) {
234   MLIRContext context;
235   const int elementValue = 12;
236   IntegerType intTy = IntegerType::get(&context, 32);
237   Attribute value = IntegerAttr::get(intTy, elementValue);
238   RankedTensorType shape = RankedTensorType::get({}, intTy);
239 
240   auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}));
241   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
242 }
243 
244 TEST(DenseSplatMapValuesTest, I32ToTrue) {
245   MLIRContext context;
246   const int elementValue = 12;
247   IntegerType boolTy = IntegerType::get(&context, 1);
248   IntegerType intTy = IntegerType::get(&context, 32);
249   RankedTensorType shape = RankedTensorType::get({4}, intTy);
250 
251   auto attr =
252       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
253           .mapValues(boolTy, [](const APInt &x) {
254             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
255           });
256   EXPECT_EQ(attr.getNumElements(), 4);
257   EXPECT_TRUE(attr.isSplat());
258   EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
259 }
260 
261 TEST(DenseSplatMapValuesTest, I32ToFalse) {
262   MLIRContext context;
263   const int elementValue = 0;
264   IntegerType boolTy = IntegerType::get(&context, 1);
265   IntegerType intTy = IntegerType::get(&context, 32);
266   RankedTensorType shape = RankedTensorType::get({4}, intTy);
267 
268   auto attr =
269       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
270           .mapValues(boolTy, [](const APInt &x) {
271             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
272           });
273   EXPECT_EQ(attr.getNumElements(), 4);
274   EXPECT_TRUE(attr.isSplat());
275   EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
276 }
277 } // namespace
278 
279 //===----------------------------------------------------------------------===//
280 // DenseResourceElementsAttr
281 //===----------------------------------------------------------------------===//
282 
283 template <typename AttrT, typename T>
284 static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
285                               Type elementType) {
286   auto type = RankedTensorType::get(data.size(), elementType);
287   auto attr = AttrT::get(type, "resource",
288                          UnmanagedAsmResourceBlob::allocateInferAlign(data));
289 
290   // Check that we can access and iterate the data properly.
291   std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
292   EXPECT_TRUE(attrData.has_value());
293   EXPECT_EQ(*attrData, data);
294 
295   // Check that we cast to this attribute when possible.
296   Attribute genericAttr = attr;
297   EXPECT_TRUE(isa<AttrT>(genericAttr));
298 }
299 template <typename AttrT, typename T>
300 static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
301   T data[] = {0, 1, 2};
302   checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data),
303                               builder.getIntegerType(intWidth));
304 }
305 
306 namespace {
307 TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
308   MLIRContext context;
309   Builder builder(&context);
310 
311   // Bool
312   bool boolData[] = {true, false, true};
313   checkNativeAccess<DenseBoolResourceElementsAttr>(
314       &context, llvm::ArrayRef(boolData), builder.getI1Type());
315 
316   // Unsigned integers
317   checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
318   checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
319   checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
320   checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
321 
322   // Signed integers
323   checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
324   checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
325   checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
326   checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
327 
328   // Float
329   float floatData[] = {0, 1, 2};
330   checkNativeAccess<DenseF32ResourceElementsAttr>(
331       &context, llvm::ArrayRef(floatData), builder.getF32Type());
332 
333   // Double
334   double doubleData[] = {0, 1, 2};
335   checkNativeAccess<DenseF64ResourceElementsAttr>(
336       &context, llvm::ArrayRef(doubleData), builder.getF64Type());
337 }
338 
339 TEST(DenseResourceElementsAttrTest, CheckNoCast) {
340   MLIRContext context;
341   Builder builder(&context);
342 
343   // Create a i32 attribute.
344   ArrayRef<uint32_t> data;
345   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
346   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
347       type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
348 
349   EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr));
350   EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr));
351   EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr));
352 }
353 
354 TEST(DenseResourceElementsAttrTest, CheckNotMutableAllocateAndCopy) {
355   MLIRContext context;
356   Builder builder(&context);
357 
358   // Create a i32 attribute.
359   std::vector<int32_t> data = {10, 20, 30};
360   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
361   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
362       type, "resource",
363       HeapAsmResourceBlob::allocateAndCopyInferAlign<int32_t>(
364           data, /*is_mutable=*/false));
365 
366   EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr));
367 }
368 
369 TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
370   MLIRContext context;
371   Builder builder(&context);
372 
373   // Create a bool attribute with data of the incorrect type.
374   ArrayRef<uint32_t> data;
375   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
376   EXPECT_DEBUG_DEATH(
377       {
378         DenseBoolResourceElementsAttr::get(
379             type, "resource",
380             UnmanagedAsmResourceBlob::allocateInferAlign(data));
381       },
382       "alignment mismatch between expected alignment and blob alignment");
383 }
384 
385 TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
386   MLIRContext context;
387   Builder builder(&context);
388 
389   // Create a bool attribute with incorrect type.
390   ArrayRef<bool> data;
391   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
392   EXPECT_DEBUG_DEATH(
393       {
394         DenseBoolResourceElementsAttr::get(
395             type, "resource",
396             UnmanagedAsmResourceBlob::allocateInferAlign(data));
397       },
398       "invalid shape element type for provided type `T`");
399 }
400 } // namespace
401 
402 //===----------------------------------------------------------------------===//
403 // SparseElementsAttr
404 //===----------------------------------------------------------------------===//
405 
406 namespace {
407 TEST(SparseElementsAttrTest, GetZero) {
408   MLIRContext context;
409   context.allowUnregisteredDialects();
410 
411   IntegerType intTy = IntegerType::get(&context, 32);
412   FloatType floatTy = Float32Type::get(&context);
413   Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
414 
415   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
416   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
417   ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
418 
419   auto indicesType =
420       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
421   auto indices =
422       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
423 
424   RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
425   auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
426 
427   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
428   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
429 
430   RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
431   auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
432 
433   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
434   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
435   auto sparseString =
436       SparseElementsAttr::get(tensorString, indices, stringValue);
437 
438   // Only index (0, 0) contains an element, others are supposed to return
439   // the zero/empty value.
440   auto zeroIntValue =
441       cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]);
442   EXPECT_EQ(zeroIntValue.getInt(), 0);
443   EXPECT_TRUE(zeroIntValue.getType() == intTy);
444 
445   auto zeroFloatValue =
446       cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
447   EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
448   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
449 
450   auto zeroStringValue =
451       cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
452   EXPECT_TRUE(zeroStringValue.empty());
453   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // SubElements
458 //===----------------------------------------------------------------------===//
459 
460 TEST(SubElementTest, Nested) {
461   MLIRContext context;
462   Builder builder(&context);
463 
464   BoolAttr trueAttr = builder.getBoolAttr(true);
465   BoolAttr falseAttr = builder.getBoolAttr(false);
466   ArrayAttr boolArrayAttr =
467       builder.getArrayAttr({trueAttr, falseAttr, trueAttr});
468   StringAttr strAttr = builder.getStringAttr("array");
469   DictionaryAttr dictAttr =
470       builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
471 
472   SmallVector<Attribute> subAttrs;
473   dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); });
474   // Note that trueAttr appears only once, identical subattributes are skipped.
475   EXPECT_EQ(llvm::ArrayRef(subAttrs),
476             ArrayRef<Attribute>(
477                 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
478 }
479 
480 // Test how many times we call copy-ctor when building an attribute.
481 TEST(CopyCountAttr, CopyCount) {
482   MLIRContext context;
483   context.loadDialect<test::TestDialect>();
484 
485   test::CopyCount::counter = 0;
486   test::CopyCount copyCount("hello");
487   test::TestCopyCountAttr::get(&context, std::move(copyCount));
488   int counter1 = test::CopyCount::counter;
489   test::CopyCount::counter = 0;
490   test::TestCopyCountAttr::get(&context, std::move(copyCount));
491 #ifndef NDEBUG
492   // One verification enabled only in assert-mode requires a copy.
493   EXPECT_EQ(counter1, 1);
494   EXPECT_EQ(test::CopyCount::counter, 1);
495 #else
496   EXPECT_EQ(counter1, 0);
497   EXPECT_EQ(test::CopyCount::counter, 0);
498 #endif
499 }
500 
501 // Test stripped printing using test dialect attribute.
502 TEST(CopyCountAttr, PrintStripped) {
503   MLIRContext context;
504   context.loadDialect<test::TestDialect>();
505   // Doesn't matter which dialect attribute is used, just chose TestCopyCount
506   // given proximity.
507   test::CopyCount::counter = 0;
508   test::CopyCount copyCount("hello");
509   Attribute res = test::TestCopyCountAttr::get(&context, std::move(copyCount));
510 
511   std::string str;
512   llvm::raw_string_ostream os(str);
513   os << "|" << res << "|";
514   res.printStripped(os << "[");
515   os << "]";
516   EXPECT_EQ(str, "|#test.copy_count<hello>|[copy_count<hello>]");
517 }
518 
519 } // namespace
520