xref: /llvm-project/mlir/unittests/IR/ShapedTypeTest.cpp (revision be0a7e9f27083ada6072fcc0711ffa5630daa5ec)
1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineMap.h"
10 #include "mlir/IR/BuiltinAttributes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "gtest/gtest.h"
16 #include <cstdint>
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
21 namespace {
22 TEST(ShapedTypeTest, CloneMemref) {
23   MLIRContext context;
24 
25   Type i32 = IntegerType::get(&context, 32);
26   Type f32 = FloatType::getF32(&context);
27   Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
28   Type memrefOriginalType = i32;
29   llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
30   AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
31 
32   ShapedType memrefType =
33       MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
34           .setMemorySpace(memSpace)
35           .setLayout(AffineMapAttr::get(map));
36   // Update shape.
37   llvm::SmallVector<int64_t> memrefNewShape({30, 40});
38   ASSERT_NE(memrefOriginalShape, memrefNewShape);
39   ASSERT_EQ(memrefType.clone(memrefNewShape),
40             (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
41                 .setMemorySpace(memSpace)
42                 .setLayout(AffineMapAttr::get(map)));
43   // Update type.
44   Type memrefNewType = f32;
45   ASSERT_NE(memrefOriginalType, memrefNewType);
46   ASSERT_EQ(memrefType.clone(memrefNewType),
47             (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
48                 .setMemorySpace(memSpace)
49                 .setLayout(AffineMapAttr::get(map)));
50   // Update both.
51   ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
52             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
53                 .setMemorySpace(memSpace)
54                 .setLayout(AffineMapAttr::get(map)));
55 
56   // Test unranked memref cloning.
57   ShapedType unrankedTensorType =
58       UnrankedMemRefType::get(memrefOriginalType, memSpace);
59   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
60             (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
61                 .setMemorySpace(memSpace));
62   ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
63             UnrankedMemRefType::get(memrefNewType, memSpace));
64   ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
65             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
66                 .setMemorySpace(memSpace));
67 }
68 
69 TEST(ShapedTypeTest, CloneTensor) {
70   MLIRContext context;
71 
72   Type i32 = IntegerType::get(&context, 32);
73   Type f32 = FloatType::getF32(&context);
74 
75   Type tensorOriginalType = i32;
76   llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
77 
78   // Test ranked tensor cloning.
79   ShapedType tensorType =
80       RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
81   // Update shape.
82   llvm::SmallVector<int64_t> tensorNewShape({30, 40});
83   ASSERT_NE(tensorOriginalShape, tensorNewShape);
84   ASSERT_EQ(tensorType.clone(tensorNewShape),
85             RankedTensorType::get(tensorNewShape, tensorOriginalType));
86   // Update type.
87   Type tensorNewType = f32;
88   ASSERT_NE(tensorOriginalType, tensorNewType);
89   ASSERT_EQ(tensorType.clone(tensorNewType),
90             RankedTensorType::get(tensorOriginalShape, tensorNewType));
91   // Update both.
92   ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
93             RankedTensorType::get(tensorNewShape, tensorNewType));
94 
95   // Test unranked tensor cloning.
96   ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
97   ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
98             RankedTensorType::get(tensorNewShape, tensorOriginalType));
99   ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
100             UnrankedTensorType::get(tensorNewType));
101   ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
102             RankedTensorType::get(tensorNewShape, tensorOriginalType));
103 }
104 
105 TEST(ShapedTypeTest, CloneVector) {
106   MLIRContext context;
107 
108   Type i32 = IntegerType::get(&context, 32);
109   Type f32 = FloatType::getF32(&context);
110 
111   Type vectorOriginalType = i32;
112   llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
113   ShapedType vectorType =
114       VectorType::get(vectorOriginalShape, vectorOriginalType);
115   // Update shape.
116   llvm::SmallVector<int64_t> vectorNewShape({30, 40});
117   ASSERT_NE(vectorOriginalShape, vectorNewShape);
118   ASSERT_EQ(vectorType.clone(vectorNewShape),
119             VectorType::get(vectorNewShape, vectorOriginalType));
120   // Update type.
121   Type vectorNewType = f32;
122   ASSERT_NE(vectorOriginalType, vectorNewType);
123   ASSERT_EQ(vectorType.clone(vectorNewType),
124             VectorType::get(vectorOriginalShape, vectorNewType));
125   // Update both.
126   ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
127             VectorType::get(vectorNewShape, vectorNewType));
128 }
129 
130 } // namespace
131