xref: /llvm-project/mlir/unittests/IR/ShapedTypeTest.cpp (revision f023da12d12635f5fba436e825cbfc999e28e623)
1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineMap.h"
10 #include "mlir/IR/BuiltinAttributes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/Support/LLVM.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "gtest/gtest.h"
17 #include <cstdint>
18 
19 using namespace mlir;
20 using namespace mlir::detail;
21 
22 namespace {
23 TEST(ShapedTypeTest, CloneMemref) {
24   MLIRContext context;
25 
26   Type i32 = IntegerType::get(&context, 32);
27   Type f32 = Float32Type::get(&context);
28   Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
29   Type memrefOriginalType = i32;
30   llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
31   AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
32 
33   ShapedType memrefType =
34       (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
35           .setMemorySpace(memSpace)
36           .setLayout(AffineMapAttr::get(map));
37   // Update shape.
38   llvm::SmallVector<int64_t> memrefNewShape({30, 40});
39   ASSERT_NE(memrefOriginalShape, memrefNewShape);
40   ASSERT_EQ(memrefType.clone(memrefNewShape),
41             (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
42                 .setMemorySpace(memSpace)
43                 .setLayout(AffineMapAttr::get(map)));
44   // Update type.
45   Type memrefNewType = f32;
46   ASSERT_NE(memrefOriginalType, memrefNewType);
47   ASSERT_EQ(memrefType.clone(memrefNewType),
48             (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
49                 .setMemorySpace(memSpace)
50                 .setLayout(AffineMapAttr::get(map)));
51   // Update both.
52   ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
53             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
54                 .setMemorySpace(memSpace)
55                 .setLayout(AffineMapAttr::get(map)));
56 
57   // Test unranked memref cloning.
58   ShapedType unrankedTensorType =
59       UnrankedMemRefType::get(memrefOriginalType, memSpace);
60   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
61             (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
62                 .setMemorySpace(memSpace));
63   ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
64             UnrankedMemRefType::get(memrefNewType, memSpace));
65   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
66             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
67                 .setMemorySpace(memSpace));
68 }
69 
70 TEST(ShapedTypeTest, CloneTensor) {
71   MLIRContext context;
72 
73   Type i32 = IntegerType::get(&context, 32);
74   Type f32 = Float32Type::get(&context);
75 
76   Type tensorOriginalType = i32;
77   llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
78 
79   // Test ranked tensor cloning.
80   ShapedType tensorType =
81       RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
82   // Update shape.
83   llvm::SmallVector<int64_t> tensorNewShape({30, 40});
84   ASSERT_NE(tensorOriginalShape, tensorNewShape);
85   ASSERT_EQ(
86       tensorType.clone(tensorNewShape),
87       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
88   // Update type.
89   Type tensorNewType = f32;
90   ASSERT_NE(tensorOriginalType, tensorNewType);
91   ASSERT_EQ(
92       tensorType.clone(tensorNewType),
93       (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
94   // Update both.
95   ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
96             (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
97 
98   // Test unranked tensor cloning.
99   ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
100   ASSERT_EQ(
101       unrankedTensorType.clone(tensorNewShape),
102       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
103   ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
104             (ShapedType)UnrankedTensorType::get(tensorNewType));
105   ASSERT_EQ(
106       unrankedTensorType.clone(tensorNewShape),
107       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
108 }
109 
110 TEST(ShapedTypeTest, CloneVector) {
111   MLIRContext context;
112 
113   Type i32 = IntegerType::get(&context, 32);
114   Type f32 = Float32Type::get(&context);
115 
116   Type vectorOriginalType = i32;
117   llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
118   ShapedType vectorType =
119       VectorType::get(vectorOriginalShape, vectorOriginalType);
120   // Update shape.
121   llvm::SmallVector<int64_t> vectorNewShape({30, 40});
122   ASSERT_NE(vectorOriginalShape, vectorNewShape);
123   ASSERT_EQ(vectorType.clone(vectorNewShape),
124             VectorType::get(vectorNewShape, vectorOriginalType));
125   // Update type.
126   Type vectorNewType = f32;
127   ASSERT_NE(vectorOriginalType, vectorNewType);
128   ASSERT_EQ(vectorType.clone(vectorNewType),
129             VectorType::get(vectorOriginalShape, vectorNewType));
130   // Update both.
131   ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
132             VectorType::get(vectorNewShape, vectorNewType));
133 }
134 
135 TEST(ShapedTypeTest, VectorTypeBuilder) {
136   MLIRContext context;
137   Type f32 = Float32Type::get(&context);
138 
139   SmallVector<int64_t> shape{2, 4, 8, 9, 1};
140   SmallVector<bool> scalableDims{true, false, true, false, false};
141   VectorType vectorType = VectorType::get(shape, f32, scalableDims);
142 
143   {
144     // Drop some dims.
145     VectorType dropFrontTwoDims =
146         VectorType::Builder(vectorType).dropDim(0).dropDim(0);
147     ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
148     ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
149     ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
150               dropFrontTwoDims.getScalableDims());
151   }
152 
153   {
154     // Set some dims.
155     VectorType setTwoDims =
156         VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
157     ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
158     ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
159     ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
160   }
161 
162   {
163     // Test for bug from:
164     // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
165     // Constructs a temporary builder, modifies it, copies it to `builder`.
166     // This used to lead to a use-after-free. Running under sanitizers will
167     // catch any issues.
168     VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
169     VectorType newVectorType = VectorType(builder);
170     ASSERT_EQ(newVectorType.getDimSize(0), 16);
171   }
172 
173   {
174     // Make builder from scratch (without scalable dims) -- this use to lead to
175     // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
176     // Running under sanitizers will catch any issues.
177     SmallVector<int64_t> shape{1, 2, 3, 4};
178     VectorType::Builder builder(shape, f32);
179     ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
180   }
181 
182   {
183     // Set vector shape (without scalable dims) -- this use to lead to
184     // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
185     // Running under sanitizers will catch any issues.
186     VectorType::Builder builder(vectorType);
187     SmallVector<int64_t> newShape{2, 2};
188     builder.setShape(newShape);
189     ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
190   }
191 }
192 
193 TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
194   MLIRContext context;
195   Type f32 = Float32Type::get(&context);
196 
197   SmallVector<int64_t> shape{2, 4, 8, 16, 32};
198   RankedTensorType tensorType = RankedTensorType::get(shape, f32);
199 
200   {
201     // Drop some dims.
202     RankedTensorType dropFrontTwoDims =
203         RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
204     ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
205     ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
206   }
207 
208   {
209     // Insert some dims.
210     RankedTensorType insertTwoDims =
211         RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
212     ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
213     ASSERT_EQ(insertTwoDims.getShape(),
214               ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
215   }
216 
217   {
218     // Test for bug from:
219     // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
220     // Constructs a temporary builder, modifies it, copies it to `builder`.
221     // This used to lead to a use-after-free. Running under sanitizers will
222     // catch any issues.
223     RankedTensorType::Builder builder =
224         RankedTensorType::Builder(tensorType).dropDim(0);
225     RankedTensorType newTensorType = RankedTensorType(builder);
226     ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
227   }
228 }
229 
230 /// Simple wrapper class to enable "isa querying" and simple accessing of
231 /// encoding.
232 class TensorWithString : public RankedTensorType {
233 public:
234   using RankedTensorType::RankedTensorType;
235 
236   static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
237                               StringRef name) {
238     return mlir::cast<TensorWithString>(RankedTensorType::get(
239         shape, elementType, StringAttr::get(elementType.getContext(), name)));
240   }
241 
242   StringRef getName() const {
243     if (Attribute enc = getEncoding())
244       return mlir::cast<StringAttr>(enc).getValue();
245     return {};
246   }
247 
248   static bool classof(Type type) {
249     if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
250       return mlir::isa_and_present<StringAttr>(rt.getEncoding());
251     return false;
252   }
253 };
254 
255 TEST(ShapedTypeTest, RankedTensorTypeView) {
256   MLIRContext context;
257   Type f32 = Float32Type::get(&context);
258 
259   Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
260 
261   UnitAttr unitAttr = UnitAttr::get(&context);
262   Type unitEncodingRankedTensorType =
263       RankedTensorType::get({10, 20}, f32, unitAttr);
264 
265   StringAttr stringAttr = StringAttr::get(&context, "app");
266   Type stringEncodingRankedTensorType =
267       RankedTensorType::get({10, 20}, f32, stringAttr);
268 
269   EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
270   EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
271   ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
272 
273   // Cast to TensorWithString view.
274   auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
275   ASSERT_TRUE(mlir::isa<TensorWithString>(view));
276   EXPECT_EQ(view.getName(), "app");
277   // Verify one could cast view type back to base type.
278   ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
279 
280   Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
281   ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
282   ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
283   view = mlir::cast<TensorWithString>(viewCreated);
284   EXPECT_EQ(view.getName(), "bob");
285 }
286 
287 } // namespace
288