xref: /llvm-project/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp (revision 9744d396f8823eaeb15c856f2fa45d8dd4f1d60b)
1 //===- IndexOpsFoldersTest.cpp - unit tests for index op folders ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Index/IR/IndexDialect.h"
10 #include "mlir/Dialect/Index/IR/IndexOps.h"
11 #include "mlir/IR/OwningOpRef.h"
12 #include "gtest/gtest.h"
13 
14 using namespace mlir;
15 
16 namespace {
17 /// Test fixture for testing operation folders.
18 class IndexFolderTest : public testing::Test {
19 public:
IndexFolderTest()20   IndexFolderTest() { ctx.getOrLoadDialect<index::IndexDialect>(); }
21 
22   /// Instantiate an operation, invoke its folder, and return the attribute
23   /// result.
24   template <typename OpT>
25   void foldOp(IntegerAttr &value, Type type, ArrayRef<Attribute> operands);
26 
27 protected:
28   /// The MLIR context to use.
29   MLIRContext ctx;
30   /// A builder to use.
31   OpBuilder b{&ctx};
32 };
33 } // namespace
34 
35 template <typename OpT>
foldOp(IntegerAttr & value,Type type,ArrayRef<Attribute> operands)36 void IndexFolderTest::foldOp(IntegerAttr &value, Type type,
37                              ArrayRef<Attribute> operands) {
38   // This function returns null so that `ASSERT_*` works within it.
39   OperationState state(UnknownLoc::get(&ctx), OpT::getOperationName());
40   state.addTypes(type);
41   OwningOpRef<OpT> op = cast<OpT>(b.create(state));
42   SmallVector<OpFoldResult> results;
43   LogicalResult result = op->getOperation()->fold(operands, results);
44   // Propagate the failure to the test.
45   if (failed(result)) {
46     value = nullptr;
47     return;
48   }
49   ASSERT_EQ(results.size(), 1u);
50   value = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(results.front()));
51   ASSERT_TRUE(value);
52 }
53 
TEST_F(IndexFolderTest,TestCastUOpFolder)54 TEST_F(IndexFolderTest, TestCastUOpFolder) {
55   IntegerAttr value;
56   auto fold = [&](Type type, Attribute input) {
57     foldOp<index::CastUOp>(value, type, input);
58   };
59 
60   // Target width less than or equal to 32 bits.
61   fold(b.getIntegerType(16), b.getIndexAttr(8000000000));
62   ASSERT_TRUE(value);
63   EXPECT_EQ(value.getInt(), 20480u);
64 
65   // Target width greater than or equal to 64 bits.
66   fold(b.getIntegerType(64), b.getIndexAttr(2000));
67   ASSERT_TRUE(value);
68   EXPECT_EQ(value.getInt(), 2000u);
69 
70   // Fails to fold, because truncating to 32 bits and then extending creates a
71   // different value.
72   fold(b.getIntegerType(64), b.getIndexAttr(8000000000));
73   EXPECT_FALSE(value);
74 
75   // Target width between 32 and 64 bits.
76   fold(b.getIntegerType(40), b.getIndexAttr(0x10000000010000));
77   // Fold succeeds because the upper bits are truncated in the cast.
78   ASSERT_TRUE(value);
79   EXPECT_EQ(value.getInt(), 65536);
80 
81   // Fails to fold because the upper bits are not truncated.
82   fold(b.getIntegerType(60), b.getIndexAttr(0x10000000010000));
83   EXPECT_FALSE(value);
84 }
85 
TEST_F(IndexFolderTest,TestCastSOpFolder)86 TEST_F(IndexFolderTest, TestCastSOpFolder) {
87   IntegerAttr value;
88   auto fold = [&](Type type, Attribute input) {
89     foldOp<index::CastSOp>(value, type, input);
90   };
91 
92   // Just test the extension cases to ensure signs are being respected.
93 
94   // Target width greater than or equal to 64 bits.
95   fold(b.getIntegerType(64), b.getIndexAttr(-2000));
96   ASSERT_TRUE(value);
97   EXPECT_EQ(value.getInt(), -2000);
98 
99   // Target width between 32 and 64 bits.
100   fold(b.getIntegerType(40), b.getIndexAttr(-0x10000000010000));
101   // Fold succeeds because the upper bits are truncated in the cast.
102   ASSERT_TRUE(value);
103   EXPECT_EQ(value.getInt(), -65536);
104 }
105