xref: /llvm-project/llvm/lib/Support/BalancedPartitioning.cpp (revision 5954b9dca21bb0c69b9e991b2ddb84c8b05ecba3)
1 //===- BalancedPartitioning.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements BalancedPartitioning, a recursive balanced graph
10 // partitioning algorithm.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/BalancedPartitioning.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/ThreadPool.h"
19 
20 using namespace llvm;
21 #define DEBUG_TYPE "balanced-partitioning"
22 
23 void BPFunctionNode::dump(raw_ostream &OS) const {
24   OS << "{ID=" << Id << " Utilities={";
25   for (auto &N : UtilityNodes)
26     OS << N.Id << " ,";
27   OS << "}";
28   if (Bucket.has_value())
29     OS << " Bucket=" << Bucket.value();
30   OS << "}";
31 }
32 
33 template <typename Func>
34 void BalancedPartitioning::BPThreadPool::async(Func &&F) {
35 #if LLVM_ENABLE_THREADS
36   // This new thread could spawn more threads, so mark it as active
37   ++NumActiveThreads;
38   TheThreadPool.async([=]() {
39     // Run the task
40     F();
41 
42     // This thread will no longer spawn new threads, so mark it as inactive
43     if (--NumActiveThreads == 0) {
44       // There are no more active threads, so mark as finished and notify
45       {
46         std::unique_lock<std::mutex> lock(mtx);
47         assert(!IsFinishedSpawning);
48         IsFinishedSpawning = true;
49       }
50       cv.notify_one();
51     }
52   });
53 #else
54   llvm_unreachable("threads are disabled");
55 #endif
56 }
57 
58 void BalancedPartitioning::BPThreadPool::wait() {
59 #if LLVM_ENABLE_THREADS
60   // TODO: We could remove the mutex and condition variable and use
61   // std::atomic::wait() instead, but that isn't available until C++20
62   {
63     std::unique_lock<std::mutex> lock(mtx);
64     cv.wait(lock, [&]() { return IsFinishedSpawning; });
65     assert(IsFinishedSpawning && NumActiveThreads == 0);
66   }
67   // Now we can call ThreadPool::wait() since all tasks have been submitted
68   TheThreadPool.wait();
69 #else
70   llvm_unreachable("threads are disabled");
71 #endif
72 }
73 
74 BalancedPartitioning::BalancedPartitioning(
75     const BalancedPartitioningConfig &Config)
76     : Config(Config) {
77   // Pre-computing log2 values
78   Log2Cache[0] = 0.0;
79   for (unsigned I = 1; I < LOG_CACHE_SIZE; I++)
80     Log2Cache[I] = std::log2(I);
81 }
82 
83 void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const {
84   LLVM_DEBUG(
85       dbgs() << format(
86           "Partitioning %d nodes using depth %d and %d iterations per split\n",
87           Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit));
88   std::optional<BPThreadPool> TP;
89 #if LLVM_ENABLE_THREADS
90   ThreadPool TheThreadPool;
91   if (Config.TaskSplitDepth > 1)
92     TP.emplace(TheThreadPool);
93 #endif
94 
95   // Record the input order
96   for (unsigned I = 0; I < Nodes.size(); I++)
97     Nodes[I].InputOrderIndex = I;
98 
99   auto NodesRange = llvm::make_range(Nodes.begin(), Nodes.end());
100   auto BisectTask = [=, &TP]() {
101     bisect(NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP);
102   };
103   if (TP) {
104     TP->async(std::move(BisectTask));
105     TP->wait();
106   } else {
107     BisectTask();
108   }
109 
110   llvm::stable_sort(NodesRange, [](const auto &L, const auto &R) {
111     return L.Bucket < R.Bucket;
112   });
113 
114   LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n");
115 }
116 
117 void BalancedPartitioning::bisect(const FunctionNodeRange Nodes,
118                                   unsigned RecDepth, unsigned RootBucket,
119                                   unsigned Offset,
120                                   std::optional<BPThreadPool> &TP) const {
121   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
122   if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) {
123     // We've reach the lowest level of the recursion tree. Fall back to the
124     // original order and assign to buckets.
125     llvm::sort(Nodes, [](const auto &L, const auto &R) {
126       return L.InputOrderIndex < R.InputOrderIndex;
127     });
128     for (auto &N : Nodes)
129       N.Bucket = Offset++;
130     return;
131   }
132 
133   LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n",
134                               NumNodes, RootBucket));
135 
136   std::mt19937 RNG(RootBucket);
137 
138   unsigned LeftBucket = 2 * RootBucket;
139   unsigned RightBucket = 2 * RootBucket + 1;
140 
141   // Split into two and assign to the left and right buckets
142   split(Nodes, LeftBucket);
143 
144   runIterations(Nodes, RecDepth, LeftBucket, RightBucket, RNG);
145 
146   // Split nodes wrt the resulting buckets
147   auto NodesMid =
148       llvm::partition(Nodes, [&](auto &N) { return N.Bucket == LeftBucket; });
149   unsigned MidOffset = Offset + std::distance(Nodes.begin(), NodesMid);
150 
151   auto LeftNodes = llvm::make_range(Nodes.begin(), NodesMid);
152   auto RightNodes = llvm::make_range(NodesMid, Nodes.end());
153 
154   auto LeftRecTask = [=, &TP]() {
155     bisect(LeftNodes, RecDepth + 1, LeftBucket, Offset, TP);
156   };
157   auto RightRecTask = [=, &TP]() {
158     bisect(RightNodes, RecDepth + 1, RightBucket, MidOffset, TP);
159   };
160 
161   if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) {
162     TP->async(std::move(LeftRecTask));
163     TP->async(std::move(RightRecTask));
164   } else {
165     LeftRecTask();
166     RightRecTask();
167   }
168 }
169 
170 void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes,
171                                          unsigned RecDepth, unsigned LeftBucket,
172                                          unsigned RightBucket,
173                                          std::mt19937 &RNG) const {
174   // Count the degree of each utility node.
175   DenseMap<uint32_t, unsigned> UtilityNodeIndex;
176   for (auto &N : Nodes)
177     for (auto &UN : N.UtilityNodes)
178       ++UtilityNodeIndex[UN.Id];
179   // Remove utility nodes if they have just one edge or are connected to all
180   // functions.
181   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
182   for (auto &N : Nodes)
183     llvm::erase_if(N.UtilityNodes, [&](auto &UN) {
184       return UtilityNodeIndex[UN.Id] == 1 ||
185              UtilityNodeIndex[UN.Id] == NumNodes;
186     });
187 
188   // Renumber utility nodes so they can be used to index into Signatures.
189   UtilityNodeIndex.clear();
190   for (auto &N : Nodes)
191     for (auto &UN : N.UtilityNodes)
192       UN.Id = UtilityNodeIndex.insert({UN.Id, UtilityNodeIndex.size()})
193                   .first->second;
194 
195   // Initialize signatures.
196   SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size());
197   for (auto &N : Nodes) {
198     for (auto &UN : N.UtilityNodes) {
199       assert(UN.Id < Signatures.size());
200       if (N.Bucket == LeftBucket)
201         Signatures[UN.Id].LeftCount++;
202       else
203         Signatures[UN.Id].RightCount++;
204       // Identical utility nodes (having the same UN.Id) have the same weight
205       // (unless there are hash collisions mapping utilities to the same Id);
206       // thus, we get a new weight only when the signature is uninitialized.
207       Signatures[UN.Id].Weight = UN.Weight;
208     }
209   }
210 
211   for (unsigned I = 0; I < Config.IterationsPerSplit; I++) {
212     unsigned NumMovedNodes =
213         runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG);
214     if (NumMovedNodes == 0)
215       break;
216   }
217 }
218 
219 unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes,
220                                             unsigned LeftBucket,
221                                             unsigned RightBucket,
222                                             SignaturesT &Signatures,
223                                             std::mt19937 &RNG) const {
224   // Init signature cost caches
225   for (auto &Signature : Signatures) {
226     if (Signature.CachedGainIsValid)
227       continue;
228     unsigned L = Signature.LeftCount;
229     unsigned R = Signature.RightCount;
230     assert((L > 0 || R > 0) && "incorrect signature");
231     float Cost = logCost(L, R);
232     Signature.CachedGainLR = 0.f;
233     Signature.CachedGainRL = 0.f;
234     if (L > 0)
235       Signature.CachedGainLR =
236           (Cost - logCost(L - 1, R + 1)) * Signature.Weight;
237     if (R > 0)
238       Signature.CachedGainRL =
239           (Cost - logCost(L + 1, R - 1)) * Signature.Weight;
240     Signature.CachedGainIsValid = true;
241   }
242 
243   // Compute move gains
244   typedef std::pair<float, BPFunctionNode *> GainPair;
245   std::vector<GainPair> Gains;
246   for (auto &N : Nodes) {
247     bool FromLeftToRight = (N.Bucket == LeftBucket);
248     float Gain = moveGain(N, FromLeftToRight, Signatures);
249     Gains.push_back(std::make_pair(Gain, &N));
250   }
251 
252   // Collect left and right gains
253   auto LeftEnd = llvm::partition(
254       Gains, [&](const auto &GP) { return GP.second->Bucket == LeftBucket; });
255   auto LeftRange = llvm::make_range(Gains.begin(), LeftEnd);
256   auto RightRange = llvm::make_range(LeftEnd, Gains.end());
257 
258   // Sort gains in descending order
259   auto LargerGain = [](const auto &L, const auto &R) {
260     return L.first > R.first;
261   };
262   llvm::stable_sort(LeftRange, LargerGain);
263   llvm::stable_sort(RightRange, LargerGain);
264 
265   unsigned NumMovedDataVertices = 0;
266   for (auto [LeftPair, RightPair] : llvm::zip(LeftRange, RightRange)) {
267     auto &[LeftGain, LeftNode] = LeftPair;
268     auto &[RightGain, RightNode] = RightPair;
269     // Stop when the gain is no longer beneficial
270     if (LeftGain + RightGain <= 0.f)
271       break;
272     // Try to exchange the nodes between buckets
273     if (moveFunctionNode(*LeftNode, LeftBucket, RightBucket, Signatures, RNG))
274       ++NumMovedDataVertices;
275     if (moveFunctionNode(*RightNode, LeftBucket, RightBucket, Signatures, RNG))
276       ++NumMovedDataVertices;
277   }
278   return NumMovedDataVertices;
279 }
280 
281 bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N,
282                                             unsigned LeftBucket,
283                                             unsigned RightBucket,
284                                             SignaturesT &Signatures,
285                                             std::mt19937 &RNG) const {
286   // Sometimes we skip the move. This helps to escape local optima
287   if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <=
288       Config.SkipProbability)
289     return false;
290 
291   bool FromLeftToRight = (N.Bucket == LeftBucket);
292   // Update the current bucket
293   N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket);
294 
295   // Update signatures and invalidate gain cache
296   if (FromLeftToRight) {
297     for (auto &UN : N.UtilityNodes) {
298       auto &Signature = Signatures[UN.Id];
299       Signature.LeftCount--;
300       Signature.RightCount++;
301       Signature.CachedGainIsValid = false;
302     }
303   } else {
304     for (auto &UN : N.UtilityNodes) {
305       auto &Signature = Signatures[UN.Id];
306       Signature.LeftCount++;
307       Signature.RightCount--;
308       Signature.CachedGainIsValid = false;
309     }
310   }
311   return true;
312 }
313 
314 void BalancedPartitioning::split(const FunctionNodeRange Nodes,
315                                  unsigned StartBucket) const {
316   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
317   auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2;
318 
319   std::nth_element(Nodes.begin(), NodesMid, Nodes.end(), [](auto &L, auto &R) {
320     return L.InputOrderIndex < R.InputOrderIndex;
321   });
322 
323   for (auto &N : llvm::make_range(Nodes.begin(), NodesMid))
324     N.Bucket = StartBucket;
325   for (auto &N : llvm::make_range(NodesMid, Nodes.end()))
326     N.Bucket = StartBucket + 1;
327 }
328 
329 float BalancedPartitioning::moveGain(const BPFunctionNode &N,
330                                      bool FromLeftToRight,
331                                      const SignaturesT &Signatures) {
332   float Gain = 0.f;
333   for (auto &UN : N.UtilityNodes)
334     Gain += (FromLeftToRight ? Signatures[UN.Id].CachedGainLR
335                              : Signatures[UN.Id].CachedGainRL);
336   return Gain;
337 }
338 
339 float BalancedPartitioning::logCost(unsigned X, unsigned Y) const {
340   return -(X * log2Cached(X + 1) + Y * log2Cached(Y + 1));
341 }
342 
343 float BalancedPartitioning::log2Cached(unsigned i) const {
344   return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(i);
345 }
346