1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "mlir/IR/BuiltinAttributes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/Support/LLVM.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "gtest/gtest.h" 17 #include <cstdint> 18 19 using namespace mlir; 20 using namespace mlir::detail; 21 22 namespace { 23 TEST(ShapedTypeTest, CloneMemref) { 24 MLIRContext context; 25 26 Type i32 = IntegerType::get(&context, 32); 27 Type f32 = Float32Type::get(&context); 28 Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); 29 Type memrefOriginalType = i32; 30 llvm::SmallVector<int64_t> memrefOriginalShape({10, 20}); 31 AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context); 32 33 ShapedType memrefType = 34 (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType) 35 .setMemorySpace(memSpace) 36 .setLayout(AffineMapAttr::get(map)); 37 // Update shape. 38 llvm::SmallVector<int64_t> memrefNewShape({30, 40}); 39 ASSERT_NE(memrefOriginalShape, memrefNewShape); 40 ASSERT_EQ(memrefType.clone(memrefNewShape), 41 (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 42 .setMemorySpace(memSpace) 43 .setLayout(AffineMapAttr::get(map))); 44 // Update type. 45 Type memrefNewType = f32; 46 ASSERT_NE(memrefOriginalType, memrefNewType); 47 ASSERT_EQ(memrefType.clone(memrefNewType), 48 (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) 49 .setMemorySpace(memSpace) 50 .setLayout(AffineMapAttr::get(map))); 51 // Update both. 52 ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), 53 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 54 .setMemorySpace(memSpace) 55 .setLayout(AffineMapAttr::get(map))); 56 57 // Test unranked memref cloning. 58 ShapedType unrankedTensorType = 59 UnrankedMemRefType::get(memrefOriginalType, memSpace); 60 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape), 61 (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 62 .setMemorySpace(memSpace)); 63 ASSERT_EQ(unrankedTensorType.clone(memrefNewType), 64 UnrankedMemRefType::get(memrefNewType, memSpace)); 65 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType), 66 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 67 .setMemorySpace(memSpace)); 68 } 69 70 TEST(ShapedTypeTest, CloneTensor) { 71 MLIRContext context; 72 73 Type i32 = IntegerType::get(&context, 32); 74 Type f32 = Float32Type::get(&context); 75 76 Type tensorOriginalType = i32; 77 llvm::SmallVector<int64_t> tensorOriginalShape({10, 20}); 78 79 // Test ranked tensor cloning. 80 ShapedType tensorType = 81 RankedTensorType::get(tensorOriginalShape, tensorOriginalType); 82 // Update shape. 83 llvm::SmallVector<int64_t> tensorNewShape({30, 40}); 84 ASSERT_NE(tensorOriginalShape, tensorNewShape); 85 ASSERT_EQ( 86 tensorType.clone(tensorNewShape), 87 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 88 // Update type. 89 Type tensorNewType = f32; 90 ASSERT_NE(tensorOriginalType, tensorNewType); 91 ASSERT_EQ( 92 tensorType.clone(tensorNewType), 93 (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType)); 94 // Update both. 95 ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), 96 (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType)); 97 98 // Test unranked tensor cloning. 99 ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); 100 ASSERT_EQ( 101 unrankedTensorType.clone(tensorNewShape), 102 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 103 ASSERT_EQ(unrankedTensorType.clone(tensorNewType), 104 (ShapedType)UnrankedTensorType::get(tensorNewType)); 105 ASSERT_EQ( 106 unrankedTensorType.clone(tensorNewShape), 107 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 108 } 109 110 TEST(ShapedTypeTest, CloneVector) { 111 MLIRContext context; 112 113 Type i32 = IntegerType::get(&context, 32); 114 Type f32 = Float32Type::get(&context); 115 116 Type vectorOriginalType = i32; 117 llvm::SmallVector<int64_t> vectorOriginalShape({10, 20}); 118 ShapedType vectorType = 119 VectorType::get(vectorOriginalShape, vectorOriginalType); 120 // Update shape. 121 llvm::SmallVector<int64_t> vectorNewShape({30, 40}); 122 ASSERT_NE(vectorOriginalShape, vectorNewShape); 123 ASSERT_EQ(vectorType.clone(vectorNewShape), 124 VectorType::get(vectorNewShape, vectorOriginalType)); 125 // Update type. 126 Type vectorNewType = f32; 127 ASSERT_NE(vectorOriginalType, vectorNewType); 128 ASSERT_EQ(vectorType.clone(vectorNewType), 129 VectorType::get(vectorOriginalShape, vectorNewType)); 130 // Update both. 131 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), 132 VectorType::get(vectorNewShape, vectorNewType)); 133 } 134 135 TEST(ShapedTypeTest, VectorTypeBuilder) { 136 MLIRContext context; 137 Type f32 = Float32Type::get(&context); 138 139 SmallVector<int64_t> shape{2, 4, 8, 9, 1}; 140 SmallVector<bool> scalableDims{true, false, true, false, false}; 141 VectorType vectorType = VectorType::get(shape, f32, scalableDims); 142 143 { 144 // Drop some dims. 145 VectorType dropFrontTwoDims = 146 VectorType::Builder(vectorType).dropDim(0).dropDim(0); 147 ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType()); 148 ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape()); 149 ASSERT_EQ(vectorType.getScalableDims().drop_front(2), 150 dropFrontTwoDims.getScalableDims()); 151 } 152 153 { 154 // Set some dims. 155 VectorType setTwoDims = 156 VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12); 157 ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1})); 158 ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType()); 159 ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims()); 160 } 161 162 { 163 // Test for bug from: 164 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a 165 // Constructs a temporary builder, modifies it, copies it to `builder`. 166 // This used to lead to a use-after-free. Running under sanitizers will 167 // catch any issues. 168 VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16); 169 VectorType newVectorType = VectorType(builder); 170 ASSERT_EQ(newVectorType.getDimSize(0), 16); 171 } 172 173 { 174 // Make builder from scratch (without scalable dims) -- this use to lead to 175 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. 176 // Running under sanitizers will catch any issues. 177 SmallVector<int64_t> shape{1, 2, 3, 4}; 178 VectorType::Builder builder(shape, f32); 179 ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape)); 180 } 181 182 { 183 // Set vector shape (without scalable dims) -- this use to lead to 184 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. 185 // Running under sanitizers will catch any issues. 186 VectorType::Builder builder(vectorType); 187 SmallVector<int64_t> newShape{2, 2}; 188 builder.setShape(newShape); 189 ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape)); 190 } 191 } 192 193 TEST(ShapedTypeTest, RankedTensorTypeBuilder) { 194 MLIRContext context; 195 Type f32 = Float32Type::get(&context); 196 197 SmallVector<int64_t> shape{2, 4, 8, 16, 32}; 198 RankedTensorType tensorType = RankedTensorType::get(shape, f32); 199 200 { 201 // Drop some dims. 202 RankedTensorType dropFrontTwoDims = 203 RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0); 204 ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType()); 205 ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32})); 206 } 207 208 { 209 // Insert some dims. 210 RankedTensorType insertTwoDims = 211 RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3); 212 ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType()); 213 ASSERT_EQ(insertTwoDims.getShape(), 214 ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32})); 215 } 216 217 { 218 // Test for bug from: 219 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a 220 // Constructs a temporary builder, modifies it, copies it to `builder`. 221 // This used to lead to a use-after-free. Running under sanitizers will 222 // catch any issues. 223 RankedTensorType::Builder builder = 224 RankedTensorType::Builder(tensorType).dropDim(0); 225 RankedTensorType newTensorType = RankedTensorType(builder); 226 ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape()); 227 } 228 } 229 230 /// Simple wrapper class to enable "isa querying" and simple accessing of 231 /// encoding. 232 class TensorWithString : public RankedTensorType { 233 public: 234 using RankedTensorType::RankedTensorType; 235 236 static TensorWithString get(ArrayRef<int64_t> shape, Type elementType, 237 StringRef name) { 238 return mlir::cast<TensorWithString>(RankedTensorType::get( 239 shape, elementType, StringAttr::get(elementType.getContext(), name))); 240 } 241 242 StringRef getName() const { 243 if (Attribute enc = getEncoding()) 244 return mlir::cast<StringAttr>(enc).getValue(); 245 return {}; 246 } 247 248 static bool classof(Type type) { 249 if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type)) 250 return mlir::isa_and_present<StringAttr>(rt.getEncoding()); 251 return false; 252 } 253 }; 254 255 TEST(ShapedTypeTest, RankedTensorTypeView) { 256 MLIRContext context; 257 Type f32 = Float32Type::get(&context); 258 259 Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32); 260 261 UnitAttr unitAttr = UnitAttr::get(&context); 262 Type unitEncodingRankedTensorType = 263 RankedTensorType::get({10, 20}, f32, unitAttr); 264 265 StringAttr stringAttr = StringAttr::get(&context, "app"); 266 Type stringEncodingRankedTensorType = 267 RankedTensorType::get({10, 20}, f32, stringAttr); 268 269 EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType)); 270 EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType)); 271 ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType)); 272 273 // Cast to TensorWithString view. 274 auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType); 275 ASSERT_TRUE(mlir::isa<TensorWithString>(view)); 276 EXPECT_EQ(view.getName(), "app"); 277 // Verify one could cast view type back to base type. 278 ASSERT_TRUE(mlir::isa<RankedTensorType>(view)); 279 280 Type viewCreated = TensorWithString::get({10, 20}, f32, "bob"); 281 ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated)); 282 ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated)); 283 view = mlir::cast<TensorWithString>(viewCreated); 284 EXPECT_EQ(view.getName(), "bob"); 285 } 286 287 } // namespace 288