xref: /llvm-project/mlir/unittests/IR/ShapedTypeTest.cpp (revision 676bfb2a226e705d801016bb433b68a1e09a1e10)
1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineMap.h"
10 #include "mlir/IR/BuiltinAttributes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "gtest/gtest.h"
16 #include <cstdint>
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
21 namespace {
22 TEST(ShapedTypeTest, CloneMemref) {
23   MLIRContext context;
24 
25   Type i32 = IntegerType::get(&context, 32);
26   Type f32 = FloatType::getF32(&context);
27   Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
28   Type memrefOriginalType = i32;
29   llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
30   AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
31 
32   ShapedType memrefType =
33       (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
34           .setMemorySpace(memSpace)
35           .setLayout(AffineMapAttr::get(map));
36   // Update shape.
37   llvm::SmallVector<int64_t> memrefNewShape({30, 40});
38   ASSERT_NE(memrefOriginalShape, memrefNewShape);
39   ASSERT_EQ(memrefType.clone(memrefNewShape),
40             (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
41                 .setMemorySpace(memSpace)
42                 .setLayout(AffineMapAttr::get(map)));
43   // Update type.
44   Type memrefNewType = f32;
45   ASSERT_NE(memrefOriginalType, memrefNewType);
46   ASSERT_EQ(memrefType.clone(memrefNewType),
47             (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
48                 .setMemorySpace(memSpace)
49                 .setLayout(AffineMapAttr::get(map)));
50   // Update both.
51   ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
52             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
53                 .setMemorySpace(memSpace)
54                 .setLayout(AffineMapAttr::get(map)));
55 
56   // Test unranked memref cloning.
57   ShapedType unrankedTensorType =
58       UnrankedMemRefType::get(memrefOriginalType, memSpace);
59   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
60             (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
61                 .setMemorySpace(memSpace));
62   ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
63             UnrankedMemRefType::get(memrefNewType, memSpace));
64   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
65             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
66                 .setMemorySpace(memSpace));
67 }
68 
69 TEST(ShapedTypeTest, CloneTensor) {
70   MLIRContext context;
71 
72   Type i32 = IntegerType::get(&context, 32);
73   Type f32 = FloatType::getF32(&context);
74 
75   Type tensorOriginalType = i32;
76   llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
77 
78   // Test ranked tensor cloning.
79   ShapedType tensorType =
80       RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
81   // Update shape.
82   llvm::SmallVector<int64_t> tensorNewShape({30, 40});
83   ASSERT_NE(tensorOriginalShape, tensorNewShape);
84   ASSERT_EQ(
85       tensorType.clone(tensorNewShape),
86       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
87   // Update type.
88   Type tensorNewType = f32;
89   ASSERT_NE(tensorOriginalType, tensorNewType);
90   ASSERT_EQ(
91       tensorType.clone(tensorNewType),
92       (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
93   // Update both.
94   ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
95             (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
96 
97   // Test unranked tensor cloning.
98   ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
99   ASSERT_EQ(
100       unrankedTensorType.clone(tensorNewShape),
101       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
102   ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
103             (ShapedType)UnrankedTensorType::get(tensorNewType));
104   ASSERT_EQ(
105       unrankedTensorType.clone(tensorNewShape),
106       (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
107 }
108 
109 TEST(ShapedTypeTest, CloneVector) {
110   MLIRContext context;
111 
112   Type i32 = IntegerType::get(&context, 32);
113   Type f32 = FloatType::getF32(&context);
114 
115   Type vectorOriginalType = i32;
116   llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
117   ShapedType vectorType =
118       VectorType::get(vectorOriginalShape, vectorOriginalType);
119   // Update shape.
120   llvm::SmallVector<int64_t> vectorNewShape({30, 40});
121   ASSERT_NE(vectorOriginalShape, vectorNewShape);
122   ASSERT_EQ(vectorType.clone(vectorNewShape),
123             VectorType::get(vectorNewShape, vectorOriginalType));
124   // Update type.
125   Type vectorNewType = f32;
126   ASSERT_NE(vectorOriginalType, vectorNewType);
127   ASSERT_EQ(vectorType.clone(vectorNewType),
128             VectorType::get(vectorOriginalShape, vectorNewType));
129   // Update both.
130   ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
131             VectorType::get(vectorNewShape, vectorNewType));
132 }
133 
134 } // namespace
135