1381a65faSJacques Pienaar //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===// 2381a65faSJacques Pienaar // 3381a65faSJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4381a65faSJacques Pienaar // See https://llvm.org/LICENSE.txt for license information. 5381a65faSJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6381a65faSJacques Pienaar // 7381a65faSJacques Pienaar //===----------------------------------------------------------------------===// 8381a65faSJacques Pienaar 9381a65faSJacques Pienaar #include "mlir/IR/AffineMap.h" 10f3bf5c05SVladislav Vinogradov #include "mlir/IR/BuiltinAttributes.h" 11381a65faSJacques Pienaar #include "mlir/IR/BuiltinTypes.h" 12381a65faSJacques Pienaar #include "mlir/IR/Dialect.h" 13381a65faSJacques Pienaar #include "mlir/IR/DialectInterface.h" 14d2f42c73SJacques Pienaar #include "mlir/Support/LLVM.h" 15381a65faSJacques Pienaar #include "llvm/ADT/SmallVector.h" 16381a65faSJacques Pienaar #include "gtest/gtest.h" 17381a65faSJacques Pienaar #include <cstdint> 18381a65faSJacques Pienaar 19381a65faSJacques Pienaar using namespace mlir; 20381a65faSJacques Pienaar using namespace mlir::detail; 21381a65faSJacques Pienaar 22381a65faSJacques Pienaar namespace { 23381a65faSJacques Pienaar TEST(ShapedTypeTest, CloneMemref) { 24381a65faSJacques Pienaar MLIRContext context; 25381a65faSJacques Pienaar 26381a65faSJacques Pienaar Type i32 = IntegerType::get(&context, 32); 27*f023da12SMatthias Springer Type f32 = Float32Type::get(&context); 28f3bf5c05SVladislav Vinogradov Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); 29381a65faSJacques Pienaar Type memrefOriginalType = i32; 30381a65faSJacques Pienaar llvm::SmallVector<int64_t> memrefOriginalShape({10, 20}); 31381a65faSJacques Pienaar AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context); 32381a65faSJacques Pienaar 33381a65faSJacques Pienaar ShapedType memrefType = 34676bfb2aSRiver Riddle (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType) 35381a65faSJacques Pienaar .setMemorySpace(memSpace) 36e41ebbecSVladislav Vinogradov .setLayout(AffineMapAttr::get(map)); 37381a65faSJacques Pienaar // Update shape. 38381a65faSJacques Pienaar llvm::SmallVector<int64_t> memrefNewShape({30, 40}); 39381a65faSJacques Pienaar ASSERT_NE(memrefOriginalShape, memrefNewShape); 40381a65faSJacques Pienaar ASSERT_EQ(memrefType.clone(memrefNewShape), 41676bfb2aSRiver Riddle (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 42381a65faSJacques Pienaar .setMemorySpace(memSpace) 43e41ebbecSVladislav Vinogradov .setLayout(AffineMapAttr::get(map))); 44381a65faSJacques Pienaar // Update type. 45381a65faSJacques Pienaar Type memrefNewType = f32; 46381a65faSJacques Pienaar ASSERT_NE(memrefOriginalType, memrefNewType); 47381a65faSJacques Pienaar ASSERT_EQ(memrefType.clone(memrefNewType), 48381a65faSJacques Pienaar (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) 49381a65faSJacques Pienaar .setMemorySpace(memSpace) 50e41ebbecSVladislav Vinogradov .setLayout(AffineMapAttr::get(map))); 51381a65faSJacques Pienaar // Update both. 52381a65faSJacques Pienaar ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), 53381a65faSJacques Pienaar (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 54381a65faSJacques Pienaar .setMemorySpace(memSpace) 55e41ebbecSVladislav Vinogradov .setLayout(AffineMapAttr::get(map))); 56381a65faSJacques Pienaar 57381a65faSJacques Pienaar // Test unranked memref cloning. 58381a65faSJacques Pienaar ShapedType unrankedTensorType = 59381a65faSJacques Pienaar UnrankedMemRefType::get(memrefOriginalType, memSpace); 60381a65faSJacques Pienaar ASSERT_EQ(unrankedTensorType.clone(memrefNewShape), 61381a65faSJacques Pienaar (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) 62381a65faSJacques Pienaar .setMemorySpace(memSpace)); 63381a65faSJacques Pienaar ASSERT_EQ(unrankedTensorType.clone(memrefNewType), 64381a65faSJacques Pienaar UnrankedMemRefType::get(memrefNewType, memSpace)); 65381a65faSJacques Pienaar ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType), 66381a65faSJacques Pienaar (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) 67381a65faSJacques Pienaar .setMemorySpace(memSpace)); 68381a65faSJacques Pienaar } 69381a65faSJacques Pienaar 70381a65faSJacques Pienaar TEST(ShapedTypeTest, CloneTensor) { 71381a65faSJacques Pienaar MLIRContext context; 72381a65faSJacques Pienaar 73381a65faSJacques Pienaar Type i32 = IntegerType::get(&context, 32); 74*f023da12SMatthias Springer Type f32 = Float32Type::get(&context); 75381a65faSJacques Pienaar 76381a65faSJacques Pienaar Type tensorOriginalType = i32; 77381a65faSJacques Pienaar llvm::SmallVector<int64_t> tensorOriginalShape({10, 20}); 78381a65faSJacques Pienaar 79381a65faSJacques Pienaar // Test ranked tensor cloning. 80381a65faSJacques Pienaar ShapedType tensorType = 81381a65faSJacques Pienaar RankedTensorType::get(tensorOriginalShape, tensorOriginalType); 82381a65faSJacques Pienaar // Update shape. 83381a65faSJacques Pienaar llvm::SmallVector<int64_t> tensorNewShape({30, 40}); 84381a65faSJacques Pienaar ASSERT_NE(tensorOriginalShape, tensorNewShape); 85676bfb2aSRiver Riddle ASSERT_EQ( 86676bfb2aSRiver Riddle tensorType.clone(tensorNewShape), 87676bfb2aSRiver Riddle (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 88381a65faSJacques Pienaar // Update type. 89381a65faSJacques Pienaar Type tensorNewType = f32; 90381a65faSJacques Pienaar ASSERT_NE(tensorOriginalType, tensorNewType); 91676bfb2aSRiver Riddle ASSERT_EQ( 92676bfb2aSRiver Riddle tensorType.clone(tensorNewType), 93676bfb2aSRiver Riddle (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType)); 94381a65faSJacques Pienaar // Update both. 95381a65faSJacques Pienaar ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), 96676bfb2aSRiver Riddle (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType)); 97381a65faSJacques Pienaar 98381a65faSJacques Pienaar // Test unranked tensor cloning. 99381a65faSJacques Pienaar ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); 100676bfb2aSRiver Riddle ASSERT_EQ( 101676bfb2aSRiver Riddle unrankedTensorType.clone(tensorNewShape), 102676bfb2aSRiver Riddle (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 103381a65faSJacques Pienaar ASSERT_EQ(unrankedTensorType.clone(tensorNewType), 104676bfb2aSRiver Riddle (ShapedType)UnrankedTensorType::get(tensorNewType)); 105676bfb2aSRiver Riddle ASSERT_EQ( 106676bfb2aSRiver Riddle unrankedTensorType.clone(tensorNewShape), 107676bfb2aSRiver Riddle (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); 108381a65faSJacques Pienaar } 109381a65faSJacques Pienaar 110381a65faSJacques Pienaar TEST(ShapedTypeTest, CloneVector) { 111381a65faSJacques Pienaar MLIRContext context; 112381a65faSJacques Pienaar 113381a65faSJacques Pienaar Type i32 = IntegerType::get(&context, 32); 114*f023da12SMatthias Springer Type f32 = Float32Type::get(&context); 115381a65faSJacques Pienaar 116381a65faSJacques Pienaar Type vectorOriginalType = i32; 117381a65faSJacques Pienaar llvm::SmallVector<int64_t> vectorOriginalShape({10, 20}); 118381a65faSJacques Pienaar ShapedType vectorType = 119381a65faSJacques Pienaar VectorType::get(vectorOriginalShape, vectorOriginalType); 120381a65faSJacques Pienaar // Update shape. 121381a65faSJacques Pienaar llvm::SmallVector<int64_t> vectorNewShape({30, 40}); 122381a65faSJacques Pienaar ASSERT_NE(vectorOriginalShape, vectorNewShape); 123381a65faSJacques Pienaar ASSERT_EQ(vectorType.clone(vectorNewShape), 124381a65faSJacques Pienaar VectorType::get(vectorNewShape, vectorOriginalType)); 125381a65faSJacques Pienaar // Update type. 126381a65faSJacques Pienaar Type vectorNewType = f32; 127381a65faSJacques Pienaar ASSERT_NE(vectorOriginalType, vectorNewType); 128381a65faSJacques Pienaar ASSERT_EQ(vectorType.clone(vectorNewType), 129381a65faSJacques Pienaar VectorType::get(vectorOriginalShape, vectorNewType)); 130381a65faSJacques Pienaar // Update both. 131381a65faSJacques Pienaar ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), 132381a65faSJacques Pienaar VectorType::get(vectorNewShape, vectorNewType)); 133381a65faSJacques Pienaar } 134381a65faSJacques Pienaar 135b0b8e83eSBenjamin Maxwell TEST(ShapedTypeTest, VectorTypeBuilder) { 136b0b8e83eSBenjamin Maxwell MLIRContext context; 137*f023da12SMatthias Springer Type f32 = Float32Type::get(&context); 138b0b8e83eSBenjamin Maxwell 139b0b8e83eSBenjamin Maxwell SmallVector<int64_t> shape{2, 4, 8, 9, 1}; 140b0b8e83eSBenjamin Maxwell SmallVector<bool> scalableDims{true, false, true, false, false}; 141b0b8e83eSBenjamin Maxwell VectorType vectorType = VectorType::get(shape, f32, scalableDims); 142b0b8e83eSBenjamin Maxwell 143b0b8e83eSBenjamin Maxwell { 144b0b8e83eSBenjamin Maxwell // Drop some dims. 145b0b8e83eSBenjamin Maxwell VectorType dropFrontTwoDims = 146b0b8e83eSBenjamin Maxwell VectorType::Builder(vectorType).dropDim(0).dropDim(0); 147b0b8e83eSBenjamin Maxwell ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType()); 148b0b8e83eSBenjamin Maxwell ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape()); 149b0b8e83eSBenjamin Maxwell ASSERT_EQ(vectorType.getScalableDims().drop_front(2), 150b0b8e83eSBenjamin Maxwell dropFrontTwoDims.getScalableDims()); 151b0b8e83eSBenjamin Maxwell } 152b0b8e83eSBenjamin Maxwell 153b0b8e83eSBenjamin Maxwell { 154b0b8e83eSBenjamin Maxwell // Set some dims. 155b0b8e83eSBenjamin Maxwell VectorType setTwoDims = 156b0b8e83eSBenjamin Maxwell VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12); 157b0b8e83eSBenjamin Maxwell ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1})); 158b0b8e83eSBenjamin Maxwell ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType()); 159b0b8e83eSBenjamin Maxwell ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims()); 160b0b8e83eSBenjamin Maxwell } 161b0b8e83eSBenjamin Maxwell 162b0b8e83eSBenjamin Maxwell { 163b0b8e83eSBenjamin Maxwell // Test for bug from: 164b0b8e83eSBenjamin Maxwell // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a 165b0b8e83eSBenjamin Maxwell // Constructs a temporary builder, modifies it, copies it to `builder`. 166b0b8e83eSBenjamin Maxwell // This used to lead to a use-after-free. Running under sanitizers will 167b0b8e83eSBenjamin Maxwell // catch any issues. 168b0b8e83eSBenjamin Maxwell VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16); 169b0b8e83eSBenjamin Maxwell VectorType newVectorType = VectorType(builder); 170b0b8e83eSBenjamin Maxwell ASSERT_EQ(newVectorType.getDimSize(0), 16); 171b0b8e83eSBenjamin Maxwell } 172b0b8e83eSBenjamin Maxwell 173b0b8e83eSBenjamin Maxwell { 174b0b8e83eSBenjamin Maxwell // Make builder from scratch (without scalable dims) -- this use to lead to 175b0b8e83eSBenjamin Maxwell // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. 176b0b8e83eSBenjamin Maxwell // Running under sanitizers will catch any issues. 177b0b8e83eSBenjamin Maxwell SmallVector<int64_t> shape{1, 2, 3, 4}; 178b0b8e83eSBenjamin Maxwell VectorType::Builder builder(shape, f32); 179b0b8e83eSBenjamin Maxwell ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape)); 180b0b8e83eSBenjamin Maxwell } 181b0b8e83eSBenjamin Maxwell 182b0b8e83eSBenjamin Maxwell { 183b0b8e83eSBenjamin Maxwell // Set vector shape (without scalable dims) -- this use to lead to 184b0b8e83eSBenjamin Maxwell // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. 185b0b8e83eSBenjamin Maxwell // Running under sanitizers will catch any issues. 186b0b8e83eSBenjamin Maxwell VectorType::Builder builder(vectorType); 187b0b8e83eSBenjamin Maxwell SmallVector<int64_t> newShape{2, 2}; 188b0b8e83eSBenjamin Maxwell builder.setShape(newShape); 189b0b8e83eSBenjamin Maxwell ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape)); 190b0b8e83eSBenjamin Maxwell } 191b0b8e83eSBenjamin Maxwell } 192b0b8e83eSBenjamin Maxwell 193b0b8e83eSBenjamin Maxwell TEST(ShapedTypeTest, RankedTensorTypeBuilder) { 194b0b8e83eSBenjamin Maxwell MLIRContext context; 195*f023da12SMatthias Springer Type f32 = Float32Type::get(&context); 196b0b8e83eSBenjamin Maxwell 197b0b8e83eSBenjamin Maxwell SmallVector<int64_t> shape{2, 4, 8, 16, 32}; 198b0b8e83eSBenjamin Maxwell RankedTensorType tensorType = RankedTensorType::get(shape, f32); 199b0b8e83eSBenjamin Maxwell 200b0b8e83eSBenjamin Maxwell { 201b0b8e83eSBenjamin Maxwell // Drop some dims. 202b0b8e83eSBenjamin Maxwell RankedTensorType dropFrontTwoDims = 203b0b8e83eSBenjamin Maxwell RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0); 204b0b8e83eSBenjamin Maxwell ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType()); 205b0b8e83eSBenjamin Maxwell ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32})); 206b0b8e83eSBenjamin Maxwell } 207b0b8e83eSBenjamin Maxwell 208b0b8e83eSBenjamin Maxwell { 209b0b8e83eSBenjamin Maxwell // Insert some dims. 210b0b8e83eSBenjamin Maxwell RankedTensorType insertTwoDims = 211b0b8e83eSBenjamin Maxwell RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3); 212b0b8e83eSBenjamin Maxwell ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType()); 213b0b8e83eSBenjamin Maxwell ASSERT_EQ(insertTwoDims.getShape(), 214b0b8e83eSBenjamin Maxwell ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32})); 215b0b8e83eSBenjamin Maxwell } 216b0b8e83eSBenjamin Maxwell 217b0b8e83eSBenjamin Maxwell { 218b0b8e83eSBenjamin Maxwell // Test for bug from: 219b0b8e83eSBenjamin Maxwell // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a 220b0b8e83eSBenjamin Maxwell // Constructs a temporary builder, modifies it, copies it to `builder`. 221b0b8e83eSBenjamin Maxwell // This used to lead to a use-after-free. Running under sanitizers will 222b0b8e83eSBenjamin Maxwell // catch any issues. 223b0b8e83eSBenjamin Maxwell RankedTensorType::Builder builder = 224b0b8e83eSBenjamin Maxwell RankedTensorType::Builder(tensorType).dropDim(0); 225b0b8e83eSBenjamin Maxwell RankedTensorType newTensorType = RankedTensorType(builder); 226b0b8e83eSBenjamin Maxwell ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape()); 227b0b8e83eSBenjamin Maxwell } 228b0b8e83eSBenjamin Maxwell } 229b0b8e83eSBenjamin Maxwell 230d2f42c73SJacques Pienaar /// Simple wrapper class to enable "isa querying" and simple accessing of 231d2f42c73SJacques Pienaar /// encoding. 232d2f42c73SJacques Pienaar class TensorWithString : public RankedTensorType { 233d2f42c73SJacques Pienaar public: 234d2f42c73SJacques Pienaar using RankedTensorType::RankedTensorType; 235d2f42c73SJacques Pienaar 236d2f42c73SJacques Pienaar static TensorWithString get(ArrayRef<int64_t> shape, Type elementType, 237d2f42c73SJacques Pienaar StringRef name) { 238d2f42c73SJacques Pienaar return mlir::cast<TensorWithString>(RankedTensorType::get( 239d2f42c73SJacques Pienaar shape, elementType, StringAttr::get(elementType.getContext(), name))); 240d2f42c73SJacques Pienaar } 241d2f42c73SJacques Pienaar 242d2f42c73SJacques Pienaar StringRef getName() const { 243d2f42c73SJacques Pienaar if (Attribute enc = getEncoding()) 244d2f42c73SJacques Pienaar return mlir::cast<StringAttr>(enc).getValue(); 245d2f42c73SJacques Pienaar return {}; 246d2f42c73SJacques Pienaar } 247d2f42c73SJacques Pienaar 248d2f42c73SJacques Pienaar static bool classof(Type type) { 249d2f42c73SJacques Pienaar if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type)) 250d2f42c73SJacques Pienaar return mlir::isa_and_present<StringAttr>(rt.getEncoding()); 251d2f42c73SJacques Pienaar return false; 252d2f42c73SJacques Pienaar } 253d2f42c73SJacques Pienaar }; 254d2f42c73SJacques Pienaar 255d2f42c73SJacques Pienaar TEST(ShapedTypeTest, RankedTensorTypeView) { 256d2f42c73SJacques Pienaar MLIRContext context; 257*f023da12SMatthias Springer Type f32 = Float32Type::get(&context); 258d2f42c73SJacques Pienaar 259d2f42c73SJacques Pienaar Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32); 260d2f42c73SJacques Pienaar 261d2f42c73SJacques Pienaar UnitAttr unitAttr = UnitAttr::get(&context); 262d2f42c73SJacques Pienaar Type unitEncodingRankedTensorType = 263d2f42c73SJacques Pienaar RankedTensorType::get({10, 20}, f32, unitAttr); 264d2f42c73SJacques Pienaar 265d2f42c73SJacques Pienaar StringAttr stringAttr = StringAttr::get(&context, "app"); 266d2f42c73SJacques Pienaar Type stringEncodingRankedTensorType = 267d2f42c73SJacques Pienaar RankedTensorType::get({10, 20}, f32, stringAttr); 268d2f42c73SJacques Pienaar 269d2f42c73SJacques Pienaar EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType)); 270d2f42c73SJacques Pienaar EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType)); 271d2f42c73SJacques Pienaar ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType)); 272d2f42c73SJacques Pienaar 273d2f42c73SJacques Pienaar // Cast to TensorWithString view. 274d2f42c73SJacques Pienaar auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType); 275d2f42c73SJacques Pienaar ASSERT_TRUE(mlir::isa<TensorWithString>(view)); 276d2f42c73SJacques Pienaar EXPECT_EQ(view.getName(), "app"); 277d2f42c73SJacques Pienaar // Verify one could cast view type back to base type. 278d2f42c73SJacques Pienaar ASSERT_TRUE(mlir::isa<RankedTensorType>(view)); 279d2f42c73SJacques Pienaar 280d2f42c73SJacques Pienaar Type viewCreated = TensorWithString::get({10, 20}, f32, "bob"); 281d2f42c73SJacques Pienaar ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated)); 282d2f42c73SJacques Pienaar ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated)); 283d2f42c73SJacques Pienaar view = mlir::cast<TensorWithString>(viewCreated); 284d2f42c73SJacques Pienaar EXPECT_EQ(view.getName(), "bob"); 285d2f42c73SJacques Pienaar } 286d2f42c73SJacques Pienaar 287be0a7e9fSMehdi Amini } // namespace 288