1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "mlir/IR/BuiltinAttributes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "gtest/gtest.h" 16 #include <cstdint> 17 18 using namespace mlir; 19 using namespace mlir::detail; 20 21 namespace { 22 TEST(ShapedTypeTest, CloneMemref) { 23 MLIRContext context; 24 25 Type i32 = IntegerType::get(&context, 32); 26 Type f32 = FloatType::getF32(&context); 27 Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); 28 Type memrefOriginalType = i32; 29 llvm::SmallVector<int64_t> memrefOriginalShape({10, 20}); 30 AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context); 31 32 ShapedType memrefType = 33 MemRefType::Builder(memrefOriginalShape, memrefOriginalType) 34 .setMemorySpace(memSpace) 35 .setLayout(AffineMapAttr::get(map)); 36 // Update shape. 37 llvm::SmallVector<int64_t> memrefNewShape({30, 40}); 38 ASSERT_NE(memrefOriginalShape, memrefNewShape); 39 ASSERT_EQ(memrefType.clone(memrefNewShape), 40 (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 41 .setMemorySpace(memSpace) 42 .setLayout(AffineMapAttr::get(map))); 43 // Update type. 44 Type memrefNewType = f32; 45 ASSERT_NE(memrefOriginalType, memrefNewType); 46 ASSERT_EQ(memrefType.clone(memrefNewType), 47 (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) 48 .setMemorySpace(memSpace) 49 .setLayout(AffineMapAttr::get(map))); 50 // Update both. 51 ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), 52 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 53 .setMemorySpace(memSpace) 54 .setLayout(AffineMapAttr::get(map))); 55 56 // Test unranked memref cloning. 57 ShapedType unrankedTensorType = 58 UnrankedMemRefType::get(memrefOriginalType, memSpace); 59 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape), 60 (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 61 .setMemorySpace(memSpace)); 62 ASSERT_EQ(unrankedTensorType.clone(memrefNewType), 63 UnrankedMemRefType::get(memrefNewType, memSpace)); 64 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType), 65 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 66 .setMemorySpace(memSpace)); 67 } 68 69 TEST(ShapedTypeTest, CloneTensor) { 70 MLIRContext context; 71 72 Type i32 = IntegerType::get(&context, 32); 73 Type f32 = FloatType::getF32(&context); 74 75 Type tensorOriginalType = i32; 76 llvm::SmallVector<int64_t> tensorOriginalShape({10, 20}); 77 78 // Test ranked tensor cloning. 79 ShapedType tensorType = 80 RankedTensorType::get(tensorOriginalShape, tensorOriginalType); 81 // Update shape. 82 llvm::SmallVector<int64_t> tensorNewShape({30, 40}); 83 ASSERT_NE(tensorOriginalShape, tensorNewShape); 84 ASSERT_EQ(tensorType.clone(tensorNewShape), 85 RankedTensorType::get(tensorNewShape, tensorOriginalType)); 86 // Update type. 87 Type tensorNewType = f32; 88 ASSERT_NE(tensorOriginalType, tensorNewType); 89 ASSERT_EQ(tensorType.clone(tensorNewType), 90 RankedTensorType::get(tensorOriginalShape, tensorNewType)); 91 // Update both. 92 ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), 93 RankedTensorType::get(tensorNewShape, tensorNewType)); 94 95 // Test unranked tensor cloning. 96 ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); 97 ASSERT_EQ(unrankedTensorType.clone(tensorNewShape), 98 RankedTensorType::get(tensorNewShape, tensorOriginalType)); 99 ASSERT_EQ(unrankedTensorType.clone(tensorNewType), 100 UnrankedTensorType::get(tensorNewType)); 101 ASSERT_EQ(unrankedTensorType.clone(tensorNewShape), 102 RankedTensorType::get(tensorNewShape, tensorOriginalType)); 103 } 104 105 TEST(ShapedTypeTest, CloneVector) { 106 MLIRContext context; 107 108 Type i32 = IntegerType::get(&context, 32); 109 Type f32 = FloatType::getF32(&context); 110 111 Type vectorOriginalType = i32; 112 llvm::SmallVector<int64_t> vectorOriginalShape({10, 20}); 113 ShapedType vectorType = 114 VectorType::get(vectorOriginalShape, vectorOriginalType); 115 // Update shape. 116 llvm::SmallVector<int64_t> vectorNewShape({30, 40}); 117 ASSERT_NE(vectorOriginalShape, vectorNewShape); 118 ASSERT_EQ(vectorType.clone(vectorNewShape), 119 VectorType::get(vectorNewShape, vectorOriginalType)); 120 // Update type. 121 Type vectorNewType = f32; 122 ASSERT_NE(vectorOriginalType, vectorNewType); 123 ASSERT_EQ(vectorType.clone(vectorNewType), 124 VectorType::get(vectorOriginalShape, vectorNewType)); 125 // Update both. 126 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), 127 VectorType::get(vectorNewShape, vectorNewType)); 128 } 129 130 } // namespace 131