xref: /llvm-project/mlir/unittests/IR/ShapedTypeTest.cpp (revision f023da12d12635f5fba436e825cbfc999e28e623)
1381a65faSJacques Pienaar //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2381a65faSJacques Pienaar //
3381a65faSJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4381a65faSJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
5381a65faSJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6381a65faSJacques Pienaar //
7381a65faSJacques Pienaar //===----------------------------------------------------------------------===//
8381a65faSJacques Pienaar 
9381a65faSJacques Pienaar #include "mlir/IR/AffineMap.h"
10f3bf5c05SVladislav Vinogradov #include "mlir/IR/BuiltinAttributes.h"
11381a65faSJacques Pienaar #include "mlir/IR/BuiltinTypes.h"
12381a65faSJacques Pienaar #include "mlir/IR/Dialect.h"
13381a65faSJacques Pienaar #include "mlir/IR/DialectInterface.h"
14d2f42c73SJacques Pienaar #include "mlir/Support/LLVM.h"
15381a65faSJacques Pienaar #include "llvm/ADT/SmallVector.h"
16381a65faSJacques Pienaar #include "gtest/gtest.h"
17381a65faSJacques Pienaar #include <cstdint>
18381a65faSJacques Pienaar 
19381a65faSJacques Pienaar using namespace mlir;
20381a65faSJacques Pienaar using namespace mlir::detail;
21381a65faSJacques Pienaar 
22381a65faSJacques Pienaar namespace {
23381a65faSJacques Pienaar TEST(ShapedTypeTest, CloneMemref) {
24381a65faSJacques Pienaar   MLIRContext context;
25381a65faSJacques Pienaar 
26381a65faSJacques Pienaar   Type i32 = IntegerType::get(&context, 32);
27*f023da12SMatthias Springer   Type f32 = Float32Type::get(&context);
28f3bf5c05SVladislav Vinogradov   Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
29381a65faSJacques Pienaar   Type memrefOriginalType = i32;
30381a65faSJacques Pienaar   llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
31381a65faSJacques Pienaar   AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
32381a65faSJacques Pienaar 
33381a65faSJacques Pienaar   ShapedType memrefType =
34676bfb2aSRiver Riddle       (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
35381a65faSJacques Pienaar           .setMemorySpace(memSpace)
36e41ebbecSVladislav Vinogradov           .setLayout(AffineMapAttr::get(map));
37381a65faSJacques Pienaar   // Update shape.
38381a65faSJacques Pienaar   llvm::SmallVector<int64_t> memrefNewShape({30, 40});
39381a65faSJacques Pienaar   ASSERT_NE(memrefOriginalShape, memrefNewShape);
40381a65faSJacques Pienaar   ASSERT_EQ(memrefType.clone(memrefNewShape),
41676bfb2aSRiver Riddle             (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
42381a65faSJacques Pienaar                 .setMemorySpace(memSpace)
43e41ebbecSVladislav Vinogradov                 .setLayout(AffineMapAttr::get(map)));
44381a65faSJacques Pienaar   // Update type.
45381a65faSJacques Pienaar   Type memrefNewType = f32;
46381a65faSJacques Pienaar   ASSERT_NE(memrefOriginalType, memrefNewType);
47381a65faSJacques Pienaar   ASSERT_EQ(memrefType.clone(memrefNewType),
48381a65faSJacques Pienaar             (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
49381a65faSJacques Pienaar                 .setMemorySpace(memSpace)
50e41ebbecSVladislav Vinogradov                 .setLayout(AffineMapAttr::get(map)));
51381a65faSJacques Pienaar   // Update both.
52381a65faSJacques Pienaar   ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
53381a65faSJacques Pienaar             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
54381a65faSJacques Pienaar                 .setMemorySpace(memSpace)
55e41ebbecSVladislav Vinogradov                 .setLayout(AffineMapAttr::get(map)));
56381a65faSJacques Pienaar 
57381a65faSJacques Pienaar   // Test unranked memref cloning.
58381a65faSJacques Pienaar   ShapedType unrankedTensorType =
59381a65faSJacques Pienaar       UnrankedMemRefType::get(memrefOriginalType, memSpace);
60381a65faSJacques Pienaar   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
61381a65faSJacques Pienaar             (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
62381a65faSJacques Pienaar                 .setMemorySpace(memSpace));
63381a65faSJacques Pienaar   ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
64381a65faSJacques Pienaar             UnrankedMemRefType::get(memrefNewType, memSpace));
65381a65faSJacques Pienaar   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
66381a65faSJacques Pienaar             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
67381a65faSJacques Pienaar                 .setMemorySpace(memSpace));
68381a65faSJacques Pienaar }
69381a65faSJacques Pienaar 
70381a65faSJacques Pienaar TEST(ShapedTypeTest, CloneTensor) {
71381a65faSJacques Pienaar   MLIRContext context;
72381a65faSJacques Pienaar 
73381a65faSJacques Pienaar   Type i32 = IntegerType::get(&context, 32);
74*f023da12SMatthias Springer   Type f32 = Float32Type::get(&context);
75381a65faSJacques Pienaar 
76381a65faSJacques Pienaar   Type tensorOriginalType = i32;
77381a65faSJacques Pienaar   llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
78381a65faSJacques Pienaar 
79381a65faSJacques Pienaar   // Test ranked tensor cloning.
80381a65faSJacques Pienaar   ShapedType tensorType =
81381a65faSJacques Pienaar       RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
82381a65faSJacques Pienaar   // Update shape.
83381a65faSJacques Pienaar   llvm::SmallVector<int64_t> tensorNewShape({30, 40});
84381a65faSJacques Pienaar   ASSERT_NE(tensorOriginalShape, tensorNewShape);
85676bfb2aSRiver Riddle   ASSERT_EQ(
86676bfb2aSRiver Riddle       tensorType.clone(tensorNewShape),
87676bfb2aSRiver Riddle       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
88381a65faSJacques Pienaar   // Update type.
89381a65faSJacques Pienaar   Type tensorNewType = f32;
90381a65faSJacques Pienaar   ASSERT_NE(tensorOriginalType, tensorNewType);
91676bfb2aSRiver Riddle   ASSERT_EQ(
92676bfb2aSRiver Riddle       tensorType.clone(tensorNewType),
93676bfb2aSRiver Riddle       (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
94381a65faSJacques Pienaar   // Update both.
95381a65faSJacques Pienaar   ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
96676bfb2aSRiver Riddle             (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
97381a65faSJacques Pienaar 
98381a65faSJacques Pienaar   // Test unranked tensor cloning.
99381a65faSJacques Pienaar   ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
100676bfb2aSRiver Riddle   ASSERT_EQ(
101676bfb2aSRiver Riddle       unrankedTensorType.clone(tensorNewShape),
102676bfb2aSRiver Riddle       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
103381a65faSJacques Pienaar   ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
104676bfb2aSRiver Riddle             (ShapedType)UnrankedTensorType::get(tensorNewType));
105676bfb2aSRiver Riddle   ASSERT_EQ(
106676bfb2aSRiver Riddle       unrankedTensorType.clone(tensorNewShape),
107676bfb2aSRiver Riddle       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
108381a65faSJacques Pienaar }
109381a65faSJacques Pienaar 
110381a65faSJacques Pienaar TEST(ShapedTypeTest, CloneVector) {
111381a65faSJacques Pienaar   MLIRContext context;
112381a65faSJacques Pienaar 
113381a65faSJacques Pienaar   Type i32 = IntegerType::get(&context, 32);
114*f023da12SMatthias Springer   Type f32 = Float32Type::get(&context);
115381a65faSJacques Pienaar 
116381a65faSJacques Pienaar   Type vectorOriginalType = i32;
117381a65faSJacques Pienaar   llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
118381a65faSJacques Pienaar   ShapedType vectorType =
119381a65faSJacques Pienaar       VectorType::get(vectorOriginalShape, vectorOriginalType);
120381a65faSJacques Pienaar   // Update shape.
121381a65faSJacques Pienaar   llvm::SmallVector<int64_t> vectorNewShape({30, 40});
122381a65faSJacques Pienaar   ASSERT_NE(vectorOriginalShape, vectorNewShape);
123381a65faSJacques Pienaar   ASSERT_EQ(vectorType.clone(vectorNewShape),
124381a65faSJacques Pienaar             VectorType::get(vectorNewShape, vectorOriginalType));
125381a65faSJacques Pienaar   // Update type.
126381a65faSJacques Pienaar   Type vectorNewType = f32;
127381a65faSJacques Pienaar   ASSERT_NE(vectorOriginalType, vectorNewType);
128381a65faSJacques Pienaar   ASSERT_EQ(vectorType.clone(vectorNewType),
129381a65faSJacques Pienaar             VectorType::get(vectorOriginalShape, vectorNewType));
130381a65faSJacques Pienaar   // Update both.
131381a65faSJacques Pienaar   ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
132381a65faSJacques Pienaar             VectorType::get(vectorNewShape, vectorNewType));
133381a65faSJacques Pienaar }
134381a65faSJacques Pienaar 
135b0b8e83eSBenjamin Maxwell TEST(ShapedTypeTest, VectorTypeBuilder) {
136b0b8e83eSBenjamin Maxwell   MLIRContext context;
137*f023da12SMatthias Springer   Type f32 = Float32Type::get(&context);
138b0b8e83eSBenjamin Maxwell 
139b0b8e83eSBenjamin Maxwell   SmallVector<int64_t> shape{2, 4, 8, 9, 1};
140b0b8e83eSBenjamin Maxwell   SmallVector<bool> scalableDims{true, false, true, false, false};
141b0b8e83eSBenjamin Maxwell   VectorType vectorType = VectorType::get(shape, f32, scalableDims);
142b0b8e83eSBenjamin Maxwell 
143b0b8e83eSBenjamin Maxwell   {
144b0b8e83eSBenjamin Maxwell     // Drop some dims.
145b0b8e83eSBenjamin Maxwell     VectorType dropFrontTwoDims =
146b0b8e83eSBenjamin Maxwell         VectorType::Builder(vectorType).dropDim(0).dropDim(0);
147b0b8e83eSBenjamin Maxwell     ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
148b0b8e83eSBenjamin Maxwell     ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
149b0b8e83eSBenjamin Maxwell     ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
150b0b8e83eSBenjamin Maxwell               dropFrontTwoDims.getScalableDims());
151b0b8e83eSBenjamin Maxwell   }
152b0b8e83eSBenjamin Maxwell 
153b0b8e83eSBenjamin Maxwell   {
154b0b8e83eSBenjamin Maxwell     // Set some dims.
155b0b8e83eSBenjamin Maxwell     VectorType setTwoDims =
156b0b8e83eSBenjamin Maxwell         VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
157b0b8e83eSBenjamin Maxwell     ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
158b0b8e83eSBenjamin Maxwell     ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
159b0b8e83eSBenjamin Maxwell     ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
160b0b8e83eSBenjamin Maxwell   }
161b0b8e83eSBenjamin Maxwell 
162b0b8e83eSBenjamin Maxwell   {
163b0b8e83eSBenjamin Maxwell     // Test for bug from:
164b0b8e83eSBenjamin Maxwell     // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
165b0b8e83eSBenjamin Maxwell     // Constructs a temporary builder, modifies it, copies it to `builder`.
166b0b8e83eSBenjamin Maxwell     // This used to lead to a use-after-free. Running under sanitizers will
167b0b8e83eSBenjamin Maxwell     // catch any issues.
168b0b8e83eSBenjamin Maxwell     VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
169b0b8e83eSBenjamin Maxwell     VectorType newVectorType = VectorType(builder);
170b0b8e83eSBenjamin Maxwell     ASSERT_EQ(newVectorType.getDimSize(0), 16);
171b0b8e83eSBenjamin Maxwell   }
172b0b8e83eSBenjamin Maxwell 
173b0b8e83eSBenjamin Maxwell   {
174b0b8e83eSBenjamin Maxwell     // Make builder from scratch (without scalable dims) -- this use to lead to
175b0b8e83eSBenjamin Maxwell     // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
176b0b8e83eSBenjamin Maxwell     // Running under sanitizers will catch any issues.
177b0b8e83eSBenjamin Maxwell     SmallVector<int64_t> shape{1, 2, 3, 4};
178b0b8e83eSBenjamin Maxwell     VectorType::Builder builder(shape, f32);
179b0b8e83eSBenjamin Maxwell     ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
180b0b8e83eSBenjamin Maxwell   }
181b0b8e83eSBenjamin Maxwell 
182b0b8e83eSBenjamin Maxwell   {
183b0b8e83eSBenjamin Maxwell     // Set vector shape (without scalable dims) -- this use to lead to
184b0b8e83eSBenjamin Maxwell     // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
185b0b8e83eSBenjamin Maxwell     // Running under sanitizers will catch any issues.
186b0b8e83eSBenjamin Maxwell     VectorType::Builder builder(vectorType);
187b0b8e83eSBenjamin Maxwell     SmallVector<int64_t> newShape{2, 2};
188b0b8e83eSBenjamin Maxwell     builder.setShape(newShape);
189b0b8e83eSBenjamin Maxwell     ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
190b0b8e83eSBenjamin Maxwell   }
191b0b8e83eSBenjamin Maxwell }
192b0b8e83eSBenjamin Maxwell 
193b0b8e83eSBenjamin Maxwell TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
194b0b8e83eSBenjamin Maxwell   MLIRContext context;
195*f023da12SMatthias Springer   Type f32 = Float32Type::get(&context);
196b0b8e83eSBenjamin Maxwell 
197b0b8e83eSBenjamin Maxwell   SmallVector<int64_t> shape{2, 4, 8, 16, 32};
198b0b8e83eSBenjamin Maxwell   RankedTensorType tensorType = RankedTensorType::get(shape, f32);
199b0b8e83eSBenjamin Maxwell 
200b0b8e83eSBenjamin Maxwell   {
201b0b8e83eSBenjamin Maxwell     // Drop some dims.
202b0b8e83eSBenjamin Maxwell     RankedTensorType dropFrontTwoDims =
203b0b8e83eSBenjamin Maxwell         RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
204b0b8e83eSBenjamin Maxwell     ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
205b0b8e83eSBenjamin Maxwell     ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
206b0b8e83eSBenjamin Maxwell   }
207b0b8e83eSBenjamin Maxwell 
208b0b8e83eSBenjamin Maxwell   {
209b0b8e83eSBenjamin Maxwell     // Insert some dims.
210b0b8e83eSBenjamin Maxwell     RankedTensorType insertTwoDims =
211b0b8e83eSBenjamin Maxwell         RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
212b0b8e83eSBenjamin Maxwell     ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
213b0b8e83eSBenjamin Maxwell     ASSERT_EQ(insertTwoDims.getShape(),
214b0b8e83eSBenjamin Maxwell               ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
215b0b8e83eSBenjamin Maxwell   }
216b0b8e83eSBenjamin Maxwell 
217b0b8e83eSBenjamin Maxwell   {
218b0b8e83eSBenjamin Maxwell     // Test for bug from:
219b0b8e83eSBenjamin Maxwell     // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
220b0b8e83eSBenjamin Maxwell     // Constructs a temporary builder, modifies it, copies it to `builder`.
221b0b8e83eSBenjamin Maxwell     // This used to lead to a use-after-free. Running under sanitizers will
222b0b8e83eSBenjamin Maxwell     // catch any issues.
223b0b8e83eSBenjamin Maxwell     RankedTensorType::Builder builder =
224b0b8e83eSBenjamin Maxwell         RankedTensorType::Builder(tensorType).dropDim(0);
225b0b8e83eSBenjamin Maxwell     RankedTensorType newTensorType = RankedTensorType(builder);
226b0b8e83eSBenjamin Maxwell     ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
227b0b8e83eSBenjamin Maxwell   }
228b0b8e83eSBenjamin Maxwell }
229b0b8e83eSBenjamin Maxwell 
230d2f42c73SJacques Pienaar /// Simple wrapper class to enable "isa querying" and simple accessing of
231d2f42c73SJacques Pienaar /// encoding.
232d2f42c73SJacques Pienaar class TensorWithString : public RankedTensorType {
233d2f42c73SJacques Pienaar public:
234d2f42c73SJacques Pienaar   using RankedTensorType::RankedTensorType;
235d2f42c73SJacques Pienaar 
236d2f42c73SJacques Pienaar   static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
237d2f42c73SJacques Pienaar                               StringRef name) {
238d2f42c73SJacques Pienaar     return mlir::cast<TensorWithString>(RankedTensorType::get(
239d2f42c73SJacques Pienaar         shape, elementType, StringAttr::get(elementType.getContext(), name)));
240d2f42c73SJacques Pienaar   }
241d2f42c73SJacques Pienaar 
242d2f42c73SJacques Pienaar   StringRef getName() const {
243d2f42c73SJacques Pienaar     if (Attribute enc = getEncoding())
244d2f42c73SJacques Pienaar       return mlir::cast<StringAttr>(enc).getValue();
245d2f42c73SJacques Pienaar     return {};
246d2f42c73SJacques Pienaar   }
247d2f42c73SJacques Pienaar 
248d2f42c73SJacques Pienaar   static bool classof(Type type) {
249d2f42c73SJacques Pienaar     if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
250d2f42c73SJacques Pienaar       return mlir::isa_and_present<StringAttr>(rt.getEncoding());
251d2f42c73SJacques Pienaar     return false;
252d2f42c73SJacques Pienaar   }
253d2f42c73SJacques Pienaar };
254d2f42c73SJacques Pienaar 
255d2f42c73SJacques Pienaar TEST(ShapedTypeTest, RankedTensorTypeView) {
256d2f42c73SJacques Pienaar   MLIRContext context;
257*f023da12SMatthias Springer   Type f32 = Float32Type::get(&context);
258d2f42c73SJacques Pienaar 
259d2f42c73SJacques Pienaar   Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
260d2f42c73SJacques Pienaar 
261d2f42c73SJacques Pienaar   UnitAttr unitAttr = UnitAttr::get(&context);
262d2f42c73SJacques Pienaar   Type unitEncodingRankedTensorType =
263d2f42c73SJacques Pienaar       RankedTensorType::get({10, 20}, f32, unitAttr);
264d2f42c73SJacques Pienaar 
265d2f42c73SJacques Pienaar   StringAttr stringAttr = StringAttr::get(&context, "app");
266d2f42c73SJacques Pienaar   Type stringEncodingRankedTensorType =
267d2f42c73SJacques Pienaar       RankedTensorType::get({10, 20}, f32, stringAttr);
268d2f42c73SJacques Pienaar 
269d2f42c73SJacques Pienaar   EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
270d2f42c73SJacques Pienaar   EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
271d2f42c73SJacques Pienaar   ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
272d2f42c73SJacques Pienaar 
273d2f42c73SJacques Pienaar   // Cast to TensorWithString view.
274d2f42c73SJacques Pienaar   auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
275d2f42c73SJacques Pienaar   ASSERT_TRUE(mlir::isa<TensorWithString>(view));
276d2f42c73SJacques Pienaar   EXPECT_EQ(view.getName(), "app");
277d2f42c73SJacques Pienaar   // Verify one could cast view type back to base type.
278d2f42c73SJacques Pienaar   ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
279d2f42c73SJacques Pienaar 
280d2f42c73SJacques Pienaar   Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
281d2f42c73SJacques Pienaar   ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
282d2f42c73SJacques Pienaar   ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
283d2f42c73SJacques Pienaar   view = mlir::cast<TensorWithString>(viewCreated);
284d2f42c73SJacques Pienaar   EXPECT_EQ(view.getName(), "bob");
285d2f42c73SJacques Pienaar }
286d2f42c73SJacques Pienaar 
287be0a7e9fSMehdi Amini } // namespace
288