1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/StandardTypes.h" 20 #include "gtest/gtest.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 template <typename EltTy> 26 static void testSplat(Type eltType, const EltTy &splatElt) { 27 VectorType shape = VectorType::get({2, 1}, eltType); 28 29 // Check that the generated splat is the same for 1 element and N elements. 30 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 31 EXPECT_TRUE(splat.isSplat()); 32 33 auto detectedSplat = 34 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); 35 EXPECT_EQ(detectedSplat, splat); 36 } 37 38 namespace { 39 TEST(DenseSplatTest, BoolSplat) { 40 MLIRContext context; 41 IntegerType boolTy = IntegerType::get(1, &context); 42 VectorType shape = VectorType::get({2, 2}, boolTy); 43 44 // Check that splat is automatically detected for boolean values. 45 /// True. 46 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 47 EXPECT_TRUE(trueSplat.isSplat()); 48 /// False. 49 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 50 EXPECT_TRUE(falseSplat.isSplat()); 51 EXPECT_NE(falseSplat, trueSplat); 52 53 /// Detect and handle splat within 8 elements (bool values are bit-packed). 54 /// True. 55 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 56 EXPECT_EQ(detectedSplat, trueSplat); 57 /// False. 58 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 59 EXPECT_EQ(detectedSplat, falseSplat); 60 } 61 62 TEST(DenseSplatTest, LargeBoolSplat) { 63 constexpr int64_t boolCount = 56; 64 65 MLIRContext context; 66 IntegerType boolTy = IntegerType::get(1, &context); 67 VectorType shape = VectorType::get({boolCount}, boolTy); 68 69 // Check that splat is automatically detected for boolean values. 70 /// True. 71 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 72 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 73 EXPECT_TRUE(trueSplat.isSplat()); 74 EXPECT_TRUE(falseSplat.isSplat()); 75 76 /// Detect that the large boolean arrays are properly splatted. 77 /// True. 78 SmallVector<bool, 64> trueValues(boolCount, true); 79 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 80 EXPECT_EQ(detectedSplat, trueSplat); 81 /// False. 82 SmallVector<bool, 64> falseValues(boolCount, false); 83 detectedSplat = DenseElementsAttr::get(shape, falseValues); 84 EXPECT_EQ(detectedSplat, falseSplat); 85 } 86 87 TEST(DenseSplatTest, BoolNonSplat) { 88 MLIRContext context; 89 IntegerType boolTy = IntegerType::get(1, &context); 90 VectorType shape = VectorType::get({6}, boolTy); 91 92 // Check that we properly handle non-splat values. 93 DenseElementsAttr nonSplat = 94 DenseElementsAttr::get(shape, {false, false, true, false, false, true}); 95 EXPECT_FALSE(nonSplat.isSplat()); 96 } 97 98 TEST(DenseSplatTest, OddIntSplat) { 99 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 100 MLIRContext context; 101 constexpr size_t intWidth = 19; 102 IntegerType intTy = IntegerType::get(intWidth, &context); 103 APInt value(intWidth, 10); 104 105 testSplat(intTy, value); 106 } 107 108 TEST(DenseSplatTest, Int32Splat) { 109 MLIRContext context; 110 IntegerType intTy = IntegerType::get(32, &context); 111 int value = 64; 112 113 testSplat(intTy, value); 114 } 115 116 TEST(DenseSplatTest, IntAttrSplat) { 117 MLIRContext context; 118 IntegerType intTy = IntegerType::get(85, &context); 119 Attribute value = IntegerAttr::get(intTy, 109); 120 121 testSplat(intTy, value); 122 } 123 124 TEST(DenseSplatTest, F32Splat) { 125 MLIRContext context; 126 FloatType floatTy = FloatType::getF32(&context); 127 float value = 10.0; 128 129 testSplat(floatTy, value); 130 } 131 132 TEST(DenseSplatTest, F64Splat) { 133 MLIRContext context; 134 FloatType floatTy = FloatType::getF64(&context); 135 double value = 10.0; 136 137 testSplat(floatTy, APFloat(value)); 138 } 139 140 TEST(DenseSplatTest, FloatAttrSplat) { 141 MLIRContext context; 142 FloatType floatTy = FloatType::getBF16(&context); 143 Attribute value = FloatAttr::get(floatTy, 10.0); 144 145 testSplat(floatTy, value); 146 } 147 } // end namespace 148