xref: /llvm-project/llvm/unittests/Support/BalancedPartitioningTest.cpp (revision 30aa9fb4c1da33892a38f952fbdf6e7e45e5953a)
1 //===- BalancedPartitioningTest.cpp - BalancedPartitioning tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Support/BalancedPartitioning.h"
10 #include "llvm/Testing/Support/SupportHelpers.h"
11 #include "gmock/gmock.h"
12 #include "gtest/gtest.h"
13 
14 using testing::Each;
15 using testing::Field;
16 using testing::Not;
17 using testing::UnorderedElementsAre;
18 using testing::UnorderedElementsAreArray;
19 
20 namespace llvm {
21 
PrintTo(const BPFunctionNode & Node,std::ostream * OS)22 void PrintTo(const BPFunctionNode &Node, std::ostream *OS) {
23   raw_os_ostream ROS(*OS);
24   Node.dump(ROS);
25 }
26 
27 class BalancedPartitioningTest : public ::testing::Test {
28 protected:
29   BalancedPartitioningConfig Config;
30   BalancedPartitioning Bp;
BalancedPartitioningTest()31   BalancedPartitioningTest() : Bp(Config) {}
32 
33   static std::vector<BPFunctionNode::IDT>
getIds(std::vector<BPFunctionNode> Nodes)34   getIds(std::vector<BPFunctionNode> Nodes) {
35     std::vector<BPFunctionNode::IDT> Ids;
36     for (auto &N : Nodes)
37       Ids.push_back(N.Id);
38     return Ids;
39   }
40 };
41 
TEST_F(BalancedPartitioningTest,Basic)42 TEST_F(BalancedPartitioningTest, Basic) {
43   std::vector<BPFunctionNode> Nodes = {
44       BPFunctionNode(0, {1, 2}), BPFunctionNode(2, {3, 4}),
45       BPFunctionNode(1, {1, 2}), BPFunctionNode(3, {3, 4}),
46       BPFunctionNode(4, {4}),
47   };
48 
49   Bp.run(Nodes);
50 
51   auto NodeIs = [](BPFunctionNode::IDT Id, std::optional<uint32_t> Bucket) {
52     return AllOf(Field("Id", &BPFunctionNode::Id, Id),
53                  Field("Bucket", &BPFunctionNode::Bucket, Bucket));
54   };
55 
56   EXPECT_THAT(Nodes,
57               UnorderedElementsAre(NodeIs(0, 0), NodeIs(1, 1), NodeIs(2, 2),
58                                    NodeIs(3, 3), NodeIs(4, 4)));
59 }
60 
TEST_F(BalancedPartitioningTest,Large)61 TEST_F(BalancedPartitioningTest, Large) {
62   const int ProblemSize = 1000;
63   std::vector<BPFunctionNode::UtilityNodeT> AllUNs;
64   for (int i = 0; i < ProblemSize; i++)
65     AllUNs.emplace_back(i);
66 
67   std::mt19937 RNG;
68   std::vector<BPFunctionNode> Nodes;
69   for (int i = 0; i < ProblemSize; i++) {
70     std::vector<BPFunctionNode::UtilityNodeT> UNs;
71     int SampleSize =
72         std::uniform_int_distribution<int>(0, AllUNs.size() - 1)(RNG);
73     std::sample(AllUNs.begin(), AllUNs.end(), std::back_inserter(UNs),
74                 SampleSize, RNG);
75     Nodes.emplace_back(i, UNs);
76   }
77 
78   auto OrigIds = getIds(Nodes);
79 
80   Bp.run(Nodes);
81 
82   EXPECT_THAT(
83       Nodes, Each(Not(Field("Bucket", &BPFunctionNode::Bucket, std::nullopt))));
84   EXPECT_THAT(getIds(Nodes), UnorderedElementsAreArray(OrigIds));
85 }
86 
TEST_F(BalancedPartitioningTest,MoveGain)87 TEST_F(BalancedPartitioningTest, MoveGain) {
88   BalancedPartitioning::SignaturesT Signatures = {
89       {10, 10, 10.f, 0.f, true}, // 0
90       {10, 10, 0.f, 10.f, true}, // 1
91       {10, 10, 0.f, 20.f, true}, // 2
92   };
93   EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {}), true, Signatures), 0.f);
94   EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {0, 1}), true, Signatures),
95                   10.f);
96   EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {1, 2}), false, Signatures),
97                   30.f);
98 }
99 
100 } // end namespace llvm
101