1*831041beSChristopher Bate //===- IndexingUtilsTest.cpp - IndexingUtils unit tests -------------------===//
2*831041beSChristopher Bate //
3*831041beSChristopher Bate // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*831041beSChristopher Bate // See https://llvm.org/LICENSE.txt for license information.
5*831041beSChristopher Bate // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*831041beSChristopher Bate //
7*831041beSChristopher Bate //===----------------------------------------------------------------------===//
8*831041beSChristopher Bate
9*831041beSChristopher Bate #include "mlir/Dialect/Utils/IndexingUtils.h"
10*831041beSChristopher Bate #include "llvm/ADT/STLExtras.h"
11*831041beSChristopher Bate #include "gtest/gtest.h"
12*831041beSChristopher Bate
13*831041beSChristopher Bate using namespace mlir;
14*831041beSChristopher Bate
TEST(StaticTileOffsetRange,checkIteratorCanonicalOrder)15*831041beSChristopher Bate TEST(StaticTileOffsetRange, checkIteratorCanonicalOrder) {
16*831041beSChristopher Bate // Tile <4x8> by <2x4> with canonical row-major order.
17*831041beSChristopher Bate std::vector<SmallVector<int64_t>> expected = {{0, 0}, {0, 4}, {2, 0}, {2, 4}};
18*831041beSChristopher Bate for (auto [idx, tileOffset] :
19*831041beSChristopher Bate llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1})))
20*831041beSChristopher Bate EXPECT_EQ(tileOffset, expected[idx]);
21*831041beSChristopher Bate
22*831041beSChristopher Bate // Check the constructor for default order and test use with zip iterator.
23*831041beSChristopher Bate for (auto [tileOffset, tileOffsetDefault] :
24*831041beSChristopher Bate llvm::zip(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1}),
25*831041beSChristopher Bate StaticTileOffsetRange({4, 8}, {2, 4})))
26*831041beSChristopher Bate EXPECT_EQ(tileOffset, tileOffsetDefault);
27*831041beSChristopher Bate }
28*831041beSChristopher Bate
TEST(StaticTileOffsetRange,checkIteratorRowMajorOrder)29*831041beSChristopher Bate TEST(StaticTileOffsetRange, checkIteratorRowMajorOrder) {
30*831041beSChristopher Bate // Tile <4x8> by <2x4> with canonical row-major order.
31*831041beSChristopher Bate std::vector<SmallVector<int64_t>> expected = {{0, 0}, {2, 0}, {0, 4}, {2, 4}};
32*831041beSChristopher Bate for (auto [idx, tileOffset] :
33*831041beSChristopher Bate llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {1, 0})))
34*831041beSChristopher Bate EXPECT_EQ(tileOffset, expected[idx]);
35*831041beSChristopher Bate }
36*831041beSChristopher Bate
TEST(StaticTileOffsetRange,checkLeadingOneFill)37*831041beSChristopher Bate TEST(StaticTileOffsetRange, checkLeadingOneFill) {
38*831041beSChristopher Bate // Tile <4x8> by <4>. A smaller tile shape gets right-aligned to the shape.
39*831041beSChristopher Bate for (auto [idx, tileOffset] :
40*831041beSChristopher Bate llvm::enumerate(StaticTileOffsetRange({4, 8}, {4}))) {
41*831041beSChristopher Bate SmallVector<int64_t> expected = {static_cast<int64_t>(idx) / 2,
42*831041beSChristopher Bate static_cast<int64_t>(idx) % 2 * 4};
43*831041beSChristopher Bate EXPECT_EQ(tileOffset, expected);
44*831041beSChristopher Bate }
45*831041beSChristopher Bate for (auto [idx, tileOffset] :
46*831041beSChristopher Bate llvm::enumerate(StaticTileOffsetRange({1, 4, 8}, {4}, {2, 1, 0}))) {
47*831041beSChristopher Bate SmallVector<int64_t> expected = {0, static_cast<int64_t>(idx) % 4,
48*831041beSChristopher Bate (static_cast<int64_t>(idx) / 4) * 4};
49*831041beSChristopher Bate EXPECT_EQ(tileOffset, expected);
50*831041beSChristopher Bate }
51*831041beSChristopher Bate }
52*831041beSChristopher Bate
TEST(StaticTileOffsetRange,checkIterator3DPermutation)53*831041beSChristopher Bate TEST(StaticTileOffsetRange, checkIterator3DPermutation) {
54*831041beSChristopher Bate // Tile <8x4x2> by <4x2x1> with permutation [1, 0, 2]
55*831041beSChristopher Bate for (auto [idx, tileOffset] : llvm::enumerate(
56*831041beSChristopher Bate StaticTileOffsetRange({8, 4, 2}, {4, 2, 1}, {1, 0, 2}))) {
57*831041beSChristopher Bate SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 4,
58*831041beSChristopher Bate ((static_cast<int64_t>(idx) / 4) % 2) * 2,
59*831041beSChristopher Bate static_cast<int64_t>(idx) % 2};
60*831041beSChristopher Bate EXPECT_EQ(tileOffset, expected);
61*831041beSChristopher Bate }
62*831041beSChristopher Bate
63*831041beSChristopher Bate // Tile <10x20x30> by <5x10x16> with permutation [2, 0, 1]
64*831041beSChristopher Bate for (auto [idx, tileOffset] : llvm::enumerate(
65*831041beSChristopher Bate StaticTileOffsetRange({10, 20, 30}, {5, 10, 15}, {2, 0, 1}))) {
66*831041beSChristopher Bate SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 5,
67*831041beSChristopher Bate (static_cast<int64_t>(idx) % 2) * 10,
68*831041beSChristopher Bate (static_cast<int64_t>(idx) / 4) % 2 * 15};
69*831041beSChristopher Bate EXPECT_EQ(tileOffset, expected);
70*831041beSChristopher Bate }
71*831041beSChristopher Bate }
72