xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision be0a7e9f27083ada6072fcc0711ffa5630daa5ec)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 template <typename EltTy>
17 static void testSplat(Type eltType, const EltTy &splatElt) {
18   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
19 
20   // Check that the generated splat is the same for 1 element and N elements.
21   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
22   EXPECT_TRUE(splat.isSplat());
23 
24   auto detectedSplat =
25       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
26   EXPECT_EQ(detectedSplat, splat);
27 
28   for (auto newValue : detectedSplat.template getValues<EltTy>())
29     EXPECT_TRUE(newValue == splatElt);
30 }
31 
32 namespace {
33 TEST(DenseSplatTest, BoolSplat) {
34   MLIRContext context;
35   IntegerType boolTy = IntegerType::get(&context, 1);
36   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
37 
38   // Check that splat is automatically detected for boolean values.
39   /// True.
40   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
41   EXPECT_TRUE(trueSplat.isSplat());
42   /// False.
43   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
44   EXPECT_TRUE(falseSplat.isSplat());
45   EXPECT_NE(falseSplat, trueSplat);
46 
47   /// Detect and handle splat within 8 elements (bool values are bit-packed).
48   /// True.
49   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
50   EXPECT_EQ(detectedSplat, trueSplat);
51   /// False.
52   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
53   EXPECT_EQ(detectedSplat, falseSplat);
54 }
55 
56 TEST(DenseSplatTest, LargeBoolSplat) {
57   constexpr int64_t boolCount = 56;
58 
59   MLIRContext context;
60   IntegerType boolTy = IntegerType::get(&context, 1);
61   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
62 
63   // Check that splat is automatically detected for boolean values.
64   /// True.
65   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
66   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
67   EXPECT_TRUE(trueSplat.isSplat());
68   EXPECT_TRUE(falseSplat.isSplat());
69 
70   /// Detect that the large boolean arrays are properly splatted.
71   /// True.
72   SmallVector<bool, 64> trueValues(boolCount, true);
73   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
74   EXPECT_EQ(detectedSplat, trueSplat);
75   /// False.
76   SmallVector<bool, 64> falseValues(boolCount, false);
77   detectedSplat = DenseElementsAttr::get(shape, falseValues);
78   EXPECT_EQ(detectedSplat, falseSplat);
79 }
80 
81 TEST(DenseSplatTest, BoolNonSplat) {
82   MLIRContext context;
83   IntegerType boolTy = IntegerType::get(&context, 1);
84   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
85 
86   // Check that we properly handle non-splat values.
87   DenseElementsAttr nonSplat =
88       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
89   EXPECT_FALSE(nonSplat.isSplat());
90 }
91 
92 TEST(DenseSplatTest, OddIntSplat) {
93   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
94   MLIRContext context;
95   constexpr size_t intWidth = 19;
96   IntegerType intTy = IntegerType::get(&context, intWidth);
97   APInt value(intWidth, 10);
98 
99   testSplat(intTy, value);
100 }
101 
102 TEST(DenseSplatTest, Int32Splat) {
103   MLIRContext context;
104   IntegerType intTy = IntegerType::get(&context, 32);
105   int value = 64;
106 
107   testSplat(intTy, value);
108 }
109 
110 TEST(DenseSplatTest, IntAttrSplat) {
111   MLIRContext context;
112   IntegerType intTy = IntegerType::get(&context, 85);
113   Attribute value = IntegerAttr::get(intTy, 109);
114 
115   testSplat(intTy, value);
116 }
117 
118 TEST(DenseSplatTest, F32Splat) {
119   MLIRContext context;
120   FloatType floatTy = FloatType::getF32(&context);
121   float value = 10.0;
122 
123   testSplat(floatTy, value);
124 }
125 
126 TEST(DenseSplatTest, F64Splat) {
127   MLIRContext context;
128   FloatType floatTy = FloatType::getF64(&context);
129   double value = 10.0;
130 
131   testSplat(floatTy, APFloat(value));
132 }
133 
134 TEST(DenseSplatTest, FloatAttrSplat) {
135   MLIRContext context;
136   FloatType floatTy = FloatType::getF32(&context);
137   Attribute value = FloatAttr::get(floatTy, 10.0);
138 
139   testSplat(floatTy, value);
140 }
141 
142 TEST(DenseSplatTest, BF16Splat) {
143   MLIRContext context;
144   FloatType floatTy = FloatType::getBF16(&context);
145   Attribute value = FloatAttr::get(floatTy, 10.0);
146 
147   testSplat(floatTy, value);
148 }
149 
150 TEST(DenseSplatTest, StringSplat) {
151   MLIRContext context;
152   context.allowUnregisteredDialects();
153   Type stringType =
154       OpaqueType::get(StringAttr::get(&context, "test"), "string");
155   StringRef value = "test-string";
156   testSplat(stringType, value);
157 }
158 
159 TEST(DenseSplatTest, StringAttrSplat) {
160   MLIRContext context;
161   context.allowUnregisteredDialects();
162   Type stringType =
163       OpaqueType::get(StringAttr::get(&context, "test"), "string");
164   Attribute stringAttr = StringAttr::get("test-string", stringType);
165   testSplat(stringType, stringAttr);
166 }
167 
168 TEST(DenseComplexTest, ComplexFloatSplat) {
169   MLIRContext context;
170   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
171   std::complex<float> value(10.0, 15.0);
172   testSplat(complexType, value);
173 }
174 
175 TEST(DenseComplexTest, ComplexIntSplat) {
176   MLIRContext context;
177   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
178   std::complex<int64_t> value(10, 15);
179   testSplat(complexType, value);
180 }
181 
182 TEST(DenseComplexTest, ComplexAPFloatSplat) {
183   MLIRContext context;
184   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
185   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
186   testSplat(complexType, value);
187 }
188 
189 TEST(DenseComplexTest, ComplexAPIntSplat) {
190   MLIRContext context;
191   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
192   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
193   testSplat(complexType, value);
194 }
195 
196 TEST(DenseScalarTest, ExtractZeroRankElement) {
197   MLIRContext context;
198   const int elementValue = 12;
199   IntegerType intTy = IntegerType::get(&context, 32);
200   Attribute value = IntegerAttr::get(intTy, elementValue);
201   RankedTensorType shape = RankedTensorType::get({}, intTy);
202 
203   auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
204   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
205 }
206 
207 TEST(SparseElementsAttrTest, GetZero) {
208   MLIRContext context;
209   context.allowUnregisteredDialects();
210 
211   IntegerType intTy = IntegerType::get(&context, 32);
212   FloatType floatTy = FloatType::getF32(&context);
213   Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
214 
215   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
216   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
217   ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
218 
219   auto indicesType =
220       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
221   auto indices =
222       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
223 
224   RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
225   auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
226 
227   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
228   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
229 
230   RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
231   auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
232 
233   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
234   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
235   auto sparseString =
236       SparseElementsAttr::get(tensorString, indices, stringValue);
237 
238   // Only index (0, 0) contains an element, others are supposed to return
239   // the zero/empty value.
240   auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}];
241   EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
242   EXPECT_TRUE(zeroIntValue.getType() == intTy);
243 
244   auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}];
245   EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
246   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
247 
248   auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}];
249   EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
250   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
251 }
252 
253 } // namespace
254