1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/StandardTypes.h" 20 #include "gtest/gtest.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 template <typename EltTy> 26 static void testSplat(Type eltType, const EltTy &splatElt) { 27 VectorType shape = VectorType::get({2, 1}, eltType); 28 29 // Check that the generated splat is the same for 1 element and N elements. 30 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 31 EXPECT_TRUE(splat.isSplat()); 32 33 auto detectedSplat = 34 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); 35 EXPECT_EQ(detectedSplat, splat); 36 } 37 38 namespace { 39 TEST(DenseSplatTest, BoolSplat) { 40 MLIRContext context; 41 IntegerType boolTy = IntegerType::get(1, &context); 42 VectorType shape = VectorType::get({2, 2}, boolTy); 43 44 // Check that splat is automatically detected for boolean values. 45 /// True. 46 DenseElementsAttr trueSplat = 47 DenseElementsAttr::get(shape, llvm::ArrayRef<bool>(true)); 48 EXPECT_TRUE(trueSplat.isSplat()); 49 /// False. 50 DenseElementsAttr falseSplat = 51 DenseElementsAttr::get(shape, llvm::ArrayRef<bool>(false)); 52 EXPECT_TRUE(falseSplat.isSplat()); 53 EXPECT_NE(falseSplat, trueSplat); 54 55 /// Detect and handle splat within 8 elements (bool values are bit-packed). 56 /// True. 57 auto detectedSplat = DenseElementsAttr::get( 58 shape, llvm::ArrayRef<bool>({true, true, true, true})); 59 EXPECT_EQ(detectedSplat, trueSplat); 60 /// False. 61 detectedSplat = DenseElementsAttr::get( 62 shape, llvm::ArrayRef<bool>({false, false, false, false})); 63 EXPECT_EQ(detectedSplat, falseSplat); 64 } 65 66 TEST(DenseSplatTest, LargeBoolSplat) { 67 constexpr int64_t boolCount = 56; 68 69 MLIRContext context; 70 IntegerType boolTy = IntegerType::get(1, &context); 71 VectorType shape = VectorType::get({boolCount}, boolTy); 72 73 // Check that splat is automatically detected for boolean values. 74 /// True. 75 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 76 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 77 EXPECT_TRUE(trueSplat.isSplat()); 78 EXPECT_TRUE(falseSplat.isSplat()); 79 80 /// Detect that the large boolean arrays are properly splatted. 81 /// True. 82 SmallVector<bool, 64> trueValues(boolCount, true); 83 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 84 EXPECT_EQ(detectedSplat, trueSplat); 85 /// False. 86 SmallVector<bool, 64> falseValues(boolCount, false); 87 detectedSplat = DenseElementsAttr::get(shape, falseValues); 88 EXPECT_EQ(detectedSplat, falseSplat); 89 } 90 91 TEST(DenseSplatTest, OddIntSplat) { 92 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 93 MLIRContext context; 94 constexpr size_t intWidth = 19; 95 IntegerType intTy = IntegerType::get(intWidth, &context); 96 APInt value(intWidth, 10); 97 98 testSplat(intTy, value); 99 } 100 101 TEST(DenseSplatTest, Int32Splat) { 102 MLIRContext context; 103 IntegerType intTy = IntegerType::get(32, &context); 104 int value = 64; 105 106 testSplat(intTy, value); 107 } 108 109 TEST(DenseSplatTest, IntAttrSplat) { 110 MLIRContext context; 111 IntegerType intTy = IntegerType::get(85, &context); 112 Attribute value = IntegerAttr::get(intTy, 109); 113 114 testSplat(intTy, value); 115 } 116 117 TEST(DenseSplatTest, F32Splat) { 118 MLIRContext context; 119 FloatType floatTy = FloatType::getF32(&context); 120 float value = 10.0; 121 122 testSplat(floatTy, value); 123 } 124 125 TEST(DenseSplatTest, F64Splat) { 126 MLIRContext context; 127 FloatType floatTy = FloatType::getF64(&context); 128 double value = 10.0; 129 130 testSplat(floatTy, APFloat(value)); 131 } 132 133 TEST(DenseSplatTest, FloatAttrSplat) { 134 MLIRContext context; 135 FloatType floatTy = FloatType::getBF16(&context); 136 Attribute value = FloatAttr::get(floatTy, 10.0); 137 138 testSplat(floatTy, value); 139 } 140 } // end namespace 141