11117b9a2SEllis Hoag //===- BalancedPartitioning.cpp -------------------------------------------===// 21117b9a2SEllis Hoag // 31117b9a2SEllis Hoag // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41117b9a2SEllis Hoag // See https://llvm.org/LICENSE.txt for license information. 51117b9a2SEllis Hoag // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61117b9a2SEllis Hoag // 71117b9a2SEllis Hoag //===----------------------------------------------------------------------===// 81117b9a2SEllis Hoag // 91117b9a2SEllis Hoag // This file implements BalancedPartitioning, a recursive balanced graph 101117b9a2SEllis Hoag // partitioning algorithm. 111117b9a2SEllis Hoag // 121117b9a2SEllis Hoag //===----------------------------------------------------------------------===// 131117b9a2SEllis Hoag 141117b9a2SEllis Hoag #include "llvm/Support/BalancedPartitioning.h" 1589e6a288SDaniil Fukalov #include "llvm/Config/llvm-config.h" // for LLVM_ENABLE_THREADS 161117b9a2SEllis Hoag #include "llvm/Support/Debug.h" 171117b9a2SEllis Hoag #include "llvm/Support/Format.h" 181117b9a2SEllis Hoag #include "llvm/Support/FormatVariadic.h" 1953a7db4fSKamlesh Kumar #include "llvm/Support/ThreadPool.h" 201117b9a2SEllis Hoag 211117b9a2SEllis Hoag using namespace llvm; 221117b9a2SEllis Hoag #define DEBUG_TYPE "balanced-partitioning" 231117b9a2SEllis Hoag 241117b9a2SEllis Hoag void BPFunctionNode::dump(raw_ostream &OS) const { 2530aa9fb4Sspupyrev OS << formatv("{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}", Id, 2630aa9fb4Sspupyrev make_range(UtilityNodes.begin(), UtilityNodes.end()), Bucket); 271117b9a2SEllis Hoag } 281117b9a2SEllis Hoag 291117b9a2SEllis Hoag template <typename Func> 301117b9a2SEllis Hoag void BalancedPartitioning::BPThreadPool::async(Func &&F) { 31c1d935ecSEllis Hoag #if LLVM_ENABLE_THREADS 321117b9a2SEllis Hoag // This new thread could spawn more threads, so mark it as active 331117b9a2SEllis Hoag ++NumActiveThreads; 34*1f9f68a1SFangrui Song TheThreadPool.async([this, F]() { 351117b9a2SEllis Hoag // Run the task 361117b9a2SEllis Hoag F(); 371117b9a2SEllis Hoag 381117b9a2SEllis Hoag // This thread will no longer spawn new threads, so mark it as inactive 391117b9a2SEllis Hoag if (--NumActiveThreads == 0) { 401117b9a2SEllis Hoag // There are no more active threads, so mark as finished and notify 411117b9a2SEllis Hoag { 421117b9a2SEllis Hoag std::unique_lock<std::mutex> lock(mtx); 431117b9a2SEllis Hoag assert(!IsFinishedSpawning); 441117b9a2SEllis Hoag IsFinishedSpawning = true; 451117b9a2SEllis Hoag } 461117b9a2SEllis Hoag cv.notify_one(); 471117b9a2SEllis Hoag } 481117b9a2SEllis Hoag }); 49c1d935ecSEllis Hoag #else 50c1d935ecSEllis Hoag llvm_unreachable("threads are disabled"); 51c1d935ecSEllis Hoag #endif 521117b9a2SEllis Hoag } 531117b9a2SEllis Hoag 541117b9a2SEllis Hoag void BalancedPartitioning::BPThreadPool::wait() { 55c1d935ecSEllis Hoag #if LLVM_ENABLE_THREADS 561117b9a2SEllis Hoag // TODO: We could remove the mutex and condition variable and use 571117b9a2SEllis Hoag // std::atomic::wait() instead, but that isn't available until C++20 581117b9a2SEllis Hoag { 591117b9a2SEllis Hoag std::unique_lock<std::mutex> lock(mtx); 601117b9a2SEllis Hoag cv.wait(lock, [&]() { return IsFinishedSpawning; }); 611117b9a2SEllis Hoag assert(IsFinishedSpawning && NumActiveThreads == 0); 621117b9a2SEllis Hoag } 631117b9a2SEllis Hoag // Now we can call ThreadPool::wait() since all tasks have been submitted 641117b9a2SEllis Hoag TheThreadPool.wait(); 65c1d935ecSEllis Hoag #else 66c1d935ecSEllis Hoag llvm_unreachable("threads are disabled"); 67c1d935ecSEllis Hoag #endif 681117b9a2SEllis Hoag } 691117b9a2SEllis Hoag 701117b9a2SEllis Hoag BalancedPartitioning::BalancedPartitioning( 711117b9a2SEllis Hoag const BalancedPartitioningConfig &Config) 721117b9a2SEllis Hoag : Config(Config) { 731117b9a2SEllis Hoag // Pre-computing log2 values 741117b9a2SEllis Hoag Log2Cache[0] = 0.0; 751117b9a2SEllis Hoag for (unsigned I = 1; I < LOG_CACHE_SIZE; I++) 761117b9a2SEllis Hoag Log2Cache[I] = std::log2(I); 771117b9a2SEllis Hoag } 781117b9a2SEllis Hoag 791117b9a2SEllis Hoag void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const { 801117b9a2SEllis Hoag LLVM_DEBUG( 811117b9a2SEllis Hoag dbgs() << format( 821117b9a2SEllis Hoag "Partitioning %d nodes using depth %d and %d iterations per split\n", 831117b9a2SEllis Hoag Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit)); 841117b9a2SEllis Hoag std::optional<BPThreadPool> TP; 85c1d935ecSEllis Hoag #if LLVM_ENABLE_THREADS 86716042a6SMehdi Amini DefaultThreadPool TheThreadPool; 871117b9a2SEllis Hoag if (Config.TaskSplitDepth > 1) 8853a7db4fSKamlesh Kumar TP.emplace(TheThreadPool); 89c1d935ecSEllis Hoag #endif 901117b9a2SEllis Hoag 911117b9a2SEllis Hoag // Record the input order 921117b9a2SEllis Hoag for (unsigned I = 0; I < Nodes.size(); I++) 931117b9a2SEllis Hoag Nodes[I].InputOrderIndex = I; 941117b9a2SEllis Hoag 951117b9a2SEllis Hoag auto NodesRange = llvm::make_range(Nodes.begin(), Nodes.end()); 96*1f9f68a1SFangrui Song auto BisectTask = [this, NodesRange, &TP]() { 971117b9a2SEllis Hoag bisect(NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP); 981117b9a2SEllis Hoag }; 991117b9a2SEllis Hoag if (TP) { 1001117b9a2SEllis Hoag TP->async(std::move(BisectTask)); 1011117b9a2SEllis Hoag TP->wait(); 1021117b9a2SEllis Hoag } else { 1031117b9a2SEllis Hoag BisectTask(); 1041117b9a2SEllis Hoag } 1051117b9a2SEllis Hoag 1061117b9a2SEllis Hoag llvm::stable_sort(NodesRange, [](const auto &L, const auto &R) { 1071117b9a2SEllis Hoag return L.Bucket < R.Bucket; 1081117b9a2SEllis Hoag }); 1091117b9a2SEllis Hoag 1101117b9a2SEllis Hoag LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n"); 1111117b9a2SEllis Hoag } 1121117b9a2SEllis Hoag 1131117b9a2SEllis Hoag void BalancedPartitioning::bisect(const FunctionNodeRange Nodes, 1141117b9a2SEllis Hoag unsigned RecDepth, unsigned RootBucket, 1151117b9a2SEllis Hoag unsigned Offset, 1161117b9a2SEllis Hoag std::optional<BPThreadPool> &TP) const { 1171117b9a2SEllis Hoag unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end()); 1181117b9a2SEllis Hoag if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) { 1191117b9a2SEllis Hoag // We've reach the lowest level of the recursion tree. Fall back to the 1201117b9a2SEllis Hoag // original order and assign to buckets. 1210c6dc805SFangrui Song llvm::sort(Nodes, [](const auto &L, const auto &R) { 1221117b9a2SEllis Hoag return L.InputOrderIndex < R.InputOrderIndex; 1231117b9a2SEllis Hoag }); 1241117b9a2SEllis Hoag for (auto &N : Nodes) 1251117b9a2SEllis Hoag N.Bucket = Offset++; 1261117b9a2SEllis Hoag return; 1271117b9a2SEllis Hoag } 1281117b9a2SEllis Hoag 1291117b9a2SEllis Hoag LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n", 1301117b9a2SEllis Hoag NumNodes, RootBucket)); 1311117b9a2SEllis Hoag 1321117b9a2SEllis Hoag std::mt19937 RNG(RootBucket); 1331117b9a2SEllis Hoag 1341117b9a2SEllis Hoag unsigned LeftBucket = 2 * RootBucket; 1351117b9a2SEllis Hoag unsigned RightBucket = 2 * RootBucket + 1; 1361117b9a2SEllis Hoag 1371117b9a2SEllis Hoag // Split into two and assign to the left and right buckets 1381117b9a2SEllis Hoag split(Nodes, LeftBucket); 1391117b9a2SEllis Hoag 1405d0d9eb5SEllis Hoag runIterations(Nodes, LeftBucket, RightBucket, RNG); 1411117b9a2SEllis Hoag 1421117b9a2SEllis Hoag // Split nodes wrt the resulting buckets 1431117b9a2SEllis Hoag auto NodesMid = 1441117b9a2SEllis Hoag llvm::partition(Nodes, [&](auto &N) { return N.Bucket == LeftBucket; }); 1451117b9a2SEllis Hoag unsigned MidOffset = Offset + std::distance(Nodes.begin(), NodesMid); 1461117b9a2SEllis Hoag 1471117b9a2SEllis Hoag auto LeftNodes = llvm::make_range(Nodes.begin(), NodesMid); 1481117b9a2SEllis Hoag auto RightNodes = llvm::make_range(NodesMid, Nodes.end()); 1491117b9a2SEllis Hoag 150*1f9f68a1SFangrui Song auto LeftRecTask = [this, LeftNodes, RecDepth, LeftBucket, Offset, &TP]() { 1511117b9a2SEllis Hoag bisect(LeftNodes, RecDepth + 1, LeftBucket, Offset, TP); 1521117b9a2SEllis Hoag }; 153*1f9f68a1SFangrui Song auto RightRecTask = [this, RightNodes, RecDepth, RightBucket, MidOffset, 154*1f9f68a1SFangrui Song &TP]() { 1551117b9a2SEllis Hoag bisect(RightNodes, RecDepth + 1, RightBucket, MidOffset, TP); 1561117b9a2SEllis Hoag }; 1571117b9a2SEllis Hoag 1581117b9a2SEllis Hoag if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) { 1591117b9a2SEllis Hoag TP->async(std::move(LeftRecTask)); 1601117b9a2SEllis Hoag TP->async(std::move(RightRecTask)); 1611117b9a2SEllis Hoag } else { 1621117b9a2SEllis Hoag LeftRecTask(); 1631117b9a2SEllis Hoag RightRecTask(); 1641117b9a2SEllis Hoag } 1651117b9a2SEllis Hoag } 1661117b9a2SEllis Hoag 1671117b9a2SEllis Hoag void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes, 1685d0d9eb5SEllis Hoag unsigned LeftBucket, 1691117b9a2SEllis Hoag unsigned RightBucket, 1701117b9a2SEllis Hoag std::mt19937 &RNG) const { 17130aa9fb4Sspupyrev unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end()); 17230aa9fb4Sspupyrev DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeIndex; 1731117b9a2SEllis Hoag for (auto &N : Nodes) 1741117b9a2SEllis Hoag for (auto &UN : N.UtilityNodes) 17530aa9fb4Sspupyrev ++UtilityNodeIndex[UN]; 1761117b9a2SEllis Hoag // Remove utility nodes if they have just one edge or are connected to all 17730aa9fb4Sspupyrev // functions 1781117b9a2SEllis Hoag for (auto &N : Nodes) 1791117b9a2SEllis Hoag llvm::erase_if(N.UtilityNodes, [&](auto &UN) { 18030aa9fb4Sspupyrev return UtilityNodeIndex[UN] == 1 || UtilityNodeIndex[UN] == NumNodes; 1811117b9a2SEllis Hoag }); 1821117b9a2SEllis Hoag 18330aa9fb4Sspupyrev // Renumber utility nodes so they can be used to index into Signatures 1840c6dc805SFangrui Song UtilityNodeIndex.clear(); 1851117b9a2SEllis Hoag for (auto &N : Nodes) 1861117b9a2SEllis Hoag for (auto &UN : N.UtilityNodes) 18730aa9fb4Sspupyrev UN = UtilityNodeIndex.insert({UN, UtilityNodeIndex.size()}).first->second; 1881117b9a2SEllis Hoag 18930aa9fb4Sspupyrev // Initialize signatures 1901117b9a2SEllis Hoag SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size()); 1911117b9a2SEllis Hoag for (auto &N : Nodes) { 1921117b9a2SEllis Hoag for (auto &UN : N.UtilityNodes) { 19330aa9fb4Sspupyrev assert(UN < Signatures.size()); 19430aa9fb4Sspupyrev if (N.Bucket == LeftBucket) { 19530aa9fb4Sspupyrev Signatures[UN].LeftCount++; 19630aa9fb4Sspupyrev } else { 19730aa9fb4Sspupyrev Signatures[UN].RightCount++; 19830aa9fb4Sspupyrev } 1991117b9a2SEllis Hoag } 2001117b9a2SEllis Hoag } 2011117b9a2SEllis Hoag 2021117b9a2SEllis Hoag for (unsigned I = 0; I < Config.IterationsPerSplit; I++) { 2031117b9a2SEllis Hoag unsigned NumMovedNodes = 2041117b9a2SEllis Hoag runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG); 2051117b9a2SEllis Hoag if (NumMovedNodes == 0) 2061117b9a2SEllis Hoag break; 2071117b9a2SEllis Hoag } 2081117b9a2SEllis Hoag } 2091117b9a2SEllis Hoag 2101117b9a2SEllis Hoag unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes, 2111117b9a2SEllis Hoag unsigned LeftBucket, 2121117b9a2SEllis Hoag unsigned RightBucket, 2131117b9a2SEllis Hoag SignaturesT &Signatures, 2141117b9a2SEllis Hoag std::mt19937 &RNG) const { 2151117b9a2SEllis Hoag // Init signature cost caches 2161117b9a2SEllis Hoag for (auto &Signature : Signatures) { 2171117b9a2SEllis Hoag if (Signature.CachedGainIsValid) 2181117b9a2SEllis Hoag continue; 2191117b9a2SEllis Hoag unsigned L = Signature.LeftCount; 2201117b9a2SEllis Hoag unsigned R = Signature.RightCount; 2211117b9a2SEllis Hoag assert((L > 0 || R > 0) && "incorrect signature"); 2221117b9a2SEllis Hoag float Cost = logCost(L, R); 223266ffd7aSEllis Hoag Signature.CachedGainLR = 0.f; 224266ffd7aSEllis Hoag Signature.CachedGainRL = 0.f; 2251117b9a2SEllis Hoag if (L > 0) 22630aa9fb4Sspupyrev Signature.CachedGainLR = Cost - logCost(L - 1, R + 1); 2271117b9a2SEllis Hoag if (R > 0) 22830aa9fb4Sspupyrev Signature.CachedGainRL = Cost - logCost(L + 1, R - 1); 2291117b9a2SEllis Hoag Signature.CachedGainIsValid = true; 2301117b9a2SEllis Hoag } 2311117b9a2SEllis Hoag 2321117b9a2SEllis Hoag // Compute move gains 2331117b9a2SEllis Hoag typedef std::pair<float, BPFunctionNode *> GainPair; 2341117b9a2SEllis Hoag std::vector<GainPair> Gains; 2351117b9a2SEllis Hoag for (auto &N : Nodes) { 2361117b9a2SEllis Hoag bool FromLeftToRight = (N.Bucket == LeftBucket); 2371117b9a2SEllis Hoag float Gain = moveGain(N, FromLeftToRight, Signatures); 2381117b9a2SEllis Hoag Gains.push_back(std::make_pair(Gain, &N)); 2391117b9a2SEllis Hoag } 2401117b9a2SEllis Hoag 2411117b9a2SEllis Hoag // Collect left and right gains 2421117b9a2SEllis Hoag auto LeftEnd = llvm::partition( 2431117b9a2SEllis Hoag Gains, [&](const auto &GP) { return GP.second->Bucket == LeftBucket; }); 2441117b9a2SEllis Hoag auto LeftRange = llvm::make_range(Gains.begin(), LeftEnd); 2451117b9a2SEllis Hoag auto RightRange = llvm::make_range(LeftEnd, Gains.end()); 2461117b9a2SEllis Hoag 2471117b9a2SEllis Hoag // Sort gains in descending order 2481117b9a2SEllis Hoag auto LargerGain = [](const auto &L, const auto &R) { 2491117b9a2SEllis Hoag return L.first > R.first; 2501117b9a2SEllis Hoag }; 2511117b9a2SEllis Hoag llvm::stable_sort(LeftRange, LargerGain); 2521117b9a2SEllis Hoag llvm::stable_sort(RightRange, LargerGain); 2531117b9a2SEllis Hoag 2541117b9a2SEllis Hoag unsigned NumMovedDataVertices = 0; 2551117b9a2SEllis Hoag for (auto [LeftPair, RightPair] : llvm::zip(LeftRange, RightRange)) { 2561117b9a2SEllis Hoag auto &[LeftGain, LeftNode] = LeftPair; 2571117b9a2SEllis Hoag auto &[RightGain, RightNode] = RightPair; 2581117b9a2SEllis Hoag // Stop when the gain is no longer beneficial 259266ffd7aSEllis Hoag if (LeftGain + RightGain <= 0.f) 2601117b9a2SEllis Hoag break; 2611117b9a2SEllis Hoag // Try to exchange the nodes between buckets 2621117b9a2SEllis Hoag if (moveFunctionNode(*LeftNode, LeftBucket, RightBucket, Signatures, RNG)) 2631117b9a2SEllis Hoag ++NumMovedDataVertices; 2641117b9a2SEllis Hoag if (moveFunctionNode(*RightNode, LeftBucket, RightBucket, Signatures, RNG)) 2651117b9a2SEllis Hoag ++NumMovedDataVertices; 2661117b9a2SEllis Hoag } 2671117b9a2SEllis Hoag return NumMovedDataVertices; 2681117b9a2SEllis Hoag } 2691117b9a2SEllis Hoag 2701117b9a2SEllis Hoag bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N, 2711117b9a2SEllis Hoag unsigned LeftBucket, 2721117b9a2SEllis Hoag unsigned RightBucket, 2731117b9a2SEllis Hoag SignaturesT &Signatures, 2741117b9a2SEllis Hoag std::mt19937 &RNG) const { 2751117b9a2SEllis Hoag // Sometimes we skip the move. This helps to escape local optima 276266ffd7aSEllis Hoag if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <= 2771117b9a2SEllis Hoag Config.SkipProbability) 2781117b9a2SEllis Hoag return false; 2791117b9a2SEllis Hoag 2801117b9a2SEllis Hoag bool FromLeftToRight = (N.Bucket == LeftBucket); 2811117b9a2SEllis Hoag // Update the current bucket 2821117b9a2SEllis Hoag N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket); 2831117b9a2SEllis Hoag 2841117b9a2SEllis Hoag // Update signatures and invalidate gain cache 2851117b9a2SEllis Hoag if (FromLeftToRight) { 2861117b9a2SEllis Hoag for (auto &UN : N.UtilityNodes) { 28730aa9fb4Sspupyrev auto &Signature = Signatures[UN]; 2881117b9a2SEllis Hoag Signature.LeftCount--; 2891117b9a2SEllis Hoag Signature.RightCount++; 2901117b9a2SEllis Hoag Signature.CachedGainIsValid = false; 2911117b9a2SEllis Hoag } 2921117b9a2SEllis Hoag } else { 2931117b9a2SEllis Hoag for (auto &UN : N.UtilityNodes) { 29430aa9fb4Sspupyrev auto &Signature = Signatures[UN]; 2951117b9a2SEllis Hoag Signature.LeftCount++; 2961117b9a2SEllis Hoag Signature.RightCount--; 2971117b9a2SEllis Hoag Signature.CachedGainIsValid = false; 2981117b9a2SEllis Hoag } 2991117b9a2SEllis Hoag } 3001117b9a2SEllis Hoag return true; 3011117b9a2SEllis Hoag } 3021117b9a2SEllis Hoag 3031117b9a2SEllis Hoag void BalancedPartitioning::split(const FunctionNodeRange Nodes, 3041117b9a2SEllis Hoag unsigned StartBucket) const { 3051117b9a2SEllis Hoag unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end()); 3061117b9a2SEllis Hoag auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2; 3071117b9a2SEllis Hoag 3081117b9a2SEllis Hoag std::nth_element(Nodes.begin(), NodesMid, Nodes.end(), [](auto &L, auto &R) { 3091117b9a2SEllis Hoag return L.InputOrderIndex < R.InputOrderIndex; 3101117b9a2SEllis Hoag }); 3111117b9a2SEllis Hoag 3121117b9a2SEllis Hoag for (auto &N : llvm::make_range(Nodes.begin(), NodesMid)) 3131117b9a2SEllis Hoag N.Bucket = StartBucket; 3141117b9a2SEllis Hoag for (auto &N : llvm::make_range(NodesMid, Nodes.end())) 3151117b9a2SEllis Hoag N.Bucket = StartBucket + 1; 3161117b9a2SEllis Hoag } 3171117b9a2SEllis Hoag 3181117b9a2SEllis Hoag float BalancedPartitioning::moveGain(const BPFunctionNode &N, 3191117b9a2SEllis Hoag bool FromLeftToRight, 3201117b9a2SEllis Hoag const SignaturesT &Signatures) { 321266ffd7aSEllis Hoag float Gain = 0.f; 3221117b9a2SEllis Hoag for (auto &UN : N.UtilityNodes) 32330aa9fb4Sspupyrev Gain += (FromLeftToRight ? Signatures[UN].CachedGainLR 32430aa9fb4Sspupyrev : Signatures[UN].CachedGainRL); 3251117b9a2SEllis Hoag return Gain; 3261117b9a2SEllis Hoag } 3271117b9a2SEllis Hoag 3281117b9a2SEllis Hoag float BalancedPartitioning::logCost(unsigned X, unsigned Y) const { 3291117b9a2SEllis Hoag return -(X * log2Cached(X + 1) + Y * log2Cached(Y + 1)); 3301117b9a2SEllis Hoag } 3311117b9a2SEllis Hoag 3321117b9a2SEllis Hoag float BalancedPartitioning::log2Cached(unsigned i) const { 3331117b9a2SEllis Hoag return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(i); 3341117b9a2SEllis Hoag } 335