xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision f023da12d12635f5fba436e825cbfc999e28e623)
1d8cd96bcSRiver Riddle //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2d8cd96bcSRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d8cd96bcSRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8d8cd96bcSRiver Riddle 
9995ab929SRiver Riddle #include "mlir/IR/AsmState.h"
10995ab929SRiver Riddle #include "mlir/IR/Builders.h"
11c7cae0e4SRiver Riddle #include "mlir/IR/BuiltinAttributes.h"
1209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
13d8cd96bcSRiver Riddle #include "gtest/gtest.h"
14a1fe1f5fSKazu Hirata #include <optional>
15d8cd96bcSRiver Riddle 
165fc28e7aSMehdi Amini #include "../../test/lib/Dialect/Test/TestDialect.h"
175fc28e7aSMehdi Amini 
18d8cd96bcSRiver Riddle using namespace mlir;
19d8cd96bcSRiver Riddle using namespace mlir::detail;
20d8cd96bcSRiver Riddle 
21995ab929SRiver Riddle //===----------------------------------------------------------------------===//
22995ab929SRiver Riddle // DenseElementsAttr
23995ab929SRiver Riddle //===----------------------------------------------------------------------===//
24995ab929SRiver Riddle 
25d8cd96bcSRiver Riddle template <typename EltTy>
26d8cd96bcSRiver Riddle static void testSplat(Type eltType, const EltTy &splatElt) {
27910fff1cSRiver Riddle   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
28d8cd96bcSRiver Riddle 
29d8cd96bcSRiver Riddle   // Check that the generated splat is the same for 1 element and N elements.
30d8cd96bcSRiver Riddle   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
31d8cd96bcSRiver Riddle   EXPECT_TRUE(splat.isSplat());
32d8cd96bcSRiver Riddle 
33d8cd96bcSRiver Riddle   auto detectedSplat =
34984b800aSserge-sans-paille       DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
35d8cd96bcSRiver Riddle   EXPECT_EQ(detectedSplat, splat);
36da2a6f4eSRiver Riddle 
37da2a6f4eSRiver Riddle   for (auto newValue : detectedSplat.template getValues<EltTy>())
3824ad3858SRiver Riddle     EXPECT_TRUE(newValue == splatElt);
39d8cd96bcSRiver Riddle }
40d8cd96bcSRiver Riddle 
41d8cd96bcSRiver Riddle namespace {
42d8cd96bcSRiver Riddle TEST(DenseSplatTest, BoolSplat) {
43e7021232SMehdi Amini   MLIRContext context;
441b97cdf8SRiver Riddle   IntegerType boolTy = IntegerType::get(&context, 1);
45910fff1cSRiver Riddle   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
46d8cd96bcSRiver Riddle 
47d8cd96bcSRiver Riddle   // Check that splat is automatically detected for boolean values.
48d8cd96bcSRiver Riddle   /// True.
495624bc28SRiver Riddle   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
50d8cd96bcSRiver Riddle   EXPECT_TRUE(trueSplat.isSplat());
51d8cd96bcSRiver Riddle   /// False.
525624bc28SRiver Riddle   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
53d8cd96bcSRiver Riddle   EXPECT_TRUE(falseSplat.isSplat());
54d8cd96bcSRiver Riddle   EXPECT_NE(falseSplat, trueSplat);
55d8cd96bcSRiver Riddle 
56d8cd96bcSRiver Riddle   /// Detect and handle splat within 8 elements (bool values are bit-packed).
57d8cd96bcSRiver Riddle   /// True.
585624bc28SRiver Riddle   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
59d8cd96bcSRiver Riddle   EXPECT_EQ(detectedSplat, trueSplat);
60d8cd96bcSRiver Riddle   /// False.
615624bc28SRiver Riddle   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
62d8cd96bcSRiver Riddle   EXPECT_EQ(detectedSplat, falseSplat);
63d8cd96bcSRiver Riddle }
649e0900cbSRiver Riddle TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
659e0900cbSRiver Riddle   MLIRContext context;
669e0900cbSRiver Riddle   IntegerType boolTy = IntegerType::get(&context, 1);
679e0900cbSRiver Riddle   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
689e0900cbSRiver Riddle 
699e0900cbSRiver Riddle   // Check that splat booleans properly round trip via the raw API.
709e0900cbSRiver Riddle   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
719e0900cbSRiver Riddle   EXPECT_TRUE(trueSplat.isSplat());
729e0900cbSRiver Riddle   DenseElementsAttr trueSplatFromRaw =
739e0900cbSRiver Riddle       DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
749e0900cbSRiver Riddle   EXPECT_TRUE(trueSplatFromRaw.isSplat());
759e0900cbSRiver Riddle 
769e0900cbSRiver Riddle   EXPECT_EQ(trueSplat, trueSplatFromRaw);
779e0900cbSRiver Riddle }
78d8cd96bcSRiver Riddle 
79998003eaSKevin Gleason TEST(DenseSplatTest, BoolSplatSmall) {
80998003eaSKevin Gleason   MLIRContext context;
81998003eaSKevin Gleason   Builder builder(&context);
82998003eaSKevin Gleason 
83998003eaSKevin Gleason   // Check that splats that don't fill entire byte are handled properly.
84998003eaSKevin Gleason   auto tensorType = RankedTensorType::get({4}, builder.getI1Type());
85998003eaSKevin Gleason   std::vector<char> data{0b00001111};
86998003eaSKevin Gleason   auto trueSplatFromRaw =
87998003eaSKevin Gleason       DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data);
88998003eaSKevin Gleason   EXPECT_TRUE(trueSplatFromRaw.isSplat());
89998003eaSKevin Gleason   DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true);
90998003eaSKevin Gleason   EXPECT_EQ(trueSplat, trueSplatFromRaw);
91998003eaSKevin Gleason }
92998003eaSKevin Gleason 
93d8cd96bcSRiver Riddle TEST(DenseSplatTest, LargeBoolSplat) {
942c926912SRiver Riddle   constexpr int64_t boolCount = 56;
95d8cd96bcSRiver Riddle 
96e7021232SMehdi Amini   MLIRContext context;
971b97cdf8SRiver Riddle   IntegerType boolTy = IntegerType::get(&context, 1);
98910fff1cSRiver Riddle   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
99d8cd96bcSRiver Riddle 
100d8cd96bcSRiver Riddle   // Check that splat is automatically detected for boolean values.
101d8cd96bcSRiver Riddle   /// True.
102d8cd96bcSRiver Riddle   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
103d8cd96bcSRiver Riddle   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
104d8cd96bcSRiver Riddle   EXPECT_TRUE(trueSplat.isSplat());
105d8cd96bcSRiver Riddle   EXPECT_TRUE(falseSplat.isSplat());
106d8cd96bcSRiver Riddle 
107d8cd96bcSRiver Riddle   /// Detect that the large boolean arrays are properly splatted.
108d8cd96bcSRiver Riddle   /// True.
109d8cd96bcSRiver Riddle   SmallVector<bool, 64> trueValues(boolCount, true);
110d8cd96bcSRiver Riddle   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
111d8cd96bcSRiver Riddle   EXPECT_EQ(detectedSplat, trueSplat);
112d8cd96bcSRiver Riddle   /// False.
113d8cd96bcSRiver Riddle   SmallVector<bool, 64> falseValues(boolCount, false);
114d8cd96bcSRiver Riddle   detectedSplat = DenseElementsAttr::get(shape, falseValues);
115d8cd96bcSRiver Riddle   EXPECT_EQ(detectedSplat, falseSplat);
116d8cd96bcSRiver Riddle }
117d8cd96bcSRiver Riddle 
1182b67821bSRiver Riddle TEST(DenseSplatTest, BoolNonSplat) {
119e7021232SMehdi Amini   MLIRContext context;
1201b97cdf8SRiver Riddle   IntegerType boolTy = IntegerType::get(&context, 1);
121910fff1cSRiver Riddle   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
1222b67821bSRiver Riddle 
1232b67821bSRiver Riddle   // Check that we properly handle non-splat values.
1242b67821bSRiver Riddle   DenseElementsAttr nonSplat =
1252b67821bSRiver Riddle       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
1262b67821bSRiver Riddle   EXPECT_FALSE(nonSplat.isSplat());
1272b67821bSRiver Riddle }
1282b67821bSRiver Riddle 
129d8cd96bcSRiver Riddle TEST(DenseSplatTest, OddIntSplat) {
130d8cd96bcSRiver Riddle   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
131e7021232SMehdi Amini   MLIRContext context;
132d8cd96bcSRiver Riddle   constexpr size_t intWidth = 19;
1331b97cdf8SRiver Riddle   IntegerType intTy = IntegerType::get(&context, intWidth);
134d8cd96bcSRiver Riddle   APInt value(intWidth, 10);
135d8cd96bcSRiver Riddle 
136d8cd96bcSRiver Riddle   testSplat(intTy, value);
137d8cd96bcSRiver Riddle }
138d8cd96bcSRiver Riddle 
139d8cd96bcSRiver Riddle TEST(DenseSplatTest, Int32Splat) {
140e7021232SMehdi Amini   MLIRContext context;
1411b97cdf8SRiver Riddle   IntegerType intTy = IntegerType::get(&context, 32);
142d8cd96bcSRiver Riddle   int value = 64;
143d8cd96bcSRiver Riddle 
144d8cd96bcSRiver Riddle   testSplat(intTy, value);
145d8cd96bcSRiver Riddle }
146d8cd96bcSRiver Riddle 
147d8cd96bcSRiver Riddle TEST(DenseSplatTest, IntAttrSplat) {
148e7021232SMehdi Amini   MLIRContext context;
1491b97cdf8SRiver Riddle   IntegerType intTy = IntegerType::get(&context, 85);
150d8cd96bcSRiver Riddle   Attribute value = IntegerAttr::get(intTy, 109);
151d8cd96bcSRiver Riddle 
152d8cd96bcSRiver Riddle   testSplat(intTy, value);
153d8cd96bcSRiver Riddle }
154d8cd96bcSRiver Riddle 
155d8cd96bcSRiver Riddle TEST(DenseSplatTest, F32Splat) {
156e7021232SMehdi Amini   MLIRContext context;
157*f023da12SMatthias Springer   FloatType floatTy = Float32Type::get(&context);
158d8cd96bcSRiver Riddle   float value = 10.0;
159d8cd96bcSRiver Riddle 
160d8cd96bcSRiver Riddle   testSplat(floatTy, value);
161d8cd96bcSRiver Riddle }
162d8cd96bcSRiver Riddle 
163d8cd96bcSRiver Riddle TEST(DenseSplatTest, F64Splat) {
164e7021232SMehdi Amini   MLIRContext context;
165*f023da12SMatthias Springer   FloatType floatTy = Float64Type::get(&context);
166d8cd96bcSRiver Riddle   double value = 10.0;
167d8cd96bcSRiver Riddle 
168d8cd96bcSRiver Riddle   testSplat(floatTy, APFloat(value));
169d8cd96bcSRiver Riddle }
170d8cd96bcSRiver Riddle 
171d8cd96bcSRiver Riddle TEST(DenseSplatTest, FloatAttrSplat) {
172e7021232SMehdi Amini   MLIRContext context;
173*f023da12SMatthias Springer   FloatType floatTy = Float32Type::get(&context);
174d8cd96bcSRiver Riddle   Attribute value = FloatAttr::get(floatTy, 10.0);
175d8cd96bcSRiver Riddle 
176d8cd96bcSRiver Riddle   testSplat(floatTy, value);
177d8cd96bcSRiver Riddle }
17868c8b6c4SRiver Riddle 
17968c8b6c4SRiver Riddle TEST(DenseSplatTest, BF16Splat) {
180e7021232SMehdi Amini   MLIRContext context;
181*f023da12SMatthias Springer   FloatType floatTy = BFloat16Type::get(&context);
1827d59f49bSDiego Caballero   Attribute value = FloatAttr::get(floatTy, 10.0);
18368c8b6c4SRiver Riddle 
18468c8b6c4SRiver Riddle   testSplat(floatTy, value);
18568c8b6c4SRiver Riddle }
18668c8b6c4SRiver Riddle 
187910fff1cSRiver Riddle TEST(DenseSplatTest, StringSplat) {
188e7021232SMehdi Amini   MLIRContext context;
189109305e1SRiver Riddle   context.allowUnregisteredDialects();
190910fff1cSRiver Riddle   Type stringType =
191195730a6SRiver Riddle       OpaqueType::get(StringAttr::get(&context, "test"), "string");
192910fff1cSRiver Riddle   StringRef value = "test-string";
193910fff1cSRiver Riddle   testSplat(stringType, value);
194910fff1cSRiver Riddle }
195910fff1cSRiver Riddle 
1960d5caa89SRiver Riddle TEST(DenseSplatTest, StringAttrSplat) {
197e7021232SMehdi Amini   MLIRContext context;
198109305e1SRiver Riddle   context.allowUnregisteredDialects();
1990d5caa89SRiver Riddle   Type stringType =
200195730a6SRiver Riddle       OpaqueType::get(StringAttr::get(&context, "test"), "string");
2010d5caa89SRiver Riddle   Attribute stringAttr = StringAttr::get("test-string", stringType);
2020d5caa89SRiver Riddle   testSplat(stringType, stringAttr);
2030d5caa89SRiver Riddle }
2040d5caa89SRiver Riddle 
205da2a6f4eSRiver Riddle TEST(DenseComplexTest, ComplexFloatSplat) {
206e7021232SMehdi Amini   MLIRContext context;
207*f023da12SMatthias Springer   ComplexType complexType = ComplexType::get(Float32Type::get(&context));
208da2a6f4eSRiver Riddle   std::complex<float> value(10.0, 15.0);
209da2a6f4eSRiver Riddle   testSplat(complexType, value);
210da2a6f4eSRiver Riddle }
211da2a6f4eSRiver Riddle 
212da2a6f4eSRiver Riddle TEST(DenseComplexTest, ComplexIntSplat) {
213e7021232SMehdi Amini   MLIRContext context;
2141b97cdf8SRiver Riddle   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
215da2a6f4eSRiver Riddle   std::complex<int64_t> value(10, 15);
216da2a6f4eSRiver Riddle   testSplat(complexType, value);
217da2a6f4eSRiver Riddle }
218da2a6f4eSRiver Riddle 
21924ad3858SRiver Riddle TEST(DenseComplexTest, ComplexAPFloatSplat) {
220e7021232SMehdi Amini   MLIRContext context;
221*f023da12SMatthias Springer   ComplexType complexType = ComplexType::get(Float32Type::get(&context));
22224ad3858SRiver Riddle   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
22324ad3858SRiver Riddle   testSplat(complexType, value);
22424ad3858SRiver Riddle }
22524ad3858SRiver Riddle 
22624ad3858SRiver Riddle TEST(DenseComplexTest, ComplexAPIntSplat) {
227e7021232SMehdi Amini   MLIRContext context;
2281b97cdf8SRiver Riddle   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
22924ad3858SRiver Riddle   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
23024ad3858SRiver Riddle   testSplat(complexType, value);
23124ad3858SRiver Riddle }
23224ad3858SRiver Riddle 
2330af25275Skarimnosseir TEST(DenseScalarTest, ExtractZeroRankElement) {
2340af25275Skarimnosseir   MLIRContext context;
2350af25275Skarimnosseir   const int elementValue = 12;
2360af25275Skarimnosseir   IntegerType intTy = IntegerType::get(&context, 32);
2370af25275Skarimnosseir   Attribute value = IntegerAttr::get(intTy, elementValue);
2380af25275Skarimnosseir   RankedTensorType shape = RankedTensorType::get({}, intTy);
2390af25275Skarimnosseir 
240984b800aSserge-sans-paille   auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}));
241ae40d625SRiver Riddle   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
2420af25275Skarimnosseir }
243ae60a4a0SChenguang Wang 
244ae60a4a0SChenguang Wang TEST(DenseSplatMapValuesTest, I32ToTrue) {
245ae60a4a0SChenguang Wang   MLIRContext context;
246ae60a4a0SChenguang Wang   const int elementValue = 12;
247ae60a4a0SChenguang Wang   IntegerType boolTy = IntegerType::get(&context, 1);
248ae60a4a0SChenguang Wang   IntegerType intTy = IntegerType::get(&context, 32);
249ae60a4a0SChenguang Wang   RankedTensorType shape = RankedTensorType::get({4}, intTy);
250ae60a4a0SChenguang Wang 
251ae60a4a0SChenguang Wang   auto attr =
252984b800aSserge-sans-paille       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
253ae60a4a0SChenguang Wang           .mapValues(boolTy, [](const APInt &x) {
254ae60a4a0SChenguang Wang             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
255ae60a4a0SChenguang Wang           });
256ae60a4a0SChenguang Wang   EXPECT_EQ(attr.getNumElements(), 4);
257ae60a4a0SChenguang Wang   EXPECT_TRUE(attr.isSplat());
258ae60a4a0SChenguang Wang   EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
259ae60a4a0SChenguang Wang }
260ae60a4a0SChenguang Wang 
261ae60a4a0SChenguang Wang TEST(DenseSplatMapValuesTest, I32ToFalse) {
262ae60a4a0SChenguang Wang   MLIRContext context;
263ae60a4a0SChenguang Wang   const int elementValue = 0;
264ae60a4a0SChenguang Wang   IntegerType boolTy = IntegerType::get(&context, 1);
265ae60a4a0SChenguang Wang   IntegerType intTy = IntegerType::get(&context, 32);
266ae60a4a0SChenguang Wang   RankedTensorType shape = RankedTensorType::get({4}, intTy);
267ae60a4a0SChenguang Wang 
268ae60a4a0SChenguang Wang   auto attr =
269984b800aSserge-sans-paille       DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
270ae60a4a0SChenguang Wang           .mapValues(boolTy, [](const APInt &x) {
271ae60a4a0SChenguang Wang             return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
272ae60a4a0SChenguang Wang           });
273ae60a4a0SChenguang Wang   EXPECT_EQ(attr.getNumElements(), 4);
274ae60a4a0SChenguang Wang   EXPECT_TRUE(attr.isSplat());
275ae60a4a0SChenguang Wang   EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
276ae60a4a0SChenguang Wang }
277995ab929SRiver Riddle } // namespace
2780af25275Skarimnosseir 
279995ab929SRiver Riddle //===----------------------------------------------------------------------===//
280995ab929SRiver Riddle // DenseResourceElementsAttr
281995ab929SRiver Riddle //===----------------------------------------------------------------------===//
282995ab929SRiver Riddle 
283995ab929SRiver Riddle template <typename AttrT, typename T>
284995ab929SRiver Riddle static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
285995ab929SRiver Riddle                               Type elementType) {
286995ab929SRiver Riddle   auto type = RankedTensorType::get(data.size(), elementType);
2878847d9a2SRainer Orth   auto attr = AttrT::get(type, "resource",
2888847d9a2SRainer Orth                          UnmanagedAsmResourceBlob::allocateInferAlign(data));
289995ab929SRiver Riddle 
290995ab929SRiver Riddle   // Check that we can access and iterate the data properly.
2910a81ace0SKazu Hirata   std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
2929750648cSKazu Hirata   EXPECT_TRUE(attrData.has_value());
293995ab929SRiver Riddle   EXPECT_EQ(*attrData, data);
294995ab929SRiver Riddle 
295995ab929SRiver Riddle   // Check that we cast to this attribute when possible.
296995ab929SRiver Riddle   Attribute genericAttr = attr;
2975550c821STres Popp   EXPECT_TRUE(isa<AttrT>(genericAttr));
298995ab929SRiver Riddle }
299995ab929SRiver Riddle template <typename AttrT, typename T>
300995ab929SRiver Riddle static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
301995ab929SRiver Riddle   T data[] = {0, 1, 2};
302984b800aSserge-sans-paille   checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data),
303995ab929SRiver Riddle                               builder.getIntegerType(intWidth));
304995ab929SRiver Riddle }
305995ab929SRiver Riddle 
306995ab929SRiver Riddle namespace {
307995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
308995ab929SRiver Riddle   MLIRContext context;
309995ab929SRiver Riddle   Builder builder(&context);
310995ab929SRiver Riddle 
311995ab929SRiver Riddle   // Bool
312995ab929SRiver Riddle   bool boolData[] = {true, false, true};
313995ab929SRiver Riddle   checkNativeAccess<DenseBoolResourceElementsAttr>(
314984b800aSserge-sans-paille       &context, llvm::ArrayRef(boolData), builder.getI1Type());
315995ab929SRiver Riddle 
316995ab929SRiver Riddle   // Unsigned integers
317995ab929SRiver Riddle   checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
318995ab929SRiver Riddle   checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
319995ab929SRiver Riddle   checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
320995ab929SRiver Riddle   checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
321995ab929SRiver Riddle 
322995ab929SRiver Riddle   // Signed integers
323995ab929SRiver Riddle   checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
324995ab929SRiver Riddle   checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
325995ab929SRiver Riddle   checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
326995ab929SRiver Riddle   checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
327995ab929SRiver Riddle 
328995ab929SRiver Riddle   // Float
329995ab929SRiver Riddle   float floatData[] = {0, 1, 2};
330995ab929SRiver Riddle   checkNativeAccess<DenseF32ResourceElementsAttr>(
331984b800aSserge-sans-paille       &context, llvm::ArrayRef(floatData), builder.getF32Type());
332995ab929SRiver Riddle 
333995ab929SRiver Riddle   // Double
334995ab929SRiver Riddle   double doubleData[] = {0, 1, 2};
335995ab929SRiver Riddle   checkNativeAccess<DenseF64ResourceElementsAttr>(
336984b800aSserge-sans-paille       &context, llvm::ArrayRef(doubleData), builder.getF64Type());
337995ab929SRiver Riddle }
338995ab929SRiver Riddle 
339995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckNoCast) {
340995ab929SRiver Riddle   MLIRContext context;
341995ab929SRiver Riddle   Builder builder(&context);
342995ab929SRiver Riddle 
343995ab929SRiver Riddle   // Create a i32 attribute.
344995ab929SRiver Riddle   ArrayRef<uint32_t> data;
345995ab929SRiver Riddle   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
346995ab929SRiver Riddle   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
3478847d9a2SRainer Orth       type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
348995ab929SRiver Riddle 
3495550c821STres Popp   EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr));
3505550c821STres Popp   EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr));
3515550c821STres Popp   EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr));
352995ab929SRiver Riddle }
353995ab929SRiver Riddle 
354b96ebee1SJacques Pienaar TEST(DenseResourceElementsAttrTest, CheckNotMutableAllocateAndCopy) {
355b96ebee1SJacques Pienaar   MLIRContext context;
356b96ebee1SJacques Pienaar   Builder builder(&context);
357b96ebee1SJacques Pienaar 
358b96ebee1SJacques Pienaar   // Create a i32 attribute.
359b96ebee1SJacques Pienaar   std::vector<int32_t> data = {10, 20, 30};
360b96ebee1SJacques Pienaar   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
361b96ebee1SJacques Pienaar   Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
362b96ebee1SJacques Pienaar       type, "resource",
363b96ebee1SJacques Pienaar       HeapAsmResourceBlob::allocateAndCopyInferAlign<int32_t>(
364b96ebee1SJacques Pienaar           data, /*is_mutable=*/false));
365b96ebee1SJacques Pienaar 
366b96ebee1SJacques Pienaar   EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr));
367b96ebee1SJacques Pienaar }
368b96ebee1SJacques Pienaar 
369995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
370995ab929SRiver Riddle   MLIRContext context;
371995ab929SRiver Riddle   Builder builder(&context);
372995ab929SRiver Riddle 
373995ab929SRiver Riddle   // Create a bool attribute with data of the incorrect type.
374995ab929SRiver Riddle   ArrayRef<uint32_t> data;
375995ab929SRiver Riddle   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
37609ca1c06SStephan Herhut   EXPECT_DEBUG_DEATH(
377995ab929SRiver Riddle       {
378995ab929SRiver Riddle         DenseBoolResourceElementsAttr::get(
3798847d9a2SRainer Orth             type, "resource",
3808847d9a2SRainer Orth             UnmanagedAsmResourceBlob::allocateInferAlign(data));
381995ab929SRiver Riddle       },
382995ab929SRiver Riddle       "alignment mismatch between expected alignment and blob alignment");
383995ab929SRiver Riddle }
384995ab929SRiver Riddle 
385995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
386995ab929SRiver Riddle   MLIRContext context;
387995ab929SRiver Riddle   Builder builder(&context);
388995ab929SRiver Riddle 
389995ab929SRiver Riddle   // Create a bool attribute with incorrect type.
390995ab929SRiver Riddle   ArrayRef<bool> data;
391995ab929SRiver Riddle   auto type = RankedTensorType::get(data.size(), builder.getI32Type());
39209ca1c06SStephan Herhut   EXPECT_DEBUG_DEATH(
393995ab929SRiver Riddle       {
394995ab929SRiver Riddle         DenseBoolResourceElementsAttr::get(
3958847d9a2SRainer Orth             type, "resource",
3968847d9a2SRainer Orth             UnmanagedAsmResourceBlob::allocateInferAlign(data));
397995ab929SRiver Riddle       },
398995ab929SRiver Riddle       "invalid shape element type for provided type `T`");
399995ab929SRiver Riddle }
400995ab929SRiver Riddle } // namespace
401995ab929SRiver Riddle 
402995ab929SRiver Riddle //===----------------------------------------------------------------------===//
403995ab929SRiver Riddle // SparseElementsAttr
404995ab929SRiver Riddle //===----------------------------------------------------------------------===//
405995ab929SRiver Riddle 
406995ab929SRiver Riddle namespace {
40764ce74a6SChia-hung Duan TEST(SparseElementsAttrTest, GetZero) {
40864ce74a6SChia-hung Duan   MLIRContext context;
40964ce74a6SChia-hung Duan   context.allowUnregisteredDialects();
41064ce74a6SChia-hung Duan 
41164ce74a6SChia-hung Duan   IntegerType intTy = IntegerType::get(&context, 32);
412*f023da12SMatthias Springer   FloatType floatTy = Float32Type::get(&context);
413195730a6SRiver Riddle   Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
41464ce74a6SChia-hung Duan 
41564ce74a6SChia-hung Duan   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
41664ce74a6SChia-hung Duan   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
41764ce74a6SChia-hung Duan   ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
41864ce74a6SChia-hung Duan 
41964ce74a6SChia-hung Duan   auto indicesType =
42064ce74a6SChia-hung Duan       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
42164ce74a6SChia-hung Duan   auto indices =
42264ce74a6SChia-hung Duan       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
42364ce74a6SChia-hung Duan 
42464ce74a6SChia-hung Duan   RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
42564ce74a6SChia-hung Duan   auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
42664ce74a6SChia-hung Duan 
42764ce74a6SChia-hung Duan   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
42864ce74a6SChia-hung Duan   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
42964ce74a6SChia-hung Duan 
43064ce74a6SChia-hung Duan   RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
43164ce74a6SChia-hung Duan   auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
43264ce74a6SChia-hung Duan 
43364ce74a6SChia-hung Duan   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
43464ce74a6SChia-hung Duan   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
43564ce74a6SChia-hung Duan   auto sparseString =
43664ce74a6SChia-hung Duan       SparseElementsAttr::get(tensorString, indices, stringValue);
43764ce74a6SChia-hung Duan 
43864ce74a6SChia-hung Duan   // Only index (0, 0) contains an element, others are supposed to return
43964ce74a6SChia-hung Duan   // the zero/empty value.
440e1795322SJeff Niu   auto zeroIntValue =
4415550c821STres Popp       cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]);
442e1795322SJeff Niu   EXPECT_EQ(zeroIntValue.getInt(), 0);
44364ce74a6SChia-hung Duan   EXPECT_TRUE(zeroIntValue.getType() == intTy);
44464ce74a6SChia-hung Duan 
445e1795322SJeff Niu   auto zeroFloatValue =
4465550c821STres Popp       cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
447e1795322SJeff Niu   EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
44864ce74a6SChia-hung Duan   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
44964ce74a6SChia-hung Duan 
450e1795322SJeff Niu   auto zeroStringValue =
4515550c821STres Popp       cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
452b3a5f539SMogball   EXPECT_TRUE(zeroStringValue.empty());
45364ce74a6SChia-hung Duan   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
45464ce74a6SChia-hung Duan }
45564ce74a6SChia-hung Duan 
45603d136cfSRiver Riddle //===----------------------------------------------------------------------===//
45703d136cfSRiver Riddle // SubElements
45803d136cfSRiver Riddle //===----------------------------------------------------------------------===//
45903d136cfSRiver Riddle 
46003d136cfSRiver Riddle TEST(SubElementTest, Nested) {
46103d136cfSRiver Riddle   MLIRContext context;
46203d136cfSRiver Riddle   Builder builder(&context);
46303d136cfSRiver Riddle 
46403d136cfSRiver Riddle   BoolAttr trueAttr = builder.getBoolAttr(true);
46503d136cfSRiver Riddle   BoolAttr falseAttr = builder.getBoolAttr(false);
466c1fa8179SAdrian Kuegel   ArrayAttr boolArrayAttr =
467c1fa8179SAdrian Kuegel       builder.getArrayAttr({trueAttr, falseAttr, trueAttr});
46803d136cfSRiver Riddle   StringAttr strAttr = builder.getStringAttr("array");
46903d136cfSRiver Riddle   DictionaryAttr dictAttr =
47003d136cfSRiver Riddle       builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
47103d136cfSRiver Riddle 
47203d136cfSRiver Riddle   SmallVector<Attribute> subAttrs;
47303d136cfSRiver Riddle   dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); });
474c1fa8179SAdrian Kuegel   // Note that trueAttr appears only once, identical subattributes are skipped.
47503d136cfSRiver Riddle   EXPECT_EQ(llvm::ArrayRef(subAttrs),
47603d136cfSRiver Riddle             ArrayRef<Attribute>(
47703d136cfSRiver Riddle                 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
47803d136cfSRiver Riddle }
4795fc28e7aSMehdi Amini 
4805fc28e7aSMehdi Amini // Test how many times we call copy-ctor when building an attribute.
4815fc28e7aSMehdi Amini TEST(CopyCountAttr, CopyCount) {
4825fc28e7aSMehdi Amini   MLIRContext context;
4835fc28e7aSMehdi Amini   context.loadDialect<test::TestDialect>();
4845fc28e7aSMehdi Amini 
4855fc28e7aSMehdi Amini   test::CopyCount::counter = 0;
4865fc28e7aSMehdi Amini   test::CopyCount copyCount("hello");
4875fc28e7aSMehdi Amini   test::TestCopyCountAttr::get(&context, std::move(copyCount));
4885fc28e7aSMehdi Amini   int counter1 = test::CopyCount::counter;
4895fc28e7aSMehdi Amini   test::CopyCount::counter = 0;
4905fc28e7aSMehdi Amini   test::TestCopyCountAttr::get(&context, std::move(copyCount));
4915fc28e7aSMehdi Amini #ifndef NDEBUG
4925fc28e7aSMehdi Amini   // One verification enabled only in assert-mode requires a copy.
4935fc28e7aSMehdi Amini   EXPECT_EQ(counter1, 1);
4945fc28e7aSMehdi Amini   EXPECT_EQ(test::CopyCount::counter, 1);
4955fc28e7aSMehdi Amini #else
4965fc28e7aSMehdi Amini   EXPECT_EQ(counter1, 0);
4975fc28e7aSMehdi Amini   EXPECT_EQ(test::CopyCount::counter, 0);
4985fc28e7aSMehdi Amini #endif
4995fc28e7aSMehdi Amini }
5005fc28e7aSMehdi Amini 
501f6ff7574SJacques Pienaar // Test stripped printing using test dialect attribute.
502f6ff7574SJacques Pienaar TEST(CopyCountAttr, PrintStripped) {
503f6ff7574SJacques Pienaar   MLIRContext context;
504f6ff7574SJacques Pienaar   context.loadDialect<test::TestDialect>();
505f6ff7574SJacques Pienaar   // Doesn't matter which dialect attribute is used, just chose TestCopyCount
506f6ff7574SJacques Pienaar   // given proximity.
507f6ff7574SJacques Pienaar   test::CopyCount::counter = 0;
508f6ff7574SJacques Pienaar   test::CopyCount copyCount("hello");
509f6ff7574SJacques Pienaar   Attribute res = test::TestCopyCountAttr::get(&context, std::move(copyCount));
510f6ff7574SJacques Pienaar 
511f6ff7574SJacques Pienaar   std::string str;
512f6ff7574SJacques Pienaar   llvm::raw_string_ostream os(str);
513f6ff7574SJacques Pienaar   os << "|" << res << "|";
514f6ff7574SJacques Pienaar   res.printStripped(os << "[");
515f6ff7574SJacques Pienaar   os << "]";
516ffc80de8SJOE1994   EXPECT_EQ(str, "|#test.copy_count<hello>|[copy_count<hello>]");
517f6ff7574SJacques Pienaar }
518f6ff7574SJacques Pienaar 
519be0a7e9fSMehdi Amini } // namespace
520