xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision b582338f629a6fec01aa3b28dc0fd0d37a72f08a)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/StandardTypes.h"
20 #include "gtest/gtest.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 template <typename EltTy>
26 static void testSplat(Type eltType, const EltTy &splatElt) {
27   VectorType shape = VectorType::get({2, 1}, eltType);
28 
29   // Check that the generated splat is the same for 1 element and N elements.
30   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
31   EXPECT_TRUE(splat.isSplat());
32 
33   auto detectedSplat =
34       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
35   EXPECT_EQ(detectedSplat, splat);
36 }
37 
38 namespace {
39 TEST(DenseSplatTest, BoolSplat) {
40   MLIRContext context;
41   IntegerType boolTy = IntegerType::get(1, &context);
42   VectorType shape = VectorType::get({2, 2}, boolTy);
43 
44   // Check that splat is automatically detected for boolean values.
45   /// True.
46   DenseElementsAttr trueSplat =
47       DenseElementsAttr::get(shape, llvm::ArrayRef<bool>(true));
48   EXPECT_TRUE(trueSplat.isSplat());
49   /// False.
50   DenseElementsAttr falseSplat =
51       DenseElementsAttr::get(shape, llvm::ArrayRef<bool>(false));
52   EXPECT_TRUE(falseSplat.isSplat());
53   EXPECT_NE(falseSplat, trueSplat);
54 
55   /// Detect and handle splat within 8 elements (bool values are bit-packed).
56   /// True.
57   auto detectedSplat = DenseElementsAttr::get(
58       shape, llvm::ArrayRef<bool>({true, true, true, true}));
59   EXPECT_EQ(detectedSplat, trueSplat);
60   /// False.
61   detectedSplat = DenseElementsAttr::get(
62       shape, llvm::ArrayRef<bool>({false, false, false, false}));
63   EXPECT_EQ(detectedSplat, falseSplat);
64 }
65 
66 TEST(DenseSplatTest, LargeBoolSplat) {
67   constexpr int64_t boolCount = 56;
68 
69   MLIRContext context;
70   IntegerType boolTy = IntegerType::get(1, &context);
71   VectorType shape = VectorType::get({boolCount}, boolTy);
72 
73   // Check that splat is automatically detected for boolean values.
74   /// True.
75   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
76   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
77   EXPECT_TRUE(trueSplat.isSplat());
78   EXPECT_TRUE(falseSplat.isSplat());
79 
80   /// Detect that the large boolean arrays are properly splatted.
81   /// True.
82   SmallVector<bool, 64> trueValues(boolCount, true);
83   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
84   EXPECT_EQ(detectedSplat, trueSplat);
85   /// False.
86   SmallVector<bool, 64> falseValues(boolCount, false);
87   detectedSplat = DenseElementsAttr::get(shape, falseValues);
88   EXPECT_EQ(detectedSplat, falseSplat);
89 }
90 
91 TEST(DenseSplatTest, OddIntSplat) {
92   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
93   MLIRContext context;
94   constexpr size_t intWidth = 19;
95   IntegerType intTy = IntegerType::get(intWidth, &context);
96   APInt value(intWidth, 10);
97 
98   testSplat(intTy, value);
99 }
100 
101 TEST(DenseSplatTest, Int32Splat) {
102   MLIRContext context;
103   IntegerType intTy = IntegerType::get(32, &context);
104   int value = 64;
105 
106   testSplat(intTy, value);
107 }
108 
109 TEST(DenseSplatTest, IntAttrSplat) {
110   MLIRContext context;
111   IntegerType intTy = IntegerType::get(85, &context);
112   Attribute value = IntegerAttr::get(intTy, 109);
113 
114   testSplat(intTy, value);
115 }
116 
117 TEST(DenseSplatTest, F32Splat) {
118   MLIRContext context;
119   FloatType floatTy = FloatType::getF32(&context);
120   float value = 10.0;
121 
122   testSplat(floatTy, value);
123 }
124 
125 TEST(DenseSplatTest, F64Splat) {
126   MLIRContext context;
127   FloatType floatTy = FloatType::getF64(&context);
128   double value = 10.0;
129 
130   testSplat(floatTy, APFloat(value));
131 }
132 
133 TEST(DenseSplatTest, FloatAttrSplat) {
134   MLIRContext context;
135   FloatType floatTy = FloatType::getBF16(&context);
136   Attribute value = FloatAttr::get(floatTy, 10.0);
137 
138   testSplat(floatTy, value);
139 }
140 } // end namespace
141