xref: /llvm-project/mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp (revision 831041be797b099b4e3805db368bacb1d1abab5d)
1 //===- IndexingUtilsTest.cpp - IndexingUtils unit tests -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/IndexingUtils.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 
TEST(StaticTileOffsetRange,checkIteratorCanonicalOrder)15 TEST(StaticTileOffsetRange, checkIteratorCanonicalOrder) {
16   // Tile <4x8> by <2x4> with canonical row-major order.
17   std::vector<SmallVector<int64_t>> expected = {{0, 0}, {0, 4}, {2, 0}, {2, 4}};
18   for (auto [idx, tileOffset] :
19        llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1})))
20     EXPECT_EQ(tileOffset, expected[idx]);
21 
22   // Check the constructor for default order and test use with zip iterator.
23   for (auto [tileOffset, tileOffsetDefault] :
24        llvm::zip(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1}),
25                  StaticTileOffsetRange({4, 8}, {2, 4})))
26     EXPECT_EQ(tileOffset, tileOffsetDefault);
27 }
28 
TEST(StaticTileOffsetRange,checkIteratorRowMajorOrder)29 TEST(StaticTileOffsetRange, checkIteratorRowMajorOrder) {
30   // Tile <4x8> by <2x4> with canonical row-major order.
31   std::vector<SmallVector<int64_t>> expected = {{0, 0}, {2, 0}, {0, 4}, {2, 4}};
32   for (auto [idx, tileOffset] :
33        llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {1, 0})))
34     EXPECT_EQ(tileOffset, expected[idx]);
35 }
36 
TEST(StaticTileOffsetRange,checkLeadingOneFill)37 TEST(StaticTileOffsetRange, checkLeadingOneFill) {
38   // Tile <4x8> by <4>. A smaller tile shape gets right-aligned to the shape.
39   for (auto [idx, tileOffset] :
40        llvm::enumerate(StaticTileOffsetRange({4, 8}, {4}))) {
41     SmallVector<int64_t> expected = {static_cast<int64_t>(idx) / 2,
42                                      static_cast<int64_t>(idx) % 2 * 4};
43     EXPECT_EQ(tileOffset, expected);
44   }
45   for (auto [idx, tileOffset] :
46        llvm::enumerate(StaticTileOffsetRange({1, 4, 8}, {4}, {2, 1, 0}))) {
47     SmallVector<int64_t> expected = {0, static_cast<int64_t>(idx) % 4,
48                                      (static_cast<int64_t>(idx) / 4) * 4};
49     EXPECT_EQ(tileOffset, expected);
50   }
51 }
52 
TEST(StaticTileOffsetRange,checkIterator3DPermutation)53 TEST(StaticTileOffsetRange, checkIterator3DPermutation) {
54   // Tile <8x4x2> by <4x2x1> with permutation [1, 0, 2]
55   for (auto [idx, tileOffset] : llvm::enumerate(
56            StaticTileOffsetRange({8, 4, 2}, {4, 2, 1}, {1, 0, 2}))) {
57     SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 4,
58                                      ((static_cast<int64_t>(idx) / 4) % 2) * 2,
59                                      static_cast<int64_t>(idx) % 2};
60     EXPECT_EQ(tileOffset, expected);
61   }
62 
63   // Tile <10x20x30> by <5x10x16> with permutation [2, 0, 1]
64   for (auto [idx, tileOffset] : llvm::enumerate(
65            StaticTileOffsetRange({10, 20, 30}, {5, 10, 15}, {2, 0, 1}))) {
66     SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 5,
67                                      (static_cast<int64_t>(idx) % 2) * 10,
68                                      (static_cast<int64_t>(idx) / 4) % 2 * 15};
69     EXPECT_EQ(tileOffset, expected);
70   }
71 }
72