//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" #include using namespace mlir; using namespace mlir::detail; namespace { TEST(ShapedTypeTest, CloneMemref) { MLIRContext context; Type i32 = IntegerType::get(&context, 32); Type f32 = FloatType::getF32(&context); Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); Type memrefOriginalType = i32; llvm::SmallVector memrefOriginalShape({10, 20}); AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context); ShapedType memrefType = (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType) .setMemorySpace(memSpace) .setLayout(AffineMapAttr::get(map)); // Update shape. llvm::SmallVector memrefNewShape({30, 40}); ASSERT_NE(memrefOriginalShape, memrefNewShape); ASSERT_EQ(memrefType.clone(memrefNewShape), (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType) .setMemorySpace(memSpace) .setLayout(AffineMapAttr::get(map))); // Update type. Type memrefNewType = f32; ASSERT_NE(memrefOriginalType, memrefNewType); ASSERT_EQ(memrefType.clone(memrefNewType), (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) .setMemorySpace(memSpace) .setLayout(AffineMapAttr::get(map))); // Update both. ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) .setMemorySpace(memSpace) .setLayout(AffineMapAttr::get(map))); // Test unranked memref cloning. ShapedType unrankedTensorType = UnrankedMemRefType::get(memrefOriginalType, memSpace); ASSERT_EQ(unrankedTensorType.clone(memrefNewShape), (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) .setMemorySpace(memSpace)); ASSERT_EQ(unrankedTensorType.clone(memrefNewType), UnrankedMemRefType::get(memrefNewType, memSpace)); ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType), (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) .setMemorySpace(memSpace)); } TEST(ShapedTypeTest, CloneTensor) { MLIRContext context; Type i32 = IntegerType::get(&context, 32); Type f32 = FloatType::getF32(&context); Type tensorOriginalType = i32; llvm::SmallVector tensorOriginalShape({10, 20}); // Test ranked tensor cloning. ShapedType tensorType = RankedTensorType::get(tensorOriginalShape, tensorOriginalType); // Update shape. llvm::SmallVector tensorNewShape({30, 40}); ASSERT_NE(tensorOriginalShape, tensorNewShape); ASSERT_EQ( tensorType.clone(tensorNewShape), (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); // Update type. Type tensorNewType = f32; ASSERT_NE(tensorOriginalType, tensorNewType); ASSERT_EQ( tensorType.clone(tensorNewType), (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType)); // Update both. ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType)); // Test unranked tensor cloning. ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); ASSERT_EQ( unrankedTensorType.clone(tensorNewShape), (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); ASSERT_EQ(unrankedTensorType.clone(tensorNewType), (ShapedType)UnrankedTensorType::get(tensorNewType)); ASSERT_EQ( unrankedTensorType.clone(tensorNewShape), (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); } TEST(ShapedTypeTest, CloneVector) { MLIRContext context; Type i32 = IntegerType::get(&context, 32); Type f32 = FloatType::getF32(&context); Type vectorOriginalType = i32; llvm::SmallVector vectorOriginalShape({10, 20}); ShapedType vectorType = VectorType::get(vectorOriginalShape, vectorOriginalType); // Update shape. llvm::SmallVector vectorNewShape({30, 40}); ASSERT_NE(vectorOriginalShape, vectorNewShape); ASSERT_EQ(vectorType.clone(vectorNewShape), VectorType::get(vectorNewShape, vectorOriginalType)); // Update type. Type vectorNewType = f32; ASSERT_NE(vectorOriginalType, vectorNewType); ASSERT_EQ(vectorType.clone(vectorNewType), VectorType::get(vectorOriginalShape, vectorNewType)); // Update both. ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), VectorType::get(vectorNewShape, vectorNewType)); } } // namespace