xref: /llvm-project/llvm/unittests/Support/BalancedPartitioningTest.cpp (revision 30aa9fb4c1da33892a38f952fbdf6e7e45e5953a)
11117b9a2SEllis Hoag //===- BalancedPartitioningTest.cpp - BalancedPartitioning tests ----------===//
21117b9a2SEllis Hoag //
31117b9a2SEllis Hoag // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41117b9a2SEllis Hoag // See https://llvm.org/LICENSE.txt for license information.
51117b9a2SEllis Hoag // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61117b9a2SEllis Hoag //
71117b9a2SEllis Hoag //===----------------------------------------------------------------------===//
81117b9a2SEllis Hoag 
91117b9a2SEllis Hoag #include "llvm/Support/BalancedPartitioning.h"
101117b9a2SEllis Hoag #include "llvm/Testing/Support/SupportHelpers.h"
111117b9a2SEllis Hoag #include "gmock/gmock.h"
121117b9a2SEllis Hoag #include "gtest/gtest.h"
131117b9a2SEllis Hoag 
141117b9a2SEllis Hoag using testing::Each;
151117b9a2SEllis Hoag using testing::Field;
161117b9a2SEllis Hoag using testing::Not;
171117b9a2SEllis Hoag using testing::UnorderedElementsAre;
181117b9a2SEllis Hoag using testing::UnorderedElementsAreArray;
191117b9a2SEllis Hoag 
201117b9a2SEllis Hoag namespace llvm {
211117b9a2SEllis Hoag 
PrintTo(const BPFunctionNode & Node,std::ostream * OS)221117b9a2SEllis Hoag void PrintTo(const BPFunctionNode &Node, std::ostream *OS) {
231117b9a2SEllis Hoag   raw_os_ostream ROS(*OS);
241117b9a2SEllis Hoag   Node.dump(ROS);
251117b9a2SEllis Hoag }
261117b9a2SEllis Hoag 
271117b9a2SEllis Hoag class BalancedPartitioningTest : public ::testing::Test {
281117b9a2SEllis Hoag protected:
291117b9a2SEllis Hoag   BalancedPartitioningConfig Config;
301117b9a2SEllis Hoag   BalancedPartitioning Bp;
BalancedPartitioningTest()311117b9a2SEllis Hoag   BalancedPartitioningTest() : Bp(Config) {}
321117b9a2SEllis Hoag 
331117b9a2SEllis Hoag   static std::vector<BPFunctionNode::IDT>
getIds(std::vector<BPFunctionNode> Nodes)341117b9a2SEllis Hoag   getIds(std::vector<BPFunctionNode> Nodes) {
351117b9a2SEllis Hoag     std::vector<BPFunctionNode::IDT> Ids;
361117b9a2SEllis Hoag     for (auto &N : Nodes)
371117b9a2SEllis Hoag       Ids.push_back(N.Id);
381117b9a2SEllis Hoag     return Ids;
391117b9a2SEllis Hoag   }
401117b9a2SEllis Hoag };
411117b9a2SEllis Hoag 
TEST_F(BalancedPartitioningTest,Basic)421117b9a2SEllis Hoag TEST_F(BalancedPartitioningTest, Basic) {
431117b9a2SEllis Hoag   std::vector<BPFunctionNode> Nodes = {
441117b9a2SEllis Hoag       BPFunctionNode(0, {1, 2}), BPFunctionNode(2, {3, 4}),
451117b9a2SEllis Hoag       BPFunctionNode(1, {1, 2}), BPFunctionNode(3, {3, 4}),
461117b9a2SEllis Hoag       BPFunctionNode(4, {4}),
471117b9a2SEllis Hoag   };
481117b9a2SEllis Hoag 
491117b9a2SEllis Hoag   Bp.run(Nodes);
501117b9a2SEllis Hoag 
51*30aa9fb4Sspupyrev   auto NodeIs = [](BPFunctionNode::IDT Id, std::optional<uint32_t> Bucket) {
52*30aa9fb4Sspupyrev     return AllOf(Field("Id", &BPFunctionNode::Id, Id),
53*30aa9fb4Sspupyrev                  Field("Bucket", &BPFunctionNode::Bucket, Bucket));
54*30aa9fb4Sspupyrev   };
55*30aa9fb4Sspupyrev 
561117b9a2SEllis Hoag   EXPECT_THAT(Nodes,
571117b9a2SEllis Hoag               UnorderedElementsAre(NodeIs(0, 0), NodeIs(1, 1), NodeIs(2, 2),
581117b9a2SEllis Hoag                                    NodeIs(3, 3), NodeIs(4, 4)));
591117b9a2SEllis Hoag }
601117b9a2SEllis Hoag 
TEST_F(BalancedPartitioningTest,Large)611117b9a2SEllis Hoag TEST_F(BalancedPartitioningTest, Large) {
621117b9a2SEllis Hoag   const int ProblemSize = 1000;
631117b9a2SEllis Hoag   std::vector<BPFunctionNode::UtilityNodeT> AllUNs;
641117b9a2SEllis Hoag   for (int i = 0; i < ProblemSize; i++)
651117b9a2SEllis Hoag     AllUNs.emplace_back(i);
661117b9a2SEllis Hoag 
671117b9a2SEllis Hoag   std::mt19937 RNG;
681117b9a2SEllis Hoag   std::vector<BPFunctionNode> Nodes;
691117b9a2SEllis Hoag   for (int i = 0; i < ProblemSize; i++) {
701117b9a2SEllis Hoag     std::vector<BPFunctionNode::UtilityNodeT> UNs;
711117b9a2SEllis Hoag     int SampleSize =
721117b9a2SEllis Hoag         std::uniform_int_distribution<int>(0, AllUNs.size() - 1)(RNG);
731117b9a2SEllis Hoag     std::sample(AllUNs.begin(), AllUNs.end(), std::back_inserter(UNs),
741117b9a2SEllis Hoag                 SampleSize, RNG);
751117b9a2SEllis Hoag     Nodes.emplace_back(i, UNs);
761117b9a2SEllis Hoag   }
771117b9a2SEllis Hoag 
781117b9a2SEllis Hoag   auto OrigIds = getIds(Nodes);
791117b9a2SEllis Hoag 
801117b9a2SEllis Hoag   Bp.run(Nodes);
811117b9a2SEllis Hoag 
82*30aa9fb4Sspupyrev   EXPECT_THAT(
83*30aa9fb4Sspupyrev       Nodes, Each(Not(Field("Bucket", &BPFunctionNode::Bucket, std::nullopt))));
841117b9a2SEllis Hoag   EXPECT_THAT(getIds(Nodes), UnorderedElementsAreArray(OrigIds));
851117b9a2SEllis Hoag }
861117b9a2SEllis Hoag 
TEST_F(BalancedPartitioningTest,MoveGain)871117b9a2SEllis Hoag TEST_F(BalancedPartitioningTest, MoveGain) {
881117b9a2SEllis Hoag   BalancedPartitioning::SignaturesT Signatures = {
891117b9a2SEllis Hoag       {10, 10, 10.f, 0.f, true}, // 0
901117b9a2SEllis Hoag       {10, 10, 0.f, 10.f, true}, // 1
911117b9a2SEllis Hoag       {10, 10, 0.f, 20.f, true}, // 2
921117b9a2SEllis Hoag   };
931117b9a2SEllis Hoag   EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {}), true, Signatures), 0.f);
941117b9a2SEllis Hoag   EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {0, 1}), true, Signatures),
951117b9a2SEllis Hoag                   10.f);
961117b9a2SEllis Hoag   EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {1, 2}), false, Signatures),
971117b9a2SEllis Hoag                   30.f);
981117b9a2SEllis Hoag }
991117b9a2SEllis Hoag 
1001117b9a2SEllis Hoag } // end namespace llvm
101