xref: /llvm-project/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (revision 01e40a8a3d40d7595d2cd95363c27d84b31e5cd2)
1 //===- TileTypeConversionTest.cpp - Tests ArmSME tile type conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
10 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
11 #include "mlir/Conversion/LLVMCommon/Pattern.h"
12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
14 
15 #include "gtest/gtest.h"
16 
17 using namespace mlir;
18 
19 class ArmSMETest : public ::testing::Test {
20 protected:
ArmSMETest()21   ArmSMETest() { context.getOrLoadDialect<mlir::arm_sme::ArmSMEDialect>(); }
22 
23   mlir::MLIRContext context;
24 };
25 
TEST_F(ArmSMETest,TestTileTypeConversion)26 TEST_F(ArmSMETest, TestTileTypeConversion) {
27   LLVMTypeConverter llvmConverter(&context);
28   LLVMTypeConverter llvmConverterWithArmSMEConversion(&context);
29 
30   RewritePatternSet patterns(&context);
31   populateArmSMEToLLVMConversionPatterns(llvmConverterWithArmSMEConversion,
32                                          patterns);
33 
34   Type i32 = IntegerType::get(&context, 32);
35   auto smeTileType = VectorType::get({4, 4}, i32, {true, true});
36 
37   // An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type.
38   {
39     SmallVector<Type> convertedType;
40     ASSERT_TRUE(failed(llvmConverter.convertType(smeTileType, convertedType)));
41   }
42 
43   // An updated LLVMTypeConverer should return the ArmSME tile vector type
44   // unchanged.
45   {
46     SmallVector<Type> convertedType;
47     ASSERT_TRUE(succeeded(llvmConverterWithArmSMEConversion.convertType(
48         smeTileType, convertedType)));
49     ASSERT_EQ(ArrayRef<Type>(convertedType), ArrayRef<Type>{smeTileType});
50   }
51 }
52