1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "mlir/IR/BuiltinAttributes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "gtest/gtest.h" 16 #include <cstdint> 17 18 using namespace mlir; 19 using namespace mlir::detail; 20 21 namespace { 22 TEST(ShapedTypeTest, CloneMemref) { 23 MLIRContext context; 24 25 Type i32 = IntegerType::get(&context, 32); 26 Type f32 = FloatType::getF32(&context); 27 Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); 28 Type memrefOriginalType = i32; 29 llvm::SmallVector<int64_t> memrefOriginalShape({10, 20}); 30 AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context); 31 32 ShapedType memrefType = 33 (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType) 34 .setMemorySpace(memSpace) 35 .setLayout(AffineMapAttr::get(map)); 36 // Update shape. 37 llvm::SmallVector<int64_t> memrefNewShape({30, 40}); 38 ASSERT_NE(memrefOriginalShape, memrefNewShape); 39 ASSERT_EQ(memrefType.clone(memrefNewShape), 40 (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 41 .setMemorySpace(memSpace) 42 .setLayout(AffineMapAttr::get(map))); 43 // Update type. 44 Type memrefNewType = f32; 45 ASSERT_NE(memrefOriginalType, memrefNewType); 46 ASSERT_EQ(memrefType.clone(memrefNewType), 47 (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) 48 .setMemorySpace(memSpace) 49 .setLayout(AffineMapAttr::get(map))); 50 // Update both. 51 ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), 52 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 53 .setMemorySpace(memSpace) 54 .setLayout(AffineMapAttr::get(map))); 55 56 // Test unranked memref cloning. 57 ShapedType unrankedTensorType = 58 UnrankedMemRefType::get(memrefOriginalType, memSpace); 59 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape), 60 (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 61 .setMemorySpace(memSpace)); 62 ASSERT_EQ(unrankedTensorType.clone(memrefNewType), 63 UnrankedMemRefType::get(memrefNewType, memSpace)); 64 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType), 65 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 66 .setMemorySpace(memSpace)); 67 } 68 69 TEST(ShapedTypeTest, CloneTensor) { 70 MLIRContext context; 71 72 Type i32 = IntegerType::get(&context, 32); 73 Type f32 = FloatType::getF32(&context); 74 75 Type tensorOriginalType = i32; 76 llvm::SmallVector<int64_t> tensorOriginalShape({10, 20}); 77 78 // Test ranked tensor cloning. 79 ShapedType tensorType = 80 RankedTensorType::get(tensorOriginalShape, tensorOriginalType); 81 // Update shape. 82 llvm::SmallVector<int64_t> tensorNewShape({30, 40}); 83 ASSERT_NE(tensorOriginalShape, tensorNewShape); 84 ASSERT_EQ( 85 tensorType.clone(tensorNewShape), 86 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 87 // Update type. 88 Type tensorNewType = f32; 89 ASSERT_NE(tensorOriginalType, tensorNewType); 90 ASSERT_EQ( 91 tensorType.clone(tensorNewType), 92 (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType)); 93 // Update both. 94 ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), 95 (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType)); 96 97 // Test unranked tensor cloning. 98 ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); 99 ASSERT_EQ( 100 unrankedTensorType.clone(tensorNewShape), 101 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 102 ASSERT_EQ(unrankedTensorType.clone(tensorNewType), 103 (ShapedType)UnrankedTensorType::get(tensorNewType)); 104 ASSERT_EQ( 105 unrankedTensorType.clone(tensorNewShape), 106 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 107 } 108 109 TEST(ShapedTypeTest, CloneVector) { 110 MLIRContext context; 111 112 Type i32 = IntegerType::get(&context, 32); 113 Type f32 = FloatType::getF32(&context); 114 115 Type vectorOriginalType = i32; 116 llvm::SmallVector<int64_t> vectorOriginalShape({10, 20}); 117 ShapedType vectorType = 118 VectorType::get(vectorOriginalShape, vectorOriginalType); 119 // Update shape. 120 llvm::SmallVector<int64_t> vectorNewShape({30, 40}); 121 ASSERT_NE(vectorOriginalShape, vectorNewShape); 122 ASSERT_EQ(vectorType.clone(vectorNewShape), 123 VectorType::get(vectorNewShape, vectorOriginalType)); 124 // Update type. 125 Type vectorNewType = f32; 126 ASSERT_NE(vectorOriginalType, vectorNewType); 127 ASSERT_EQ(vectorType.clone(vectorNewType), 128 VectorType::get(vectorOriginalShape, vectorNewType)); 129 // Update both. 130 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), 131 VectorType::get(vectorNewShape, vectorNewType)); 132 } 133 134 } // namespace 135