1*9744d396SJeff Niu //===- IndexOpsFoldersTest.cpp - unit tests for index op folders ----------===//
2*9744d396SJeff Niu //
3*9744d396SJeff Niu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*9744d396SJeff Niu // See https://llvm.org/LICENSE.txt for license information.
5*9744d396SJeff Niu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*9744d396SJeff Niu //
7*9744d396SJeff Niu //===----------------------------------------------------------------------===//
8*9744d396SJeff Niu
9*9744d396SJeff Niu #include "mlir/Dialect/Index/IR/IndexDialect.h"
10*9744d396SJeff Niu #include "mlir/Dialect/Index/IR/IndexOps.h"
11*9744d396SJeff Niu #include "mlir/IR/OwningOpRef.h"
12*9744d396SJeff Niu #include "gtest/gtest.h"
13*9744d396SJeff Niu
14*9744d396SJeff Niu using namespace mlir;
15*9744d396SJeff Niu
16*9744d396SJeff Niu namespace {
17*9744d396SJeff Niu /// Test fixture for testing operation folders.
18*9744d396SJeff Niu class IndexFolderTest : public testing::Test {
19*9744d396SJeff Niu public:
IndexFolderTest()20*9744d396SJeff Niu IndexFolderTest() { ctx.getOrLoadDialect<index::IndexDialect>(); }
21*9744d396SJeff Niu
22*9744d396SJeff Niu /// Instantiate an operation, invoke its folder, and return the attribute
23*9744d396SJeff Niu /// result.
24*9744d396SJeff Niu template <typename OpT>
25*9744d396SJeff Niu void foldOp(IntegerAttr &value, Type type, ArrayRef<Attribute> operands);
26*9744d396SJeff Niu
27*9744d396SJeff Niu protected:
28*9744d396SJeff Niu /// The MLIR context to use.
29*9744d396SJeff Niu MLIRContext ctx;
30*9744d396SJeff Niu /// A builder to use.
31*9744d396SJeff Niu OpBuilder b{&ctx};
32*9744d396SJeff Niu };
33*9744d396SJeff Niu } // namespace
34*9744d396SJeff Niu
35*9744d396SJeff Niu template <typename OpT>
foldOp(IntegerAttr & value,Type type,ArrayRef<Attribute> operands)36*9744d396SJeff Niu void IndexFolderTest::foldOp(IntegerAttr &value, Type type,
37*9744d396SJeff Niu ArrayRef<Attribute> operands) {
38*9744d396SJeff Niu // This function returns null so that `ASSERT_*` works within it.
39*9744d396SJeff Niu OperationState state(UnknownLoc::get(&ctx), OpT::getOperationName());
40*9744d396SJeff Niu state.addTypes(type);
41*9744d396SJeff Niu OwningOpRef<OpT> op = cast<OpT>(b.create(state));
42*9744d396SJeff Niu SmallVector<OpFoldResult> results;
43*9744d396SJeff Niu LogicalResult result = op->getOperation()->fold(operands, results);
44*9744d396SJeff Niu // Propagate the failure to the test.
45*9744d396SJeff Niu if (failed(result)) {
46*9744d396SJeff Niu value = nullptr;
47*9744d396SJeff Niu return;
48*9744d396SJeff Niu }
49*9744d396SJeff Niu ASSERT_EQ(results.size(), 1u);
50*9744d396SJeff Niu value = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(results.front()));
51*9744d396SJeff Niu ASSERT_TRUE(value);
52*9744d396SJeff Niu }
53*9744d396SJeff Niu
TEST_F(IndexFolderTest,TestCastUOpFolder)54*9744d396SJeff Niu TEST_F(IndexFolderTest, TestCastUOpFolder) {
55*9744d396SJeff Niu IntegerAttr value;
56*9744d396SJeff Niu auto fold = [&](Type type, Attribute input) {
57*9744d396SJeff Niu foldOp<index::CastUOp>(value, type, input);
58*9744d396SJeff Niu };
59*9744d396SJeff Niu
60*9744d396SJeff Niu // Target width less than or equal to 32 bits.
61*9744d396SJeff Niu fold(b.getIntegerType(16), b.getIndexAttr(8000000000));
62*9744d396SJeff Niu ASSERT_TRUE(value);
63*9744d396SJeff Niu EXPECT_EQ(value.getInt(), 20480u);
64*9744d396SJeff Niu
65*9744d396SJeff Niu // Target width greater than or equal to 64 bits.
66*9744d396SJeff Niu fold(b.getIntegerType(64), b.getIndexAttr(2000));
67*9744d396SJeff Niu ASSERT_TRUE(value);
68*9744d396SJeff Niu EXPECT_EQ(value.getInt(), 2000u);
69*9744d396SJeff Niu
70*9744d396SJeff Niu // Fails to fold, because truncating to 32 bits and then extending creates a
71*9744d396SJeff Niu // different value.
72*9744d396SJeff Niu fold(b.getIntegerType(64), b.getIndexAttr(8000000000));
73*9744d396SJeff Niu EXPECT_FALSE(value);
74*9744d396SJeff Niu
75*9744d396SJeff Niu // Target width between 32 and 64 bits.
76*9744d396SJeff Niu fold(b.getIntegerType(40), b.getIndexAttr(0x10000000010000));
77*9744d396SJeff Niu // Fold succeeds because the upper bits are truncated in the cast.
78*9744d396SJeff Niu ASSERT_TRUE(value);
79*9744d396SJeff Niu EXPECT_EQ(value.getInt(), 65536);
80*9744d396SJeff Niu
81*9744d396SJeff Niu // Fails to fold because the upper bits are not truncated.
82*9744d396SJeff Niu fold(b.getIntegerType(60), b.getIndexAttr(0x10000000010000));
83*9744d396SJeff Niu EXPECT_FALSE(value);
84*9744d396SJeff Niu }
85*9744d396SJeff Niu
TEST_F(IndexFolderTest,TestCastSOpFolder)86*9744d396SJeff Niu TEST_F(IndexFolderTest, TestCastSOpFolder) {
87*9744d396SJeff Niu IntegerAttr value;
88*9744d396SJeff Niu auto fold = [&](Type type, Attribute input) {
89*9744d396SJeff Niu foldOp<index::CastSOp>(value, type, input);
90*9744d396SJeff Niu };
91*9744d396SJeff Niu
92*9744d396SJeff Niu // Just test the extension cases to ensure signs are being respected.
93*9744d396SJeff Niu
94*9744d396SJeff Niu // Target width greater than or equal to 64 bits.
95*9744d396SJeff Niu fold(b.getIntegerType(64), b.getIndexAttr(-2000));
96*9744d396SJeff Niu ASSERT_TRUE(value);
97*9744d396SJeff Niu EXPECT_EQ(value.getInt(), -2000);
98*9744d396SJeff Niu
99*9744d396SJeff Niu // Target width between 32 and 64 bits.
100*9744d396SJeff Niu fold(b.getIntegerType(40), b.getIndexAttr(-0x10000000010000));
101*9744d396SJeff Niu // Fold succeeds because the upper bits are truncated in the cast.
102*9744d396SJeff Niu ASSERT_TRUE(value);
103*9744d396SJeff Niu EXPECT_EQ(value.getInt(), -65536);
104*9744d396SJeff Niu }
105