//===- AttributeTest.cpp - Attribute unit tests ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/AsmState.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "gtest/gtest.h" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// template static void testSplat(Type eltType, const EltTy &splatElt) { RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); // Check that the generated splat is the same for 1 element and N elements. DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); EXPECT_TRUE(splat.isSplat()); auto detectedSplat = DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); EXPECT_EQ(detectedSplat, splat); for (auto newValue : detectedSplat.template getValues()) EXPECT_TRUE(newValue == splatElt); } namespace { TEST(DenseSplatTest, BoolSplat) { MLIRContext context; IntegerType boolTy = IntegerType::get(&context, 1); RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); // Check that splat is automatically detected for boolean values. /// True. DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); EXPECT_TRUE(trueSplat.isSplat()); /// False. DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); EXPECT_TRUE(falseSplat.isSplat()); EXPECT_NE(falseSplat, trueSplat); /// Detect and handle splat within 8 elements (bool values are bit-packed). /// True. auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); EXPECT_EQ(detectedSplat, trueSplat); /// False. detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); EXPECT_EQ(detectedSplat, falseSplat); } TEST(DenseSplatTest, LargeBoolSplat) { constexpr int64_t boolCount = 56; MLIRContext context; IntegerType boolTy = IntegerType::get(&context, 1); RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); // Check that splat is automatically detected for boolean values. /// True. DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); EXPECT_TRUE(trueSplat.isSplat()); EXPECT_TRUE(falseSplat.isSplat()); /// Detect that the large boolean arrays are properly splatted. /// True. SmallVector trueValues(boolCount, true); auto detectedSplat = DenseElementsAttr::get(shape, trueValues); EXPECT_EQ(detectedSplat, trueSplat); /// False. SmallVector falseValues(boolCount, false); detectedSplat = DenseElementsAttr::get(shape, falseValues); EXPECT_EQ(detectedSplat, falseSplat); } TEST(DenseSplatTest, BoolNonSplat) { MLIRContext context; IntegerType boolTy = IntegerType::get(&context, 1); RankedTensorType shape = RankedTensorType::get({6}, boolTy); // Check that we properly handle non-splat values. DenseElementsAttr nonSplat = DenseElementsAttr::get(shape, {false, false, true, false, false, true}); EXPECT_FALSE(nonSplat.isSplat()); } TEST(DenseSplatTest, OddIntSplat) { // Test detecting a splat with an odd(non 8-bit) integer bitwidth. MLIRContext context; constexpr size_t intWidth = 19; IntegerType intTy = IntegerType::get(&context, intWidth); APInt value(intWidth, 10); testSplat(intTy, value); } TEST(DenseSplatTest, Int32Splat) { MLIRContext context; IntegerType intTy = IntegerType::get(&context, 32); int value = 64; testSplat(intTy, value); } TEST(DenseSplatTest, IntAttrSplat) { MLIRContext context; IntegerType intTy = IntegerType::get(&context, 85); Attribute value = IntegerAttr::get(intTy, 109); testSplat(intTy, value); } TEST(DenseSplatTest, F32Splat) { MLIRContext context; FloatType floatTy = FloatType::getF32(&context); float value = 10.0; testSplat(floatTy, value); } TEST(DenseSplatTest, F64Splat) { MLIRContext context; FloatType floatTy = FloatType::getF64(&context); double value = 10.0; testSplat(floatTy, APFloat(value)); } TEST(DenseSplatTest, FloatAttrSplat) { MLIRContext context; FloatType floatTy = FloatType::getF32(&context); Attribute value = FloatAttr::get(floatTy, 10.0); testSplat(floatTy, value); } TEST(DenseSplatTest, BF16Splat) { MLIRContext context; FloatType floatTy = FloatType::getBF16(&context); Attribute value = FloatAttr::get(floatTy, 10.0); testSplat(floatTy, value); } TEST(DenseSplatTest, StringSplat) { MLIRContext context; context.allowUnregisteredDialects(); Type stringType = OpaqueType::get(StringAttr::get(&context, "test"), "string"); StringRef value = "test-string"; testSplat(stringType, value); } TEST(DenseSplatTest, StringAttrSplat) { MLIRContext context; context.allowUnregisteredDialects(); Type stringType = OpaqueType::get(StringAttr::get(&context, "test"), "string"); Attribute stringAttr = StringAttr::get("test-string", stringType); testSplat(stringType, stringAttr); } TEST(DenseComplexTest, ComplexFloatSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); std::complex value(10.0, 15.0); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexIntSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); std::complex value(10, 15); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexAPFloatSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); std::complex value(APFloat(10.0f), APFloat(15.0f)); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexAPIntSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); std::complex value(APInt(64, 10), APInt(64, 15)); testSplat(complexType, value); } TEST(DenseScalarTest, ExtractZeroRankElement) { MLIRContext context; const int elementValue = 12; IntegerType intTy = IntegerType::get(&context, 32); Attribute value = IntegerAttr::get(intTy, elementValue); RankedTensorType shape = RankedTensorType::get({}, intTy); auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); EXPECT_TRUE(attr.getValues()[0] == value); } } // namespace //===----------------------------------------------------------------------===// // DenseResourceElementsAttr //===----------------------------------------------------------------------===// template static void checkNativeAccess(MLIRContext *ctx, ArrayRef data, Type elementType) { auto type = RankedTensorType::get(data.size(), elementType); auto attr = AttrT::get(type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); // Check that we can access and iterate the data properly. Optional> attrData = attr.tryGetAsArrayRef(); EXPECT_TRUE(attrData.has_value()); EXPECT_EQ(*attrData, data); // Check that we cast to this attribute when possible. Attribute genericAttr = attr; EXPECT_TRUE(genericAttr.template isa()); } template static void checkNativeIntAccess(Builder &builder, size_t intWidth) { T data[] = {0, 1, 2}; checkNativeAccess(builder.getContext(), llvm::makeArrayRef(data), builder.getIntegerType(intWidth)); } namespace { TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { MLIRContext context; Builder builder(&context); // Bool bool boolData[] = {true, false, true}; checkNativeAccess( &context, llvm::makeArrayRef(boolData), builder.getI1Type()); // Unsigned integers checkNativeIntAccess(builder, 8); checkNativeIntAccess(builder, 16); checkNativeIntAccess(builder, 32); checkNativeIntAccess(builder, 64); // Signed integers checkNativeIntAccess(builder, 8); checkNativeIntAccess(builder, 16); checkNativeIntAccess(builder, 32); checkNativeIntAccess(builder, 64); // Float float floatData[] = {0, 1, 2}; checkNativeAccess( &context, llvm::makeArrayRef(floatData), builder.getF32Type()); // Double double doubleData[] = {0, 1, 2}; checkNativeAccess( &context, llvm::makeArrayRef(doubleData), builder.getF64Type()); } TEST(DenseResourceElementsAttrTest, CheckNoCast) { MLIRContext context; Builder builder(&context); // Create a i32 attribute. ArrayRef data; auto type = RankedTensorType::get(data.size(), builder.getI32Type()); Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); EXPECT_TRUE(i32ResourceAttr.isa()); EXPECT_FALSE(i32ResourceAttr.isa()); EXPECT_FALSE(i32ResourceAttr.isa()); } TEST(DenseResourceElementsAttrTest, CheckInvalidData) { MLIRContext context; Builder builder(&context); // Create a bool attribute with data of the incorrect type. ArrayRef data; auto type = RankedTensorType::get(data.size(), builder.getI32Type()); EXPECT_DEBUG_DEATH( { DenseBoolResourceElementsAttr::get( type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); }, "alignment mismatch between expected alignment and blob alignment"); } TEST(DenseResourceElementsAttrTest, CheckInvalidType) { MLIRContext context; Builder builder(&context); // Create a bool attribute with incorrect type. ArrayRef data; auto type = RankedTensorType::get(data.size(), builder.getI32Type()); EXPECT_DEBUG_DEATH( { DenseBoolResourceElementsAttr::get( type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); }, "invalid shape element type for provided type `T`"); } } // namespace //===----------------------------------------------------------------------===// // SparseElementsAttr //===----------------------------------------------------------------------===// namespace { TEST(SparseElementsAttrTest, GetZero) { MLIRContext context; context.allowUnregisteredDialects(); IntegerType intTy = IntegerType::get(&context, 32); FloatType floatTy = FloatType::getF32(&context); Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); auto indices = DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); auto sparseString = SparseElementsAttr::get(tensorString, indices, stringValue); // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. auto zeroIntValue = sparseInt.getValues()[{1, 1}].cast(); EXPECT_EQ(zeroIntValue.getInt(), 0); EXPECT_TRUE(zeroIntValue.getType() == intTy); auto zeroFloatValue = sparseFloat.getValues()[{1, 1}].cast(); EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); auto zeroStringValue = sparseString.getValues()[{1, 1}].cast(); EXPECT_TRUE(zeroStringValue.getValue().empty()); EXPECT_TRUE(zeroStringValue.getType() == stringTy); } } // namespace