xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision 2b67821b909e9cdb19cfbb2d37f500e60c594824)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/StandardTypes.h"
20 #include "gtest/gtest.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 template <typename EltTy>
26 static void testSplat(Type eltType, const EltTy &splatElt) {
27   VectorType shape = VectorType::get({2, 1}, eltType);
28 
29   // Check that the generated splat is the same for 1 element and N elements.
30   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
31   EXPECT_TRUE(splat.isSplat());
32 
33   auto detectedSplat =
34       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
35   EXPECT_EQ(detectedSplat, splat);
36 }
37 
38 namespace {
39 TEST(DenseSplatTest, BoolSplat) {
40   MLIRContext context;
41   IntegerType boolTy = IntegerType::get(1, &context);
42   VectorType shape = VectorType::get({2, 2}, boolTy);
43 
44   // Check that splat is automatically detected for boolean values.
45   /// True.
46   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
47   EXPECT_TRUE(trueSplat.isSplat());
48   /// False.
49   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
50   EXPECT_TRUE(falseSplat.isSplat());
51   EXPECT_NE(falseSplat, trueSplat);
52 
53   /// Detect and handle splat within 8 elements (bool values are bit-packed).
54   /// True.
55   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
56   EXPECT_EQ(detectedSplat, trueSplat);
57   /// False.
58   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
59   EXPECT_EQ(detectedSplat, falseSplat);
60 }
61 
62 TEST(DenseSplatTest, LargeBoolSplat) {
63   constexpr int64_t boolCount = 56;
64 
65   MLIRContext context;
66   IntegerType boolTy = IntegerType::get(1, &context);
67   VectorType shape = VectorType::get({boolCount}, boolTy);
68 
69   // Check that splat is automatically detected for boolean values.
70   /// True.
71   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
72   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
73   EXPECT_TRUE(trueSplat.isSplat());
74   EXPECT_TRUE(falseSplat.isSplat());
75 
76   /// Detect that the large boolean arrays are properly splatted.
77   /// True.
78   SmallVector<bool, 64> trueValues(boolCount, true);
79   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
80   EXPECT_EQ(detectedSplat, trueSplat);
81   /// False.
82   SmallVector<bool, 64> falseValues(boolCount, false);
83   detectedSplat = DenseElementsAttr::get(shape, falseValues);
84   EXPECT_EQ(detectedSplat, falseSplat);
85 }
86 
87 TEST(DenseSplatTest, BoolNonSplat) {
88   MLIRContext context;
89   IntegerType boolTy = IntegerType::get(1, &context);
90   VectorType shape = VectorType::get({6}, boolTy);
91 
92   // Check that we properly handle non-splat values.
93   DenseElementsAttr nonSplat =
94       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
95   EXPECT_FALSE(nonSplat.isSplat());
96 }
97 
98 TEST(DenseSplatTest, OddIntSplat) {
99   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
100   MLIRContext context;
101   constexpr size_t intWidth = 19;
102   IntegerType intTy = IntegerType::get(intWidth, &context);
103   APInt value(intWidth, 10);
104 
105   testSplat(intTy, value);
106 }
107 
108 TEST(DenseSplatTest, Int32Splat) {
109   MLIRContext context;
110   IntegerType intTy = IntegerType::get(32, &context);
111   int value = 64;
112 
113   testSplat(intTy, value);
114 }
115 
116 TEST(DenseSplatTest, IntAttrSplat) {
117   MLIRContext context;
118   IntegerType intTy = IntegerType::get(85, &context);
119   Attribute value = IntegerAttr::get(intTy, 109);
120 
121   testSplat(intTy, value);
122 }
123 
124 TEST(DenseSplatTest, F32Splat) {
125   MLIRContext context;
126   FloatType floatTy = FloatType::getF32(&context);
127   float value = 10.0;
128 
129   testSplat(floatTy, value);
130 }
131 
132 TEST(DenseSplatTest, F64Splat) {
133   MLIRContext context;
134   FloatType floatTy = FloatType::getF64(&context);
135   double value = 10.0;
136 
137   testSplat(floatTy, APFloat(value));
138 }
139 
140 TEST(DenseSplatTest, FloatAttrSplat) {
141   MLIRContext context;
142   FloatType floatTy = FloatType::getBF16(&context);
143   Attribute value = FloatAttr::get(floatTy, 10.0);
144 
145   testSplat(floatTy, value);
146 }
147 } // end namespace
148