xref: /llvm-project/mlir/unittests/IR/AttributeTest.cpp (revision d8cd96bc8b7ed55666d3460e906949056820f896)
1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/StandardTypes.h"
20 #include "gtest/gtest.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 template <typename EltTy>
26 static void testSplat(Type eltType, const EltTy &splatElt) {
27   VectorType shape = VectorType::get({2, 1}, eltType);
28 
29   // Check that the generated splat is the same for 1 element and N elements.
30   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
31   EXPECT_TRUE(splat.isSplat());
32 
33   auto detectedSplat =
34       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
35   EXPECT_EQ(detectedSplat, splat);
36 }
37 
38 namespace {
39 TEST(DenseSplatTest, BoolSplat) {
40   MLIRContext context;
41   IntegerType boolTy = IntegerType::get(1, &context);
42   VectorType shape = VectorType::get({2, 2}, boolTy);
43 
44   // Check that splat is automatically detected for boolean values.
45   /// True.
46   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
47   EXPECT_TRUE(trueSplat.isSplat());
48   /// False.
49   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
50   EXPECT_TRUE(falseSplat.isSplat());
51   EXPECT_NE(falseSplat, trueSplat);
52 
53   /// Detect and handle splat within 8 elements (bool values are bit-packed).
54   /// True.
55   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
56   EXPECT_EQ(detectedSplat, trueSplat);
57   /// False.
58   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
59   EXPECT_EQ(detectedSplat, falseSplat);
60 }
61 
62 TEST(DenseSplatTest, LargeBoolSplat) {
63   constexpr size_t boolCount = 56;
64 
65   MLIRContext context;
66   IntegerType boolTy = IntegerType::get(1, &context);
67   VectorType shape = VectorType::get({boolCount}, boolTy);
68 
69   // Check that splat is automatically detected for boolean values.
70   /// True.
71   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
72   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
73   EXPECT_TRUE(trueSplat.isSplat());
74   EXPECT_TRUE(falseSplat.isSplat());
75 
76   /// Detect that the large boolean arrays are properly splatted.
77   /// True.
78   SmallVector<bool, 64> trueValues(boolCount, true);
79   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
80   EXPECT_EQ(detectedSplat, trueSplat);
81   /// False.
82   SmallVector<bool, 64> falseValues(boolCount, false);
83   detectedSplat = DenseElementsAttr::get(shape, falseValues);
84   EXPECT_EQ(detectedSplat, falseSplat);
85 }
86 
87 TEST(DenseSplatTest, OddIntSplat) {
88   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
89   MLIRContext context;
90   constexpr size_t intWidth = 19;
91   IntegerType intTy = IntegerType::get(intWidth, &context);
92   APInt value(intWidth, 10);
93 
94   testSplat(intTy, value);
95 }
96 
97 TEST(DenseSplatTest, Int32Splat) {
98   MLIRContext context;
99   IntegerType intTy = IntegerType::get(32, &context);
100   int value = 64;
101 
102   testSplat(intTy, value);
103 }
104 
105 TEST(DenseSplatTest, IntAttrSplat) {
106   MLIRContext context;
107   IntegerType intTy = IntegerType::get(85, &context);
108   Attribute value = IntegerAttr::get(intTy, 109);
109 
110   testSplat(intTy, value);
111 }
112 
113 TEST(DenseSplatTest, F32Splat) {
114   MLIRContext context;
115   FloatType floatTy = FloatType::getF32(&context);
116   float value = 10.0;
117 
118   testSplat(floatTy, value);
119 }
120 
121 TEST(DenseSplatTest, F64Splat) {
122   MLIRContext context;
123   FloatType floatTy = FloatType::getF64(&context);
124   double value = 10.0;
125 
126   testSplat(floatTy, APFloat(value));
127 }
128 
129 TEST(DenseSplatTest, FloatAttrSplat) {
130   MLIRContext context;
131   FloatType floatTy = FloatType::getBF16(&context);
132   Attribute value = FloatAttr::get(floatTy, 10.0);
133 
134   testSplat(floatTy, value);
135 }
136 } // end namespace
137