1d8cd96bcSRiver Riddle //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2d8cd96bcSRiver Riddle // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d8cd96bcSRiver Riddle // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 8d8cd96bcSRiver Riddle 9995ab929SRiver Riddle #include "mlir/IR/AsmState.h" 10995ab929SRiver Riddle #include "mlir/IR/Builders.h" 11c7cae0e4SRiver Riddle #include "mlir/IR/BuiltinAttributes.h" 1209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 13d8cd96bcSRiver Riddle #include "gtest/gtest.h" 14a1fe1f5fSKazu Hirata #include <optional> 15d8cd96bcSRiver Riddle 165fc28e7aSMehdi Amini #include "../../test/lib/Dialect/Test/TestDialect.h" 175fc28e7aSMehdi Amini 18d8cd96bcSRiver Riddle using namespace mlir; 19d8cd96bcSRiver Riddle using namespace mlir::detail; 20d8cd96bcSRiver Riddle 21995ab929SRiver Riddle //===----------------------------------------------------------------------===// 22995ab929SRiver Riddle // DenseElementsAttr 23995ab929SRiver Riddle //===----------------------------------------------------------------------===// 24995ab929SRiver Riddle 25d8cd96bcSRiver Riddle template <typename EltTy> 26d8cd96bcSRiver Riddle static void testSplat(Type eltType, const EltTy &splatElt) { 27910fff1cSRiver Riddle RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); 28d8cd96bcSRiver Riddle 29d8cd96bcSRiver Riddle // Check that the generated splat is the same for 1 element and N elements. 30d8cd96bcSRiver Riddle DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 31d8cd96bcSRiver Riddle EXPECT_TRUE(splat.isSplat()); 32d8cd96bcSRiver Riddle 33d8cd96bcSRiver Riddle auto detectedSplat = 34984b800aSserge-sans-paille DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); 35d8cd96bcSRiver Riddle EXPECT_EQ(detectedSplat, splat); 36da2a6f4eSRiver Riddle 37da2a6f4eSRiver Riddle for (auto newValue : detectedSplat.template getValues<EltTy>()) 3824ad3858SRiver Riddle EXPECT_TRUE(newValue == splatElt); 39d8cd96bcSRiver Riddle } 40d8cd96bcSRiver Riddle 41d8cd96bcSRiver Riddle namespace { 42d8cd96bcSRiver Riddle TEST(DenseSplatTest, BoolSplat) { 43e7021232SMehdi Amini MLIRContext context; 441b97cdf8SRiver Riddle IntegerType boolTy = IntegerType::get(&context, 1); 45910fff1cSRiver Riddle RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 46d8cd96bcSRiver Riddle 47d8cd96bcSRiver Riddle // Check that splat is automatically detected for boolean values. 48d8cd96bcSRiver Riddle /// True. 495624bc28SRiver Riddle DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 50d8cd96bcSRiver Riddle EXPECT_TRUE(trueSplat.isSplat()); 51d8cd96bcSRiver Riddle /// False. 525624bc28SRiver Riddle DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 53d8cd96bcSRiver Riddle EXPECT_TRUE(falseSplat.isSplat()); 54d8cd96bcSRiver Riddle EXPECT_NE(falseSplat, trueSplat); 55d8cd96bcSRiver Riddle 56d8cd96bcSRiver Riddle /// Detect and handle splat within 8 elements (bool values are bit-packed). 57d8cd96bcSRiver Riddle /// True. 585624bc28SRiver Riddle auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 59d8cd96bcSRiver Riddle EXPECT_EQ(detectedSplat, trueSplat); 60d8cd96bcSRiver Riddle /// False. 615624bc28SRiver Riddle detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 62d8cd96bcSRiver Riddle EXPECT_EQ(detectedSplat, falseSplat); 63d8cd96bcSRiver Riddle } 649e0900cbSRiver Riddle TEST(DenseSplatTest, BoolSplatRawRoundtrip) { 659e0900cbSRiver Riddle MLIRContext context; 669e0900cbSRiver Riddle IntegerType boolTy = IntegerType::get(&context, 1); 679e0900cbSRiver Riddle RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); 689e0900cbSRiver Riddle 699e0900cbSRiver Riddle // Check that splat booleans properly round trip via the raw API. 709e0900cbSRiver Riddle DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 719e0900cbSRiver Riddle EXPECT_TRUE(trueSplat.isSplat()); 729e0900cbSRiver Riddle DenseElementsAttr trueSplatFromRaw = 739e0900cbSRiver Riddle DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData()); 749e0900cbSRiver Riddle EXPECT_TRUE(trueSplatFromRaw.isSplat()); 759e0900cbSRiver Riddle 769e0900cbSRiver Riddle EXPECT_EQ(trueSplat, trueSplatFromRaw); 779e0900cbSRiver Riddle } 78d8cd96bcSRiver Riddle 79998003eaSKevin Gleason TEST(DenseSplatTest, BoolSplatSmall) { 80998003eaSKevin Gleason MLIRContext context; 81998003eaSKevin Gleason Builder builder(&context); 82998003eaSKevin Gleason 83998003eaSKevin Gleason // Check that splats that don't fill entire byte are handled properly. 84998003eaSKevin Gleason auto tensorType = RankedTensorType::get({4}, builder.getI1Type()); 85998003eaSKevin Gleason std::vector<char> data{0b00001111}; 86998003eaSKevin Gleason auto trueSplatFromRaw = 87998003eaSKevin Gleason DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data); 88998003eaSKevin Gleason EXPECT_TRUE(trueSplatFromRaw.isSplat()); 89998003eaSKevin Gleason DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true); 90998003eaSKevin Gleason EXPECT_EQ(trueSplat, trueSplatFromRaw); 91998003eaSKevin Gleason } 92998003eaSKevin Gleason 93d8cd96bcSRiver Riddle TEST(DenseSplatTest, LargeBoolSplat) { 942c926912SRiver Riddle constexpr int64_t boolCount = 56; 95d8cd96bcSRiver Riddle 96e7021232SMehdi Amini MLIRContext context; 971b97cdf8SRiver Riddle IntegerType boolTy = IntegerType::get(&context, 1); 98910fff1cSRiver Riddle RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); 99d8cd96bcSRiver Riddle 100d8cd96bcSRiver Riddle // Check that splat is automatically detected for boolean values. 101d8cd96bcSRiver Riddle /// True. 102d8cd96bcSRiver Riddle DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 103d8cd96bcSRiver Riddle DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 104d8cd96bcSRiver Riddle EXPECT_TRUE(trueSplat.isSplat()); 105d8cd96bcSRiver Riddle EXPECT_TRUE(falseSplat.isSplat()); 106d8cd96bcSRiver Riddle 107d8cd96bcSRiver Riddle /// Detect that the large boolean arrays are properly splatted. 108d8cd96bcSRiver Riddle /// True. 109d8cd96bcSRiver Riddle SmallVector<bool, 64> trueValues(boolCount, true); 110d8cd96bcSRiver Riddle auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 111d8cd96bcSRiver Riddle EXPECT_EQ(detectedSplat, trueSplat); 112d8cd96bcSRiver Riddle /// False. 113d8cd96bcSRiver Riddle SmallVector<bool, 64> falseValues(boolCount, false); 114d8cd96bcSRiver Riddle detectedSplat = DenseElementsAttr::get(shape, falseValues); 115d8cd96bcSRiver Riddle EXPECT_EQ(detectedSplat, falseSplat); 116d8cd96bcSRiver Riddle } 117d8cd96bcSRiver Riddle 1182b67821bSRiver Riddle TEST(DenseSplatTest, BoolNonSplat) { 119e7021232SMehdi Amini MLIRContext context; 1201b97cdf8SRiver Riddle IntegerType boolTy = IntegerType::get(&context, 1); 121910fff1cSRiver Riddle RankedTensorType shape = RankedTensorType::get({6}, boolTy); 1222b67821bSRiver Riddle 1232b67821bSRiver Riddle // Check that we properly handle non-splat values. 1242b67821bSRiver Riddle DenseElementsAttr nonSplat = 1252b67821bSRiver Riddle DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 1262b67821bSRiver Riddle EXPECT_FALSE(nonSplat.isSplat()); 1272b67821bSRiver Riddle } 1282b67821bSRiver Riddle 129d8cd96bcSRiver Riddle TEST(DenseSplatTest, OddIntSplat) { 130d8cd96bcSRiver Riddle // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 131e7021232SMehdi Amini MLIRContext context; 132d8cd96bcSRiver Riddle constexpr size_t intWidth = 19; 1331b97cdf8SRiver Riddle IntegerType intTy = IntegerType::get(&context, intWidth); 134d8cd96bcSRiver Riddle APInt value(intWidth, 10); 135d8cd96bcSRiver Riddle 136d8cd96bcSRiver Riddle testSplat(intTy, value); 137d8cd96bcSRiver Riddle } 138d8cd96bcSRiver Riddle 139d8cd96bcSRiver Riddle TEST(DenseSplatTest, Int32Splat) { 140e7021232SMehdi Amini MLIRContext context; 1411b97cdf8SRiver Riddle IntegerType intTy = IntegerType::get(&context, 32); 142d8cd96bcSRiver Riddle int value = 64; 143d8cd96bcSRiver Riddle 144d8cd96bcSRiver Riddle testSplat(intTy, value); 145d8cd96bcSRiver Riddle } 146d8cd96bcSRiver Riddle 147d8cd96bcSRiver Riddle TEST(DenseSplatTest, IntAttrSplat) { 148e7021232SMehdi Amini MLIRContext context; 1491b97cdf8SRiver Riddle IntegerType intTy = IntegerType::get(&context, 85); 150d8cd96bcSRiver Riddle Attribute value = IntegerAttr::get(intTy, 109); 151d8cd96bcSRiver Riddle 152d8cd96bcSRiver Riddle testSplat(intTy, value); 153d8cd96bcSRiver Riddle } 154d8cd96bcSRiver Riddle 155d8cd96bcSRiver Riddle TEST(DenseSplatTest, F32Splat) { 156e7021232SMehdi Amini MLIRContext context; 157*f023da12SMatthias Springer FloatType floatTy = Float32Type::get(&context); 158d8cd96bcSRiver Riddle float value = 10.0; 159d8cd96bcSRiver Riddle 160d8cd96bcSRiver Riddle testSplat(floatTy, value); 161d8cd96bcSRiver Riddle } 162d8cd96bcSRiver Riddle 163d8cd96bcSRiver Riddle TEST(DenseSplatTest, F64Splat) { 164e7021232SMehdi Amini MLIRContext context; 165*f023da12SMatthias Springer FloatType floatTy = Float64Type::get(&context); 166d8cd96bcSRiver Riddle double value = 10.0; 167d8cd96bcSRiver Riddle 168d8cd96bcSRiver Riddle testSplat(floatTy, APFloat(value)); 169d8cd96bcSRiver Riddle } 170d8cd96bcSRiver Riddle 171d8cd96bcSRiver Riddle TEST(DenseSplatTest, FloatAttrSplat) { 172e7021232SMehdi Amini MLIRContext context; 173*f023da12SMatthias Springer FloatType floatTy = Float32Type::get(&context); 174d8cd96bcSRiver Riddle Attribute value = FloatAttr::get(floatTy, 10.0); 175d8cd96bcSRiver Riddle 176d8cd96bcSRiver Riddle testSplat(floatTy, value); 177d8cd96bcSRiver Riddle } 17868c8b6c4SRiver Riddle 17968c8b6c4SRiver Riddle TEST(DenseSplatTest, BF16Splat) { 180e7021232SMehdi Amini MLIRContext context; 181*f023da12SMatthias Springer FloatType floatTy = BFloat16Type::get(&context); 1827d59f49bSDiego Caballero Attribute value = FloatAttr::get(floatTy, 10.0); 18368c8b6c4SRiver Riddle 18468c8b6c4SRiver Riddle testSplat(floatTy, value); 18568c8b6c4SRiver Riddle } 18668c8b6c4SRiver Riddle 187910fff1cSRiver Riddle TEST(DenseSplatTest, StringSplat) { 188e7021232SMehdi Amini MLIRContext context; 189109305e1SRiver Riddle context.allowUnregisteredDialects(); 190910fff1cSRiver Riddle Type stringType = 191195730a6SRiver Riddle OpaqueType::get(StringAttr::get(&context, "test"), "string"); 192910fff1cSRiver Riddle StringRef value = "test-string"; 193910fff1cSRiver Riddle testSplat(stringType, value); 194910fff1cSRiver Riddle } 195910fff1cSRiver Riddle 1960d5caa89SRiver Riddle TEST(DenseSplatTest, StringAttrSplat) { 197e7021232SMehdi Amini MLIRContext context; 198109305e1SRiver Riddle context.allowUnregisteredDialects(); 1990d5caa89SRiver Riddle Type stringType = 200195730a6SRiver Riddle OpaqueType::get(StringAttr::get(&context, "test"), "string"); 2010d5caa89SRiver Riddle Attribute stringAttr = StringAttr::get("test-string", stringType); 2020d5caa89SRiver Riddle testSplat(stringType, stringAttr); 2030d5caa89SRiver Riddle } 2040d5caa89SRiver Riddle 205da2a6f4eSRiver Riddle TEST(DenseComplexTest, ComplexFloatSplat) { 206e7021232SMehdi Amini MLIRContext context; 207*f023da12SMatthias Springer ComplexType complexType = ComplexType::get(Float32Type::get(&context)); 208da2a6f4eSRiver Riddle std::complex<float> value(10.0, 15.0); 209da2a6f4eSRiver Riddle testSplat(complexType, value); 210da2a6f4eSRiver Riddle } 211da2a6f4eSRiver Riddle 212da2a6f4eSRiver Riddle TEST(DenseComplexTest, ComplexIntSplat) { 213e7021232SMehdi Amini MLIRContext context; 2141b97cdf8SRiver Riddle ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 215da2a6f4eSRiver Riddle std::complex<int64_t> value(10, 15); 216da2a6f4eSRiver Riddle testSplat(complexType, value); 217da2a6f4eSRiver Riddle } 218da2a6f4eSRiver Riddle 21924ad3858SRiver Riddle TEST(DenseComplexTest, ComplexAPFloatSplat) { 220e7021232SMehdi Amini MLIRContext context; 221*f023da12SMatthias Springer ComplexType complexType = ComplexType::get(Float32Type::get(&context)); 22224ad3858SRiver Riddle std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); 22324ad3858SRiver Riddle testSplat(complexType, value); 22424ad3858SRiver Riddle } 22524ad3858SRiver Riddle 22624ad3858SRiver Riddle TEST(DenseComplexTest, ComplexAPIntSplat) { 227e7021232SMehdi Amini MLIRContext context; 2281b97cdf8SRiver Riddle ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); 22924ad3858SRiver Riddle std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); 23024ad3858SRiver Riddle testSplat(complexType, value); 23124ad3858SRiver Riddle } 23224ad3858SRiver Riddle 2330af25275Skarimnosseir TEST(DenseScalarTest, ExtractZeroRankElement) { 2340af25275Skarimnosseir MLIRContext context; 2350af25275Skarimnosseir const int elementValue = 12; 2360af25275Skarimnosseir IntegerType intTy = IntegerType::get(&context, 32); 2370af25275Skarimnosseir Attribute value = IntegerAttr::get(intTy, elementValue); 2380af25275Skarimnosseir RankedTensorType shape = RankedTensorType::get({}, intTy); 2390af25275Skarimnosseir 240984b800aSserge-sans-paille auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})); 241ae40d625SRiver Riddle EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); 2420af25275Skarimnosseir } 243ae60a4a0SChenguang Wang 244ae60a4a0SChenguang Wang TEST(DenseSplatMapValuesTest, I32ToTrue) { 245ae60a4a0SChenguang Wang MLIRContext context; 246ae60a4a0SChenguang Wang const int elementValue = 12; 247ae60a4a0SChenguang Wang IntegerType boolTy = IntegerType::get(&context, 1); 248ae60a4a0SChenguang Wang IntegerType intTy = IntegerType::get(&context, 32); 249ae60a4a0SChenguang Wang RankedTensorType shape = RankedTensorType::get({4}, intTy); 250ae60a4a0SChenguang Wang 251ae60a4a0SChenguang Wang auto attr = 252984b800aSserge-sans-paille DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 253ae60a4a0SChenguang Wang .mapValues(boolTy, [](const APInt &x) { 254ae60a4a0SChenguang Wang return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 255ae60a4a0SChenguang Wang }); 256ae60a4a0SChenguang Wang EXPECT_EQ(attr.getNumElements(), 4); 257ae60a4a0SChenguang Wang EXPECT_TRUE(attr.isSplat()); 258ae60a4a0SChenguang Wang EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue()); 259ae60a4a0SChenguang Wang } 260ae60a4a0SChenguang Wang 261ae60a4a0SChenguang Wang TEST(DenseSplatMapValuesTest, I32ToFalse) { 262ae60a4a0SChenguang Wang MLIRContext context; 263ae60a4a0SChenguang Wang const int elementValue = 0; 264ae60a4a0SChenguang Wang IntegerType boolTy = IntegerType::get(&context, 1); 265ae60a4a0SChenguang Wang IntegerType intTy = IntegerType::get(&context, 32); 266ae60a4a0SChenguang Wang RankedTensorType shape = RankedTensorType::get({4}, intTy); 267ae60a4a0SChenguang Wang 268ae60a4a0SChenguang Wang auto attr = 269984b800aSserge-sans-paille DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) 270ae60a4a0SChenguang Wang .mapValues(boolTy, [](const APInt &x) { 271ae60a4a0SChenguang Wang return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); 272ae60a4a0SChenguang Wang }); 273ae60a4a0SChenguang Wang EXPECT_EQ(attr.getNumElements(), 4); 274ae60a4a0SChenguang Wang EXPECT_TRUE(attr.isSplat()); 275ae60a4a0SChenguang Wang EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue()); 276ae60a4a0SChenguang Wang } 277995ab929SRiver Riddle } // namespace 2780af25275Skarimnosseir 279995ab929SRiver Riddle //===----------------------------------------------------------------------===// 280995ab929SRiver Riddle // DenseResourceElementsAttr 281995ab929SRiver Riddle //===----------------------------------------------------------------------===// 282995ab929SRiver Riddle 283995ab929SRiver Riddle template <typename AttrT, typename T> 284995ab929SRiver Riddle static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, 285995ab929SRiver Riddle Type elementType) { 286995ab929SRiver Riddle auto type = RankedTensorType::get(data.size(), elementType); 2878847d9a2SRainer Orth auto attr = AttrT::get(type, "resource", 2888847d9a2SRainer Orth UnmanagedAsmResourceBlob::allocateInferAlign(data)); 289995ab929SRiver Riddle 290995ab929SRiver Riddle // Check that we can access and iterate the data properly. 2910a81ace0SKazu Hirata std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); 2929750648cSKazu Hirata EXPECT_TRUE(attrData.has_value()); 293995ab929SRiver Riddle EXPECT_EQ(*attrData, data); 294995ab929SRiver Riddle 295995ab929SRiver Riddle // Check that we cast to this attribute when possible. 296995ab929SRiver Riddle Attribute genericAttr = attr; 2975550c821STres Popp EXPECT_TRUE(isa<AttrT>(genericAttr)); 298995ab929SRiver Riddle } 299995ab929SRiver Riddle template <typename AttrT, typename T> 300995ab929SRiver Riddle static void checkNativeIntAccess(Builder &builder, size_t intWidth) { 301995ab929SRiver Riddle T data[] = {0, 1, 2}; 302984b800aSserge-sans-paille checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data), 303995ab929SRiver Riddle builder.getIntegerType(intWidth)); 304995ab929SRiver Riddle } 305995ab929SRiver Riddle 306995ab929SRiver Riddle namespace { 307995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { 308995ab929SRiver Riddle MLIRContext context; 309995ab929SRiver Riddle Builder builder(&context); 310995ab929SRiver Riddle 311995ab929SRiver Riddle // Bool 312995ab929SRiver Riddle bool boolData[] = {true, false, true}; 313995ab929SRiver Riddle checkNativeAccess<DenseBoolResourceElementsAttr>( 314984b800aSserge-sans-paille &context, llvm::ArrayRef(boolData), builder.getI1Type()); 315995ab929SRiver Riddle 316995ab929SRiver Riddle // Unsigned integers 317995ab929SRiver Riddle checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8); 318995ab929SRiver Riddle checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16); 319995ab929SRiver Riddle checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32); 320995ab929SRiver Riddle checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64); 321995ab929SRiver Riddle 322995ab929SRiver Riddle // Signed integers 323995ab929SRiver Riddle checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8); 324995ab929SRiver Riddle checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16); 325995ab929SRiver Riddle checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32); 326995ab929SRiver Riddle checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64); 327995ab929SRiver Riddle 328995ab929SRiver Riddle // Float 329995ab929SRiver Riddle float floatData[] = {0, 1, 2}; 330995ab929SRiver Riddle checkNativeAccess<DenseF32ResourceElementsAttr>( 331984b800aSserge-sans-paille &context, llvm::ArrayRef(floatData), builder.getF32Type()); 332995ab929SRiver Riddle 333995ab929SRiver Riddle // Double 334995ab929SRiver Riddle double doubleData[] = {0, 1, 2}; 335995ab929SRiver Riddle checkNativeAccess<DenseF64ResourceElementsAttr>( 336984b800aSserge-sans-paille &context, llvm::ArrayRef(doubleData), builder.getF64Type()); 337995ab929SRiver Riddle } 338995ab929SRiver Riddle 339995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckNoCast) { 340995ab929SRiver Riddle MLIRContext context; 341995ab929SRiver Riddle Builder builder(&context); 342995ab929SRiver Riddle 343995ab929SRiver Riddle // Create a i32 attribute. 344995ab929SRiver Riddle ArrayRef<uint32_t> data; 345995ab929SRiver Riddle auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 346995ab929SRiver Riddle Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 3478847d9a2SRainer Orth type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); 348995ab929SRiver Riddle 3495550c821STres Popp EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr)); 3505550c821STres Popp EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr)); 3515550c821STres Popp EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr)); 352995ab929SRiver Riddle } 353995ab929SRiver Riddle 354b96ebee1SJacques Pienaar TEST(DenseResourceElementsAttrTest, CheckNotMutableAllocateAndCopy) { 355b96ebee1SJacques Pienaar MLIRContext context; 356b96ebee1SJacques Pienaar Builder builder(&context); 357b96ebee1SJacques Pienaar 358b96ebee1SJacques Pienaar // Create a i32 attribute. 359b96ebee1SJacques Pienaar std::vector<int32_t> data = {10, 20, 30}; 360b96ebee1SJacques Pienaar auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 361b96ebee1SJacques Pienaar Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( 362b96ebee1SJacques Pienaar type, "resource", 363b96ebee1SJacques Pienaar HeapAsmResourceBlob::allocateAndCopyInferAlign<int32_t>( 364b96ebee1SJacques Pienaar data, /*is_mutable=*/false)); 365b96ebee1SJacques Pienaar 366b96ebee1SJacques Pienaar EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr)); 367b96ebee1SJacques Pienaar } 368b96ebee1SJacques Pienaar 369995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckInvalidData) { 370995ab929SRiver Riddle MLIRContext context; 371995ab929SRiver Riddle Builder builder(&context); 372995ab929SRiver Riddle 373995ab929SRiver Riddle // Create a bool attribute with data of the incorrect type. 374995ab929SRiver Riddle ArrayRef<uint32_t> data; 375995ab929SRiver Riddle auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 37609ca1c06SStephan Herhut EXPECT_DEBUG_DEATH( 377995ab929SRiver Riddle { 378995ab929SRiver Riddle DenseBoolResourceElementsAttr::get( 3798847d9a2SRainer Orth type, "resource", 3808847d9a2SRainer Orth UnmanagedAsmResourceBlob::allocateInferAlign(data)); 381995ab929SRiver Riddle }, 382995ab929SRiver Riddle "alignment mismatch between expected alignment and blob alignment"); 383995ab929SRiver Riddle } 384995ab929SRiver Riddle 385995ab929SRiver Riddle TEST(DenseResourceElementsAttrTest, CheckInvalidType) { 386995ab929SRiver Riddle MLIRContext context; 387995ab929SRiver Riddle Builder builder(&context); 388995ab929SRiver Riddle 389995ab929SRiver Riddle // Create a bool attribute with incorrect type. 390995ab929SRiver Riddle ArrayRef<bool> data; 391995ab929SRiver Riddle auto type = RankedTensorType::get(data.size(), builder.getI32Type()); 39209ca1c06SStephan Herhut EXPECT_DEBUG_DEATH( 393995ab929SRiver Riddle { 394995ab929SRiver Riddle DenseBoolResourceElementsAttr::get( 3958847d9a2SRainer Orth type, "resource", 3968847d9a2SRainer Orth UnmanagedAsmResourceBlob::allocateInferAlign(data)); 397995ab929SRiver Riddle }, 398995ab929SRiver Riddle "invalid shape element type for provided type `T`"); 399995ab929SRiver Riddle } 400995ab929SRiver Riddle } // namespace 401995ab929SRiver Riddle 402995ab929SRiver Riddle //===----------------------------------------------------------------------===// 403995ab929SRiver Riddle // SparseElementsAttr 404995ab929SRiver Riddle //===----------------------------------------------------------------------===// 405995ab929SRiver Riddle 406995ab929SRiver Riddle namespace { 40764ce74a6SChia-hung Duan TEST(SparseElementsAttrTest, GetZero) { 40864ce74a6SChia-hung Duan MLIRContext context; 40964ce74a6SChia-hung Duan context.allowUnregisteredDialects(); 41064ce74a6SChia-hung Duan 41164ce74a6SChia-hung Duan IntegerType intTy = IntegerType::get(&context, 32); 412*f023da12SMatthias Springer FloatType floatTy = Float32Type::get(&context); 413195730a6SRiver Riddle Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); 41464ce74a6SChia-hung Duan 41564ce74a6SChia-hung Duan ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); 41664ce74a6SChia-hung Duan ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); 41764ce74a6SChia-hung Duan ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); 41864ce74a6SChia-hung Duan 41964ce74a6SChia-hung Duan auto indicesType = 42064ce74a6SChia-hung Duan RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); 42164ce74a6SChia-hung Duan auto indices = 42264ce74a6SChia-hung Duan DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); 42364ce74a6SChia-hung Duan 42464ce74a6SChia-hung Duan RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); 42564ce74a6SChia-hung Duan auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); 42664ce74a6SChia-hung Duan 42764ce74a6SChia-hung Duan RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); 42864ce74a6SChia-hung Duan auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); 42964ce74a6SChia-hung Duan 43064ce74a6SChia-hung Duan RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); 43164ce74a6SChia-hung Duan auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); 43264ce74a6SChia-hung Duan 43364ce74a6SChia-hung Duan auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); 43464ce74a6SChia-hung Duan auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); 43564ce74a6SChia-hung Duan auto sparseString = 43664ce74a6SChia-hung Duan SparseElementsAttr::get(tensorString, indices, stringValue); 43764ce74a6SChia-hung Duan 43864ce74a6SChia-hung Duan // Only index (0, 0) contains an element, others are supposed to return 43964ce74a6SChia-hung Duan // the zero/empty value. 440e1795322SJeff Niu auto zeroIntValue = 4415550c821STres Popp cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]); 442e1795322SJeff Niu EXPECT_EQ(zeroIntValue.getInt(), 0); 44364ce74a6SChia-hung Duan EXPECT_TRUE(zeroIntValue.getType() == intTy); 44464ce74a6SChia-hung Duan 445e1795322SJeff Niu auto zeroFloatValue = 4465550c821STres Popp cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]); 447e1795322SJeff Niu EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); 44864ce74a6SChia-hung Duan EXPECT_TRUE(zeroFloatValue.getType() == floatTy); 44964ce74a6SChia-hung Duan 450e1795322SJeff Niu auto zeroStringValue = 4515550c821STres Popp cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]); 452b3a5f539SMogball EXPECT_TRUE(zeroStringValue.empty()); 45364ce74a6SChia-hung Duan EXPECT_TRUE(zeroStringValue.getType() == stringTy); 45464ce74a6SChia-hung Duan } 45564ce74a6SChia-hung Duan 45603d136cfSRiver Riddle //===----------------------------------------------------------------------===// 45703d136cfSRiver Riddle // SubElements 45803d136cfSRiver Riddle //===----------------------------------------------------------------------===// 45903d136cfSRiver Riddle 46003d136cfSRiver Riddle TEST(SubElementTest, Nested) { 46103d136cfSRiver Riddle MLIRContext context; 46203d136cfSRiver Riddle Builder builder(&context); 46303d136cfSRiver Riddle 46403d136cfSRiver Riddle BoolAttr trueAttr = builder.getBoolAttr(true); 46503d136cfSRiver Riddle BoolAttr falseAttr = builder.getBoolAttr(false); 466c1fa8179SAdrian Kuegel ArrayAttr boolArrayAttr = 467c1fa8179SAdrian Kuegel builder.getArrayAttr({trueAttr, falseAttr, trueAttr}); 46803d136cfSRiver Riddle StringAttr strAttr = builder.getStringAttr("array"); 46903d136cfSRiver Riddle DictionaryAttr dictAttr = 47003d136cfSRiver Riddle builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr)); 47103d136cfSRiver Riddle 47203d136cfSRiver Riddle SmallVector<Attribute> subAttrs; 47303d136cfSRiver Riddle dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); }); 474c1fa8179SAdrian Kuegel // Note that trueAttr appears only once, identical subattributes are skipped. 47503d136cfSRiver Riddle EXPECT_EQ(llvm::ArrayRef(subAttrs), 47603d136cfSRiver Riddle ArrayRef<Attribute>( 47703d136cfSRiver Riddle {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); 47803d136cfSRiver Riddle } 4795fc28e7aSMehdi Amini 4805fc28e7aSMehdi Amini // Test how many times we call copy-ctor when building an attribute. 4815fc28e7aSMehdi Amini TEST(CopyCountAttr, CopyCount) { 4825fc28e7aSMehdi Amini MLIRContext context; 4835fc28e7aSMehdi Amini context.loadDialect<test::TestDialect>(); 4845fc28e7aSMehdi Amini 4855fc28e7aSMehdi Amini test::CopyCount::counter = 0; 4865fc28e7aSMehdi Amini test::CopyCount copyCount("hello"); 4875fc28e7aSMehdi Amini test::TestCopyCountAttr::get(&context, std::move(copyCount)); 4885fc28e7aSMehdi Amini int counter1 = test::CopyCount::counter; 4895fc28e7aSMehdi Amini test::CopyCount::counter = 0; 4905fc28e7aSMehdi Amini test::TestCopyCountAttr::get(&context, std::move(copyCount)); 4915fc28e7aSMehdi Amini #ifndef NDEBUG 4925fc28e7aSMehdi Amini // One verification enabled only in assert-mode requires a copy. 4935fc28e7aSMehdi Amini EXPECT_EQ(counter1, 1); 4945fc28e7aSMehdi Amini EXPECT_EQ(test::CopyCount::counter, 1); 4955fc28e7aSMehdi Amini #else 4965fc28e7aSMehdi Amini EXPECT_EQ(counter1, 0); 4975fc28e7aSMehdi Amini EXPECT_EQ(test::CopyCount::counter, 0); 4985fc28e7aSMehdi Amini #endif 4995fc28e7aSMehdi Amini } 5005fc28e7aSMehdi Amini 501f6ff7574SJacques Pienaar // Test stripped printing using test dialect attribute. 502f6ff7574SJacques Pienaar TEST(CopyCountAttr, PrintStripped) { 503f6ff7574SJacques Pienaar MLIRContext context; 504f6ff7574SJacques Pienaar context.loadDialect<test::TestDialect>(); 505f6ff7574SJacques Pienaar // Doesn't matter which dialect attribute is used, just chose TestCopyCount 506f6ff7574SJacques Pienaar // given proximity. 507f6ff7574SJacques Pienaar test::CopyCount::counter = 0; 508f6ff7574SJacques Pienaar test::CopyCount copyCount("hello"); 509f6ff7574SJacques Pienaar Attribute res = test::TestCopyCountAttr::get(&context, std::move(copyCount)); 510f6ff7574SJacques Pienaar 511f6ff7574SJacques Pienaar std::string str; 512f6ff7574SJacques Pienaar llvm::raw_string_ostream os(str); 513f6ff7574SJacques Pienaar os << "|" << res << "|"; 514f6ff7574SJacques Pienaar res.printStripped(os << "["); 515f6ff7574SJacques Pienaar os << "]"; 516ffc80de8SJOE1994 EXPECT_EQ(str, "|#test.copy_count<hello>|[copy_count<hello>]"); 517f6ff7574SJacques Pienaar } 518f6ff7574SJacques Pienaar 519be0a7e9fSMehdi Amini } // namespace 520