xref: /llvm-project/llvm/lib/Support/BalancedPartitioning.cpp (revision 1f9f68a1cdbfaed813b35137a600bd76532f0c7e)
1 //===- BalancedPartitioning.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements BalancedPartitioning, a recursive balanced graph
10 // partitioning algorithm.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/BalancedPartitioning.h"
15 #include "llvm/Config/llvm-config.h" // for LLVM_ENABLE_THREADS
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/ThreadPool.h"
20 
21 using namespace llvm;
22 #define DEBUG_TYPE "balanced-partitioning"
23 
24 void BPFunctionNode::dump(raw_ostream &OS) const {
25   OS << formatv("{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}", Id,
26                 make_range(UtilityNodes.begin(), UtilityNodes.end()), Bucket);
27 }
28 
29 template <typename Func>
30 void BalancedPartitioning::BPThreadPool::async(Func &&F) {
31 #if LLVM_ENABLE_THREADS
32   // This new thread could spawn more threads, so mark it as active
33   ++NumActiveThreads;
34   TheThreadPool.async([this, F]() {
35     // Run the task
36     F();
37 
38     // This thread will no longer spawn new threads, so mark it as inactive
39     if (--NumActiveThreads == 0) {
40       // There are no more active threads, so mark as finished and notify
41       {
42         std::unique_lock<std::mutex> lock(mtx);
43         assert(!IsFinishedSpawning);
44         IsFinishedSpawning = true;
45       }
46       cv.notify_one();
47     }
48   });
49 #else
50   llvm_unreachable("threads are disabled");
51 #endif
52 }
53 
54 void BalancedPartitioning::BPThreadPool::wait() {
55 #if LLVM_ENABLE_THREADS
56   // TODO: We could remove the mutex and condition variable and use
57   // std::atomic::wait() instead, but that isn't available until C++20
58   {
59     std::unique_lock<std::mutex> lock(mtx);
60     cv.wait(lock, [&]() { return IsFinishedSpawning; });
61     assert(IsFinishedSpawning && NumActiveThreads == 0);
62   }
63   // Now we can call ThreadPool::wait() since all tasks have been submitted
64   TheThreadPool.wait();
65 #else
66   llvm_unreachable("threads are disabled");
67 #endif
68 }
69 
70 BalancedPartitioning::BalancedPartitioning(
71     const BalancedPartitioningConfig &Config)
72     : Config(Config) {
73   // Pre-computing log2 values
74   Log2Cache[0] = 0.0;
75   for (unsigned I = 1; I < LOG_CACHE_SIZE; I++)
76     Log2Cache[I] = std::log2(I);
77 }
78 
79 void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const {
80   LLVM_DEBUG(
81       dbgs() << format(
82           "Partitioning %d nodes using depth %d and %d iterations per split\n",
83           Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit));
84   std::optional<BPThreadPool> TP;
85 #if LLVM_ENABLE_THREADS
86   DefaultThreadPool TheThreadPool;
87   if (Config.TaskSplitDepth > 1)
88     TP.emplace(TheThreadPool);
89 #endif
90 
91   // Record the input order
92   for (unsigned I = 0; I < Nodes.size(); I++)
93     Nodes[I].InputOrderIndex = I;
94 
95   auto NodesRange = llvm::make_range(Nodes.begin(), Nodes.end());
96   auto BisectTask = [this, NodesRange, &TP]() {
97     bisect(NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP);
98   };
99   if (TP) {
100     TP->async(std::move(BisectTask));
101     TP->wait();
102   } else {
103     BisectTask();
104   }
105 
106   llvm::stable_sort(NodesRange, [](const auto &L, const auto &R) {
107     return L.Bucket < R.Bucket;
108   });
109 
110   LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n");
111 }
112 
113 void BalancedPartitioning::bisect(const FunctionNodeRange Nodes,
114                                   unsigned RecDepth, unsigned RootBucket,
115                                   unsigned Offset,
116                                   std::optional<BPThreadPool> &TP) const {
117   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
118   if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) {
119     // We've reach the lowest level of the recursion tree. Fall back to the
120     // original order and assign to buckets.
121     llvm::sort(Nodes, [](const auto &L, const auto &R) {
122       return L.InputOrderIndex < R.InputOrderIndex;
123     });
124     for (auto &N : Nodes)
125       N.Bucket = Offset++;
126     return;
127   }
128 
129   LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n",
130                               NumNodes, RootBucket));
131 
132   std::mt19937 RNG(RootBucket);
133 
134   unsigned LeftBucket = 2 * RootBucket;
135   unsigned RightBucket = 2 * RootBucket + 1;
136 
137   // Split into two and assign to the left and right buckets
138   split(Nodes, LeftBucket);
139 
140   runIterations(Nodes, LeftBucket, RightBucket, RNG);
141 
142   // Split nodes wrt the resulting buckets
143   auto NodesMid =
144       llvm::partition(Nodes, [&](auto &N) { return N.Bucket == LeftBucket; });
145   unsigned MidOffset = Offset + std::distance(Nodes.begin(), NodesMid);
146 
147   auto LeftNodes = llvm::make_range(Nodes.begin(), NodesMid);
148   auto RightNodes = llvm::make_range(NodesMid, Nodes.end());
149 
150   auto LeftRecTask = [this, LeftNodes, RecDepth, LeftBucket, Offset, &TP]() {
151     bisect(LeftNodes, RecDepth + 1, LeftBucket, Offset, TP);
152   };
153   auto RightRecTask = [this, RightNodes, RecDepth, RightBucket, MidOffset,
154                        &TP]() {
155     bisect(RightNodes, RecDepth + 1, RightBucket, MidOffset, TP);
156   };
157 
158   if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) {
159     TP->async(std::move(LeftRecTask));
160     TP->async(std::move(RightRecTask));
161   } else {
162     LeftRecTask();
163     RightRecTask();
164   }
165 }
166 
167 void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes,
168                                          unsigned LeftBucket,
169                                          unsigned RightBucket,
170                                          std::mt19937 &RNG) const {
171   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
172   DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeIndex;
173   for (auto &N : Nodes)
174     for (auto &UN : N.UtilityNodes)
175       ++UtilityNodeIndex[UN];
176   // Remove utility nodes if they have just one edge or are connected to all
177   // functions
178   for (auto &N : Nodes)
179     llvm::erase_if(N.UtilityNodes, [&](auto &UN) {
180       return UtilityNodeIndex[UN] == 1 || UtilityNodeIndex[UN] == NumNodes;
181     });
182 
183   // Renumber utility nodes so they can be used to index into Signatures
184   UtilityNodeIndex.clear();
185   for (auto &N : Nodes)
186     for (auto &UN : N.UtilityNodes)
187       UN = UtilityNodeIndex.insert({UN, UtilityNodeIndex.size()}).first->second;
188 
189   // Initialize signatures
190   SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size());
191   for (auto &N : Nodes) {
192     for (auto &UN : N.UtilityNodes) {
193       assert(UN < Signatures.size());
194       if (N.Bucket == LeftBucket) {
195         Signatures[UN].LeftCount++;
196       } else {
197         Signatures[UN].RightCount++;
198       }
199     }
200   }
201 
202   for (unsigned I = 0; I < Config.IterationsPerSplit; I++) {
203     unsigned NumMovedNodes =
204         runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG);
205     if (NumMovedNodes == 0)
206       break;
207   }
208 }
209 
210 unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes,
211                                             unsigned LeftBucket,
212                                             unsigned RightBucket,
213                                             SignaturesT &Signatures,
214                                             std::mt19937 &RNG) const {
215   // Init signature cost caches
216   for (auto &Signature : Signatures) {
217     if (Signature.CachedGainIsValid)
218       continue;
219     unsigned L = Signature.LeftCount;
220     unsigned R = Signature.RightCount;
221     assert((L > 0 || R > 0) && "incorrect signature");
222     float Cost = logCost(L, R);
223     Signature.CachedGainLR = 0.f;
224     Signature.CachedGainRL = 0.f;
225     if (L > 0)
226       Signature.CachedGainLR = Cost - logCost(L - 1, R + 1);
227     if (R > 0)
228       Signature.CachedGainRL = Cost - logCost(L + 1, R - 1);
229     Signature.CachedGainIsValid = true;
230   }
231 
232   // Compute move gains
233   typedef std::pair<float, BPFunctionNode *> GainPair;
234   std::vector<GainPair> Gains;
235   for (auto &N : Nodes) {
236     bool FromLeftToRight = (N.Bucket == LeftBucket);
237     float Gain = moveGain(N, FromLeftToRight, Signatures);
238     Gains.push_back(std::make_pair(Gain, &N));
239   }
240 
241   // Collect left and right gains
242   auto LeftEnd = llvm::partition(
243       Gains, [&](const auto &GP) { return GP.second->Bucket == LeftBucket; });
244   auto LeftRange = llvm::make_range(Gains.begin(), LeftEnd);
245   auto RightRange = llvm::make_range(LeftEnd, Gains.end());
246 
247   // Sort gains in descending order
248   auto LargerGain = [](const auto &L, const auto &R) {
249     return L.first > R.first;
250   };
251   llvm::stable_sort(LeftRange, LargerGain);
252   llvm::stable_sort(RightRange, LargerGain);
253 
254   unsigned NumMovedDataVertices = 0;
255   for (auto [LeftPair, RightPair] : llvm::zip(LeftRange, RightRange)) {
256     auto &[LeftGain, LeftNode] = LeftPair;
257     auto &[RightGain, RightNode] = RightPair;
258     // Stop when the gain is no longer beneficial
259     if (LeftGain + RightGain <= 0.f)
260       break;
261     // Try to exchange the nodes between buckets
262     if (moveFunctionNode(*LeftNode, LeftBucket, RightBucket, Signatures, RNG))
263       ++NumMovedDataVertices;
264     if (moveFunctionNode(*RightNode, LeftBucket, RightBucket, Signatures, RNG))
265       ++NumMovedDataVertices;
266   }
267   return NumMovedDataVertices;
268 }
269 
270 bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N,
271                                             unsigned LeftBucket,
272                                             unsigned RightBucket,
273                                             SignaturesT &Signatures,
274                                             std::mt19937 &RNG) const {
275   // Sometimes we skip the move. This helps to escape local optima
276   if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <=
277       Config.SkipProbability)
278     return false;
279 
280   bool FromLeftToRight = (N.Bucket == LeftBucket);
281   // Update the current bucket
282   N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket);
283 
284   // Update signatures and invalidate gain cache
285   if (FromLeftToRight) {
286     for (auto &UN : N.UtilityNodes) {
287       auto &Signature = Signatures[UN];
288       Signature.LeftCount--;
289       Signature.RightCount++;
290       Signature.CachedGainIsValid = false;
291     }
292   } else {
293     for (auto &UN : N.UtilityNodes) {
294       auto &Signature = Signatures[UN];
295       Signature.LeftCount++;
296       Signature.RightCount--;
297       Signature.CachedGainIsValid = false;
298     }
299   }
300   return true;
301 }
302 
303 void BalancedPartitioning::split(const FunctionNodeRange Nodes,
304                                  unsigned StartBucket) const {
305   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
306   auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2;
307 
308   std::nth_element(Nodes.begin(), NodesMid, Nodes.end(), [](auto &L, auto &R) {
309     return L.InputOrderIndex < R.InputOrderIndex;
310   });
311 
312   for (auto &N : llvm::make_range(Nodes.begin(), NodesMid))
313     N.Bucket = StartBucket;
314   for (auto &N : llvm::make_range(NodesMid, Nodes.end()))
315     N.Bucket = StartBucket + 1;
316 }
317 
318 float BalancedPartitioning::moveGain(const BPFunctionNode &N,
319                                      bool FromLeftToRight,
320                                      const SignaturesT &Signatures) {
321   float Gain = 0.f;
322   for (auto &UN : N.UtilityNodes)
323     Gain += (FromLeftToRight ? Signatures[UN].CachedGainLR
324                              : Signatures[UN].CachedGainRL);
325   return Gain;
326 }
327 
328 float BalancedPartitioning::logCost(unsigned X, unsigned Y) const {
329   return -(X * log2Cached(X + 1) + Y * log2Cached(Y + 1));
330 }
331 
332 float BalancedPartitioning::log2Cached(unsigned i) const {
333   return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(i);
334 }
335