1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/StandardTypes.h" 20 #include "gtest/gtest.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 template <typename EltTy> 26 static void testSplat(Type eltType, const EltTy &splatElt) { 27 VectorType shape = VectorType::get({2, 1}, eltType); 28 29 // Check that the generated splat is the same for 1 element and N elements. 30 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); 31 EXPECT_TRUE(splat.isSplat()); 32 33 auto detectedSplat = 34 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); 35 EXPECT_EQ(detectedSplat, splat); 36 } 37 38 namespace { 39 TEST(DenseSplatTest, BoolSplat) { 40 MLIRContext context; 41 IntegerType boolTy = IntegerType::get(1, &context); 42 VectorType shape = VectorType::get({2, 2}, boolTy); 43 44 // Check that splat is automatically detected for boolean values. 45 /// True. 46 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 47 EXPECT_TRUE(trueSplat.isSplat()); 48 /// False. 49 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 50 EXPECT_TRUE(falseSplat.isSplat()); 51 EXPECT_NE(falseSplat, trueSplat); 52 53 /// Detect and handle splat within 8 elements (bool values are bit-packed). 54 /// True. 55 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); 56 EXPECT_EQ(detectedSplat, trueSplat); 57 /// False. 58 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); 59 EXPECT_EQ(detectedSplat, falseSplat); 60 } 61 62 TEST(DenseSplatTest, LargeBoolSplat) { 63 constexpr size_t boolCount = 56; 64 65 MLIRContext context; 66 IntegerType boolTy = IntegerType::get(1, &context); 67 VectorType shape = VectorType::get({boolCount}, boolTy); 68 69 // Check that splat is automatically detected for boolean values. 70 /// True. 71 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); 72 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); 73 EXPECT_TRUE(trueSplat.isSplat()); 74 EXPECT_TRUE(falseSplat.isSplat()); 75 76 /// Detect that the large boolean arrays are properly splatted. 77 /// True. 78 SmallVector<bool, 64> trueValues(boolCount, true); 79 auto detectedSplat = DenseElementsAttr::get(shape, trueValues); 80 EXPECT_EQ(detectedSplat, trueSplat); 81 /// False. 82 SmallVector<bool, 64> falseValues(boolCount, false); 83 detectedSplat = DenseElementsAttr::get(shape, falseValues); 84 EXPECT_EQ(detectedSplat, falseSplat); 85 } 86 87 TEST(DenseSplatTest, OddIntSplat) { 88 // Test detecting a splat with an odd(non 8-bit) integer bitwidth. 89 MLIRContext context; 90 constexpr size_t intWidth = 19; 91 IntegerType intTy = IntegerType::get(intWidth, &context); 92 APInt value(intWidth, 10); 93 94 testSplat(intTy, value); 95 } 96 97 TEST(DenseSplatTest, Int32Splat) { 98 MLIRContext context; 99 IntegerType intTy = IntegerType::get(32, &context); 100 int value = 64; 101 102 testSplat(intTy, value); 103 } 104 105 TEST(DenseSplatTest, IntAttrSplat) { 106 MLIRContext context; 107 IntegerType intTy = IntegerType::get(85, &context); 108 Attribute value = IntegerAttr::get(intTy, 109); 109 110 testSplat(intTy, value); 111 } 112 113 TEST(DenseSplatTest, F32Splat) { 114 MLIRContext context; 115 FloatType floatTy = FloatType::getF32(&context); 116 float value = 10.0; 117 118 testSplat(floatTy, value); 119 } 120 121 TEST(DenseSplatTest, F64Splat) { 122 MLIRContext context; 123 FloatType floatTy = FloatType::getF64(&context); 124 double value = 10.0; 125 126 testSplat(floatTy, APFloat(value)); 127 } 128 129 TEST(DenseSplatTest, FloatAttrSplat) { 130 MLIRContext context; 131 FloatType floatTy = FloatType::getBF16(&context); 132 Attribute value = FloatAttr::get(floatTy, 10.0); 133 134 testSplat(floatTy, value); 135 } 136 } // end namespace 137