xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp (revision 83ad90d851f9e32a51d56193125ab596cc3636b6)
143fd244bSPierre van Houtryve //===- AMDGPUSplitModule.cpp ----------------------------------------------===//
243fd244bSPierre van Houtryve //
343fd244bSPierre van Houtryve // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
443fd244bSPierre van Houtryve // See https://llvm.org/LICENSE.txt for license information.
543fd244bSPierre van Houtryve // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
643fd244bSPierre van Houtryve //
743fd244bSPierre van Houtryve //===----------------------------------------------------------------------===//
843fd244bSPierre van Houtryve //
943fd244bSPierre van Houtryve /// \file Implements a module splitting algorithm designed to support the
109347b66cSPierre van Houtryve /// FullLTO --lto-partitions option for parallel codegen.
1143fd244bSPierre van Houtryve ///
129347b66cSPierre van Houtryve /// The role of this module splitting pass is the same as
139347b66cSPierre van Houtryve /// lib/Transforms/Utils/SplitModule.cpp: load-balance the module's functions
149347b66cSPierre van Houtryve /// across a set of N partitions to allow for parallel codegen.
1543fd244bSPierre van Houtryve ///
169347b66cSPierre van Houtryve /// The similarities mostly end here, as this pass achieves load-balancing in a
179347b66cSPierre van Houtryve /// more elaborate fashion which is targeted towards AMDGPU modules. It can take
189347b66cSPierre van Houtryve /// advantage of the structure of AMDGPU modules (which are mostly
199347b66cSPierre van Houtryve /// self-contained) to allow for more efficient splitting without affecting
209347b66cSPierre van Houtryve /// codegen negatively, or causing innaccurate resource usage analysis.
219347b66cSPierre van Houtryve ///
229347b66cSPierre van Houtryve /// High-level pass overview:
239347b66cSPierre van Houtryve ///   - SplitGraph & associated classes
249347b66cSPierre van Houtryve ///      - Graph representation of the module and of the dependencies that
259347b66cSPierre van Houtryve ///      matter for splitting.
269347b66cSPierre van Houtryve ///   - RecursiveSearchSplitting
279347b66cSPierre van Houtryve ///     - Core splitting algorithm.
289347b66cSPierre van Houtryve ///   - SplitProposal
299347b66cSPierre van Houtryve ///     - Represents a suggested solution for splitting the input module. These
309347b66cSPierre van Houtryve ///     solutions can be scored to determine the best one when multiple
319347b66cSPierre van Houtryve ///     solutions are available.
329347b66cSPierre van Houtryve ///   - Driver/pass "run" function glues everything together.
3343fd244bSPierre van Houtryve 
3443fd244bSPierre van Houtryve #include "AMDGPUSplitModule.h"
3543fd244bSPierre van Houtryve #include "AMDGPUTargetMachine.h"
3643fd244bSPierre van Houtryve #include "Utils/AMDGPUBaseInfo.h"
3743fd244bSPierre van Houtryve #include "llvm/ADT/DenseMap.h"
389347b66cSPierre van Houtryve #include "llvm/ADT/EquivalenceClasses.h"
399347b66cSPierre van Houtryve #include "llvm/ADT/GraphTraits.h"
4043fd244bSPierre van Houtryve #include "llvm/ADT/SmallVector.h"
4143fd244bSPierre van Houtryve #include "llvm/ADT/StringExtras.h"
4243fd244bSPierre van Houtryve #include "llvm/ADT/StringRef.h"
4343fd244bSPierre van Houtryve #include "llvm/Analysis/CallGraph.h"
4443fd244bSPierre van Houtryve #include "llvm/Analysis/TargetTransformInfo.h"
4543fd244bSPierre van Houtryve #include "llvm/IR/Function.h"
46d656b206SPierre van Houtryve #include "llvm/IR/InstIterator.h"
4743fd244bSPierre van Houtryve #include "llvm/IR/Instruction.h"
4843fd244bSPierre van Houtryve #include "llvm/IR/Module.h"
4943fd244bSPierre van Houtryve #include "llvm/IR/Value.h"
509347b66cSPierre van Houtryve #include "llvm/Support/Allocator.h"
5143fd244bSPierre van Houtryve #include "llvm/Support/Casting.h"
529347b66cSPierre van Houtryve #include "llvm/Support/DOTGraphTraits.h"
5343fd244bSPierre van Houtryve #include "llvm/Support/Debug.h"
549347b66cSPierre van Houtryve #include "llvm/Support/GraphWriter.h"
5543fd244bSPierre van Houtryve #include "llvm/Support/Path.h"
569347b66cSPierre van Houtryve #include "llvm/Support/Timer.h"
5743fd244bSPierre van Houtryve #include "llvm/Support/raw_ostream.h"
5843fd244bSPierre van Houtryve #include "llvm/Transforms/Utils/Cloning.h"
5943fd244bSPierre van Houtryve #include <cassert>
609347b66cSPierre van Houtryve #include <cmath>
6143fd244bSPierre van Houtryve #include <memory>
6243fd244bSPierre van Houtryve #include <utility>
6343fd244bSPierre van Houtryve #include <vector>
6443fd244bSPierre van Houtryve 
659347b66cSPierre van Houtryve #ifndef NDEBUG
669347b66cSPierre van Houtryve #include "llvm/Support/LockFileManager.h"
679347b66cSPierre van Houtryve #endif
6843fd244bSPierre van Houtryve 
6943fd244bSPierre van Houtryve #define DEBUG_TYPE "amdgpu-split-module"
7043fd244bSPierre van Houtryve 
719347b66cSPierre van Houtryve namespace llvm {
7243fd244bSPierre van Houtryve namespace {
7343fd244bSPierre van Houtryve 
749347b66cSPierre van Houtryve static cl::opt<unsigned> MaxDepth(
759347b66cSPierre van Houtryve     "amdgpu-module-splitting-max-depth",
76c9b6e01bSPierre van Houtryve     cl::desc(
779347b66cSPierre van Houtryve         "maximum search depth. 0 forces a greedy approach. "
789347b66cSPierre van Houtryve         "warning: the algorithm is up to O(2^N), where N is the max depth."),
799347b66cSPierre van Houtryve     cl::init(8));
809347b66cSPierre van Houtryve 
819347b66cSPierre van Houtryve static cl::opt<float> LargeFnFactor(
829347b66cSPierre van Houtryve     "amdgpu-module-splitting-large-threshold", cl::init(2.0f), cl::Hidden,
839347b66cSPierre van Houtryve     cl::desc(
849347b66cSPierre van Houtryve         "when max depth is reached and we can no longer branch out, this "
859347b66cSPierre van Houtryve         "value determines if a function is worth merging into an already "
869347b66cSPierre van Houtryve         "existing partition to reduce code duplication. This is a factor "
879347b66cSPierre van Houtryve         "of the ideal partition size, e.g. 2.0 means we consider the "
889347b66cSPierre van Houtryve         "function for merging if its cost (including its callees) is 2x the "
899347b66cSPierre van Houtryve         "size of an ideal partition."));
9043fd244bSPierre van Houtryve 
911c025fb0SPierre van Houtryve static cl::opt<float> LargeFnOverlapForMerge(
929347b66cSPierre van Houtryve     "amdgpu-module-splitting-merge-threshold", cl::init(0.7f), cl::Hidden,
939347b66cSPierre van Houtryve     cl::desc("when a function is considered for merging into a partition that "
949347b66cSPierre van Houtryve              "already contains some of its callees, do the merge if at least "
959347b66cSPierre van Houtryve              "n% of the code it can reach is already present inside the "
969347b66cSPierre van Houtryve              "partition; e.g. 0.7 means only merge >70%"));
9743fd244bSPierre van Houtryve 
9843fd244bSPierre van Houtryve static cl::opt<bool> NoExternalizeGlobals(
9943fd244bSPierre van Houtryve     "amdgpu-module-splitting-no-externalize-globals", cl::Hidden,
10043fd244bSPierre van Houtryve     cl::desc("disables externalization of global variable with local linkage; "
10143fd244bSPierre van Houtryve              "may cause globals to be duplicated which increases binary size"));
10243fd244bSPierre van Houtryve 
103d656b206SPierre van Houtryve static cl::opt<bool> NoExternalizeOnAddrTaken(
104d656b206SPierre van Houtryve     "amdgpu-module-splitting-no-externalize-address-taken", cl::Hidden,
105d656b206SPierre van Houtryve     cl::desc(
106d656b206SPierre van Houtryve         "disables externalization of functions whose addresses are taken"));
107d656b206SPierre van Houtryve 
10843fd244bSPierre van Houtryve static cl::opt<std::string>
1099347b66cSPierre van Houtryve     ModuleDotCfgOutput("amdgpu-module-splitting-print-module-dotcfg",
1109347b66cSPierre van Houtryve                        cl::Hidden,
1119347b66cSPierre van Houtryve                        cl::desc("output file to write out the dotgraph "
1129347b66cSPierre van Houtryve                                 "representation of the input module"));
1139347b66cSPierre van Houtryve 
1149347b66cSPierre van Houtryve static cl::opt<std::string> PartitionSummariesOutput(
1159347b66cSPierre van Houtryve     "amdgpu-module-splitting-print-partition-summaries", cl::Hidden,
1169347b66cSPierre van Houtryve     cl::desc("output file to write out a summary of "
1179347b66cSPierre van Houtryve              "the partitions created for each module"));
1189347b66cSPierre van Houtryve 
1199347b66cSPierre van Houtryve #ifndef NDEBUG
1209347b66cSPierre van Houtryve static cl::opt<bool>
1219347b66cSPierre van Houtryve     UseLockFile("amdgpu-module-splitting-serial-execution", cl::Hidden,
1229347b66cSPierre van Houtryve                 cl::desc("use a lock file so only one process in the system "
1239347b66cSPierre van Houtryve                          "can run this pass at once. useful to avoid mangled "
1249347b66cSPierre van Houtryve                          "debug output in multithreaded environments."));
12543fd244bSPierre van Houtryve 
12643fd244bSPierre van Houtryve static cl::opt<bool>
1279347b66cSPierre van Houtryve     DebugProposalSearch("amdgpu-module-splitting-debug-proposal-search",
1289347b66cSPierre van Houtryve                         cl::Hidden,
1299347b66cSPierre van Houtryve                         cl::desc("print all proposals received and whether "
1309347b66cSPierre van Houtryve                                  "they were rejected or accepted"));
1319347b66cSPierre van Houtryve #endif
132c9b6e01bSPierre van Houtryve 
1339347b66cSPierre van Houtryve struct SplitModuleTimer : NamedRegionTimer {
1349347b66cSPierre van Houtryve   SplitModuleTimer(StringRef Name, StringRef Desc)
1359347b66cSPierre van Houtryve       : NamedRegionTimer(Name, Desc, DEBUG_TYPE, "AMDGPU Module Splitting",
1369347b66cSPierre van Houtryve                          TimePassesIsEnabled) {}
1376345604aSDanial Klimkin };
1386345604aSDanial Klimkin 
1399347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
1409347b66cSPierre van Houtryve // Utils
1419347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
1429347b66cSPierre van Houtryve 
1439347b66cSPierre van Houtryve using CostType = InstructionCost::CostType;
1449347b66cSPierre van Houtryve using FunctionsCostMap = DenseMap<const Function *, CostType>;
1459347b66cSPierre van Houtryve using GetTTIFn = function_ref<const TargetTransformInfo &(Function &)>;
1469347b66cSPierre van Houtryve static constexpr unsigned InvalidPID = -1;
1479347b66cSPierre van Houtryve 
1489347b66cSPierre van Houtryve /// \param Num numerator
1499347b66cSPierre van Houtryve /// \param Dem denominator
1509347b66cSPierre van Houtryve /// \returns a printable object to print (Num/Dem) using "%0.2f".
1519347b66cSPierre van Houtryve static auto formatRatioOf(CostType Num, CostType Dem) {
1523414993eSFraser Cormack   CostType DemOr1 = Dem ? Dem : 1;
1533414993eSFraser Cormack   return format("%0.2f", (static_cast<double>(Num) / DemOr1) * 100);
1546345604aSDanial Klimkin }
1556345604aSDanial Klimkin 
1569347b66cSPierre van Houtryve /// Checks whether a given function is non-copyable.
1579347b66cSPierre van Houtryve ///
1589347b66cSPierre van Houtryve /// Non-copyable functions cannot be cloned into multiple partitions, and only
1599347b66cSPierre van Houtryve /// one copy of the function can be present across all partitions.
1609347b66cSPierre van Houtryve ///
161*83ad90d8SSiu Chi Chan /// Kernel functions and external functions fall into this category. If we were
162*83ad90d8SSiu Chi Chan /// to clone them, we would end up with multiple symbol definitions and a very
163*83ad90d8SSiu Chi Chan /// unhappy linker.
1649347b66cSPierre van Houtryve static bool isNonCopyable(const Function &F) {
165*83ad90d8SSiu Chi Chan   return F.hasExternalLinkage() || !F.isDefinitionExact() ||
166*83ad90d8SSiu Chi Chan          AMDGPU::isEntryFunctionCC(F.getCallingConv());
1679347b66cSPierre van Houtryve }
1689347b66cSPierre van Houtryve 
1699347b66cSPierre van Houtryve /// If \p GV has local linkage, make it external + hidden.
1709347b66cSPierre van Houtryve static void externalize(GlobalValue &GV) {
1719347b66cSPierre van Houtryve   if (GV.hasLocalLinkage()) {
1729347b66cSPierre van Houtryve     GV.setLinkage(GlobalValue::ExternalLinkage);
1739347b66cSPierre van Houtryve     GV.setVisibility(GlobalValue::HiddenVisibility);
1749347b66cSPierre van Houtryve   }
1759347b66cSPierre van Houtryve 
1769347b66cSPierre van Houtryve   // Unnamed entities must be named consistently between modules. setName will
1779347b66cSPierre van Houtryve   // give a distinct name to each such entity.
1789347b66cSPierre van Houtryve   if (!GV.hasName())
1799347b66cSPierre van Houtryve     GV.setName("__llvmsplit_unnamed");
1809347b66cSPierre van Houtryve }
1819347b66cSPierre van Houtryve 
1829347b66cSPierre van Houtryve /// Cost analysis function. Calculates the cost of each function in \p M
1839347b66cSPierre van Houtryve ///
184d95b82c4SPierre van Houtryve /// \param GetTTI Abstract getter for TargetTransformInfo.
18543fd244bSPierre van Houtryve /// \param M Module to analyze.
18643fd244bSPierre van Houtryve /// \param CostMap[out] Resulting Function -> Cost map.
18743fd244bSPierre van Houtryve /// \return The module's total cost.
1889347b66cSPierre van Houtryve static CostType calculateFunctionCosts(GetTTIFn GetTTI, Module &M,
1899347b66cSPierre van Houtryve                                        FunctionsCostMap &CostMap) {
1909347b66cSPierre van Houtryve   SplitModuleTimer SMT("calculateFunctionCosts", "cost analysis");
1919347b66cSPierre van Houtryve 
1929347b66cSPierre van Houtryve   LLVM_DEBUG(dbgs() << "[cost analysis] calculating function costs\n");
19343fd244bSPierre van Houtryve   CostType ModuleCost = 0;
1949347b66cSPierre van Houtryve   [[maybe_unused]] CostType KernelCost = 0;
19543fd244bSPierre van Houtryve 
19643fd244bSPierre van Houtryve   for (auto &Fn : M) {
19743fd244bSPierre van Houtryve     if (Fn.isDeclaration())
19843fd244bSPierre van Houtryve       continue;
19943fd244bSPierre van Houtryve 
20043fd244bSPierre van Houtryve     CostType FnCost = 0;
201d95b82c4SPierre van Houtryve     const auto &TTI = GetTTI(Fn);
20243fd244bSPierre van Houtryve     for (const auto &BB : Fn) {
20343fd244bSPierre van Houtryve       for (const auto &I : BB) {
20443fd244bSPierre van Houtryve         auto Cost =
20543fd244bSPierre van Houtryve             TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
20643fd244bSPierre van Houtryve         assert(Cost != InstructionCost::getMax());
20743fd244bSPierre van Houtryve         // Assume expensive if we can't tell the cost of an instruction.
20843fd244bSPierre van Houtryve         CostType CostVal =
20943fd244bSPierre van Houtryve             Cost.getValue().value_or(TargetTransformInfo::TCC_Expensive);
21043fd244bSPierre van Houtryve         assert((FnCost + CostVal) >= FnCost && "Overflow!");
21143fd244bSPierre van Houtryve         FnCost += CostVal;
21243fd244bSPierre van Houtryve       }
21343fd244bSPierre van Houtryve     }
21443fd244bSPierre van Houtryve 
21543fd244bSPierre van Houtryve     assert(FnCost != 0);
21643fd244bSPierre van Houtryve 
21743fd244bSPierre van Houtryve     CostMap[&Fn] = FnCost;
21843fd244bSPierre van Houtryve     assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!");
21943fd244bSPierre van Houtryve     ModuleCost += FnCost;
22043fd244bSPierre van Houtryve 
2219347b66cSPierre van Houtryve     if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv()))
22243fd244bSPierre van Houtryve       KernelCost += FnCost;
22343fd244bSPierre van Houtryve   }
22443fd244bSPierre van Houtryve 
2259347b66cSPierre van Houtryve   if (CostMap.empty())
2269347b66cSPierre van Houtryve     return 0;
2279347b66cSPierre van Houtryve 
2289347b66cSPierre van Houtryve   assert(ModuleCost);
2299347b66cSPierre van Houtryve   LLVM_DEBUG({
2309347b66cSPierre van Houtryve     const CostType FnCost = ModuleCost - KernelCost;
2319347b66cSPierre van Houtryve     dbgs() << " - total module cost is " << ModuleCost << ". kernels cost "
2329347b66cSPierre van Houtryve            << "" << KernelCost << " ("
2339347b66cSPierre van Houtryve            << format("%0.2f", (float(KernelCost) / ModuleCost) * 100)
2349347b66cSPierre van Houtryve            << "% of the module), functions cost " << FnCost << " ("
2359347b66cSPierre van Houtryve            << format("%0.2f", (float(FnCost) / ModuleCost) * 100)
2369347b66cSPierre van Houtryve            << "% of the module)\n";
2379347b66cSPierre van Houtryve   });
23843fd244bSPierre van Houtryve 
23943fd244bSPierre van Houtryve   return ModuleCost;
24043fd244bSPierre van Houtryve }
24143fd244bSPierre van Houtryve 
2429347b66cSPierre van Houtryve /// \return true if \p F can be indirectly called
24343fd244bSPierre van Houtryve static bool canBeIndirectlyCalled(const Function &F) {
2449347b66cSPierre van Houtryve   if (F.isDeclaration() || AMDGPU::isEntryFunctionCC(F.getCallingConv()))
24543fd244bSPierre van Houtryve     return false;
24643fd244bSPierre van Houtryve   return !F.hasLocalLinkage() ||
24743fd244bSPierre van Houtryve          F.hasAddressTaken(/*PutOffender=*/nullptr,
24843fd244bSPierre van Houtryve                            /*IgnoreCallbackUses=*/false,
24943fd244bSPierre van Houtryve                            /*IgnoreAssumeLikeCalls=*/true,
25043fd244bSPierre van Houtryve                            /*IgnoreLLVMUsed=*/true,
25143fd244bSPierre van Houtryve                            /*IgnoreARCAttachedCall=*/false,
25243fd244bSPierre van Houtryve                            /*IgnoreCastedDirectCall=*/true);
25343fd244bSPierre van Houtryve }
25443fd244bSPierre van Houtryve 
2559347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
2569347b66cSPierre van Houtryve // Graph-based Module Representation
2579347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
2589347b66cSPierre van Houtryve 
2599347b66cSPierre van Houtryve /// AMDGPUSplitModule's view of the source Module, as a graph of all components
2609347b66cSPierre van Houtryve /// that can be split into different modules.
26143fd244bSPierre van Houtryve ///
2629347b66cSPierre van Houtryve /// The most trivial instance of this graph is just the CallGraph of the module,
2639347b66cSPierre van Houtryve /// but it is not guaranteed that the graph is strictly equal to the CG. It
2649347b66cSPierre van Houtryve /// currently always is but it's designed in a way that would eventually allow
2659347b66cSPierre van Houtryve /// us to create abstract nodes, or nodes for different entities such as global
2669347b66cSPierre van Houtryve /// variables or any other meaningful constraint we must consider.
2679347b66cSPierre van Houtryve ///
2689347b66cSPierre van Houtryve /// The graph is only mutable by this class, and is generally not modified
2699347b66cSPierre van Houtryve /// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
2709347b66cSPierre van Houtryve /// mutate it.
2719347b66cSPierre van Houtryve class SplitGraph {
2729347b66cSPierre van Houtryve public:
2739347b66cSPierre van Houtryve   class Node;
2749347b66cSPierre van Houtryve 
2759347b66cSPierre van Houtryve   enum class EdgeKind : uint8_t {
2769347b66cSPierre van Houtryve     /// The nodes are related through a direct call. This is a "strong" edge as
2779347b66cSPierre van Houtryve     /// it means the Src will directly reference the Dst.
2789347b66cSPierre van Houtryve     DirectCall,
2799347b66cSPierre van Houtryve     /// The nodes are related through an indirect call.
2809347b66cSPierre van Houtryve     /// This is a "weaker" edge and is only considered when traversing the graph
2819347b66cSPierre van Houtryve     /// starting from a kernel. We need this edge for resource usage analysis.
2829347b66cSPierre van Houtryve     ///
2839347b66cSPierre van Houtryve     /// The reason why we have this edge in the first place is due to how
2849347b66cSPierre van Houtryve     /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
2859347b66cSPierre van Houtryve     /// the resource usage of the kernel containing the indirect call is the
2869347b66cSPierre van Houtryve     /// max resource usage of all functions that can be indirectly called.
2879347b66cSPierre van Houtryve     IndirectCall,
2889347b66cSPierre van Houtryve   };
2899347b66cSPierre van Houtryve 
2909347b66cSPierre van Houtryve   /// An edge between two nodes. Edges are directional, and tagged with a
2919347b66cSPierre van Houtryve   /// "kind".
2929347b66cSPierre van Houtryve   struct Edge {
2939347b66cSPierre van Houtryve     Edge(Node *Src, Node *Dst, EdgeKind Kind)
2949347b66cSPierre van Houtryve         : Src(Src), Dst(Dst), Kind(Kind) {}
2959347b66cSPierre van Houtryve 
2969347b66cSPierre van Houtryve     Node *Src; ///< Source
2979347b66cSPierre van Houtryve     Node *Dst; ///< Destination
2989347b66cSPierre van Houtryve     EdgeKind Kind;
2999347b66cSPierre van Houtryve   };
3009347b66cSPierre van Houtryve 
3019347b66cSPierre van Houtryve   using EdgesVec = SmallVector<const Edge *, 0>;
3029347b66cSPierre van Houtryve   using edges_iterator = EdgesVec::const_iterator;
3039347b66cSPierre van Houtryve   using nodes_iterator = const Node *const *;
3049347b66cSPierre van Houtryve 
3059347b66cSPierre van Houtryve   SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
3069347b66cSPierre van Houtryve              CostType ModuleCost)
3079347b66cSPierre van Houtryve       : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
3089347b66cSPierre van Houtryve 
3099347b66cSPierre van Houtryve   void buildGraph(CallGraph &CG);
3109347b66cSPierre van Houtryve 
3119347b66cSPierre van Houtryve #ifndef NDEBUG
3129347b66cSPierre van Houtryve   bool verifyGraph() const;
3139347b66cSPierre van Houtryve #endif
3149347b66cSPierre van Houtryve 
3159347b66cSPierre van Houtryve   bool empty() const { return Nodes.empty(); }
3169347b66cSPierre van Houtryve   const iterator_range<nodes_iterator> nodes() const {
3179347b66cSPierre van Houtryve     return {Nodes.begin(), Nodes.end()};
31843fd244bSPierre van Houtryve   }
3199347b66cSPierre van Houtryve   const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
3209347b66cSPierre van Houtryve 
3219347b66cSPierre van Houtryve   unsigned getNumNodes() const { return Nodes.size(); }
3229347b66cSPierre van Houtryve   BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
3239347b66cSPierre van Houtryve 
3249347b66cSPierre van Houtryve   const Module &getModule() const { return M; }
3259347b66cSPierre van Houtryve 
3269347b66cSPierre van Houtryve   CostType getModuleCost() const { return ModuleCost; }
3279347b66cSPierre van Houtryve   CostType getCost(const Function &F) const { return CostMap.at(&F); }
3289347b66cSPierre van Houtryve 
3299347b66cSPierre van Houtryve   /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
3309347b66cSPierre van Houtryve   /// IDs).
3319347b66cSPierre van Houtryve   CostType calculateCost(const BitVector &BV) const;
3329347b66cSPierre van Houtryve 
3339347b66cSPierre van Houtryve private:
3349347b66cSPierre van Houtryve   /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
3359347b66cSPierre van Houtryve   /// updates \p Cache.
3369347b66cSPierre van Houtryve   Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
3379347b66cSPierre van Houtryve                 const GlobalValue &GV);
3389347b66cSPierre van Houtryve 
3399347b66cSPierre van Houtryve   // Create a new edge between two nodes and add it to both nodes.
3409347b66cSPierre van Houtryve   const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
3419347b66cSPierre van Houtryve 
3429347b66cSPierre van Houtryve   const Module &M;
3439347b66cSPierre van Houtryve   const FunctionsCostMap &CostMap;
3449347b66cSPierre van Houtryve   CostType ModuleCost;
3459347b66cSPierre van Houtryve 
3469347b66cSPierre van Houtryve   // Final list of nodes with stable ordering.
3479347b66cSPierre van Houtryve   SmallVector<Node *> Nodes;
3489347b66cSPierre van Houtryve 
3499347b66cSPierre van Houtryve   SpecificBumpPtrAllocator<Node> NodesPool;
3509347b66cSPierre van Houtryve 
3519347b66cSPierre van Houtryve   // Edges are trivially destructible objects, so as a small optimization we
3529347b66cSPierre van Houtryve   // use a BumpPtrAllocator which avoids destructor calls but also makes
3539347b66cSPierre van Houtryve   // allocation faster.
3549347b66cSPierre van Houtryve   static_assert(
3559347b66cSPierre van Houtryve       std::is_trivially_destructible_v<Edge>,
3569347b66cSPierre van Houtryve       "Edge must be trivially destructible to use the BumpPtrAllocator");
3579347b66cSPierre van Houtryve   BumpPtrAllocator EdgesPool;
3589347b66cSPierre van Houtryve };
3599347b66cSPierre van Houtryve 
3609347b66cSPierre van Houtryve /// Nodes in the SplitGraph contain both incoming, and outgoing edges.
3619347b66cSPierre van Houtryve /// Incoming edges have this node as their Dst, and Outgoing ones have this node
3629347b66cSPierre van Houtryve /// as their Src.
3639347b66cSPierre van Houtryve ///
3649347b66cSPierre van Houtryve /// Edge objects are shared by both nodes in Src/Dst. They provide immediate
3659347b66cSPierre van Houtryve /// feedback on how two nodes are related, and in which direction they are
3669347b66cSPierre van Houtryve /// related, which is valuable information to make splitting decisions.
3679347b66cSPierre van Houtryve ///
3689347b66cSPierre van Houtryve /// Nodes are fundamentally abstract, and any consumers of the graph should
3699347b66cSPierre van Houtryve /// treat them as such. While a node will be a function most of the time, we
3709347b66cSPierre van Houtryve /// could also create nodes for any other reason. In the future, we could have
3719347b66cSPierre van Houtryve /// single nodes for multiple functions, or nodes for GVs, etc.
3729347b66cSPierre van Houtryve class SplitGraph::Node {
3739347b66cSPierre van Houtryve   friend class SplitGraph;
3749347b66cSPierre van Houtryve 
3759347b66cSPierre van Houtryve public:
3769347b66cSPierre van Houtryve   Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
3779347b66cSPierre van Houtryve        bool IsNonCopyable)
3789347b66cSPierre van Houtryve       : ID(ID), GV(GV), IndividualCost(IndividualCost),
3799347b66cSPierre van Houtryve         IsNonCopyable(IsNonCopyable), IsEntryFnCC(false), IsGraphEntry(false) {
3809347b66cSPierre van Houtryve     if (auto *Fn = dyn_cast<Function>(&GV))
3819347b66cSPierre van Houtryve       IsEntryFnCC = AMDGPU::isEntryFunctionCC(Fn->getCallingConv());
38243fd244bSPierre van Houtryve   }
38343fd244bSPierre van Houtryve 
3849347b66cSPierre van Houtryve   /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
3859347b66cSPierre van Houtryve   /// nodes in the graph. This ID can be used as an index in a BitVector.
3869347b66cSPierre van Houtryve   unsigned getID() const { return ID; }
387c9b6e01bSPierre van Houtryve 
3889347b66cSPierre van Houtryve   const Function &getFunction() const { return cast<Function>(GV); }
3899347b66cSPierre van Houtryve 
3909347b66cSPierre van Houtryve   /// \returns the cost to import this component into a given module, not
3919347b66cSPierre van Houtryve   /// accounting for any dependencies that may need to be imported as well.
3929347b66cSPierre van Houtryve   CostType getIndividualCost() const { return IndividualCost; }
3939347b66cSPierre van Houtryve 
3949347b66cSPierre van Houtryve   bool isNonCopyable() const { return IsNonCopyable; }
3959347b66cSPierre van Houtryve   bool isEntryFunctionCC() const { return IsEntryFnCC; }
3969347b66cSPierre van Houtryve 
3979347b66cSPierre van Houtryve   /// \returns whether this is an entry point in the graph. Entry points are
3989347b66cSPierre van Houtryve   /// defined as follows: if you take all entry points in the graph, and iterate
3999347b66cSPierre van Houtryve   /// their dependencies, you are guaranteed to visit all nodes in the graph at
4009347b66cSPierre van Houtryve   /// least once.
4019347b66cSPierre van Houtryve   bool isGraphEntryPoint() const { return IsGraphEntry; }
4029347b66cSPierre van Houtryve 
4039347b66cSPierre van Houtryve   StringRef getName() const { return GV.getName(); }
4049347b66cSPierre van Houtryve 
4059347b66cSPierre van Houtryve   bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
4069347b66cSPierre van Houtryve   bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
4079347b66cSPierre van Houtryve     return any_of(IncomingEdges, [&](const auto *E) { return E->Kind == EK; });
4089347b66cSPierre van Houtryve   }
4099347b66cSPierre van Houtryve 
4109347b66cSPierre van Houtryve   bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
4119347b66cSPierre van Houtryve   bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
4129347b66cSPierre van Houtryve     return any_of(OutgoingEdges, [&](const auto *E) { return E->Kind == EK; });
4139347b66cSPierre van Houtryve   }
4149347b66cSPierre van Houtryve 
4159347b66cSPierre van Houtryve   iterator_range<edges_iterator> incoming_edges() const {
4169347b66cSPierre van Houtryve     return IncomingEdges;
4179347b66cSPierre van Houtryve   }
4189347b66cSPierre van Houtryve 
4199347b66cSPierre van Houtryve   iterator_range<edges_iterator> outgoing_edges() const {
4209347b66cSPierre van Houtryve     return OutgoingEdges;
4219347b66cSPierre van Houtryve   }
4229347b66cSPierre van Houtryve 
4239347b66cSPierre van Houtryve   bool shouldFollowIndirectCalls() const { return isEntryFunctionCC(); }
4249347b66cSPierre van Houtryve 
4259347b66cSPierre van Houtryve   /// Visit all children of this node in a recursive fashion. Also visits Self.
4269347b66cSPierre van Houtryve   /// If \ref shouldFollowIndirectCalls returns false, then this only follows
4279347b66cSPierre van Houtryve   /// DirectCall edges.
4289347b66cSPierre van Houtryve   ///
4299347b66cSPierre van Houtryve   /// \param Visitor Visitor Function.
4309347b66cSPierre van Houtryve   void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
4319347b66cSPierre van Houtryve 
4329347b66cSPierre van Houtryve   /// Adds the depedencies of this node in \p BV by setting the bit
4339347b66cSPierre van Houtryve   /// corresponding to each node.
4349347b66cSPierre van Houtryve   ///
4359347b66cSPierre van Houtryve   /// Implemented using \ref visitAllDependencies, hence it follows the same
4369347b66cSPierre van Houtryve   /// rules regarding dependencies traversal.
4379347b66cSPierre van Houtryve   ///
4389347b66cSPierre van Houtryve   /// \param[out] BV The bitvector where the bits should be set.
4399347b66cSPierre van Houtryve   void getDependencies(BitVector &BV) const {
4409347b66cSPierre van Houtryve     visitAllDependencies([&](const Node &N) { BV.set(N.getID()); });
4419347b66cSPierre van Houtryve   }
4429347b66cSPierre van Houtryve 
4439347b66cSPierre van Houtryve private:
4449347b66cSPierre van Houtryve   void markAsGraphEntry() { IsGraphEntry = true; }
4459347b66cSPierre van Houtryve 
4469347b66cSPierre van Houtryve   unsigned ID;
4479347b66cSPierre van Houtryve   const GlobalValue &GV;
4489347b66cSPierre van Houtryve   CostType IndividualCost;
4499347b66cSPierre van Houtryve   bool IsNonCopyable : 1;
4509347b66cSPierre van Houtryve   bool IsEntryFnCC : 1;
4519347b66cSPierre van Houtryve   bool IsGraphEntry : 1;
4529347b66cSPierre van Houtryve 
4539347b66cSPierre van Houtryve   // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
4549347b66cSPierre van Houtryve   // together)
4559347b66cSPierre van Houtryve   EdgesVec IncomingEdges;
4569347b66cSPierre van Houtryve   EdgesVec OutgoingEdges;
4579347b66cSPierre van Houtryve };
4589347b66cSPierre van Houtryve 
4599347b66cSPierre van Houtryve void SplitGraph::Node::visitAllDependencies(
4609347b66cSPierre van Houtryve     std::function<void(const Node &)> Visitor) const {
4619347b66cSPierre van Houtryve   const bool FollowIndirect = shouldFollowIndirectCalls();
4629347b66cSPierre van Houtryve   // FIXME: If this can access SplitGraph in the future, use a BitVector
4639347b66cSPierre van Houtryve   // instead.
4649347b66cSPierre van Houtryve   DenseSet<const Node *> Seen;
4659347b66cSPierre van Houtryve   SmallVector<const Node *, 8> WorkList({this});
46643fd244bSPierre van Houtryve   while (!WorkList.empty()) {
4679347b66cSPierre van Houtryve     const Node *CurN = WorkList.pop_back_val();
4689347b66cSPierre van Houtryve     if (auto [It, Inserted] = Seen.insert(CurN); !Inserted)
4699347b66cSPierre van Houtryve       continue;
47043fd244bSPierre van Houtryve 
4719347b66cSPierre van Houtryve     Visitor(*CurN);
47243fd244bSPierre van Houtryve 
4739347b66cSPierre van Houtryve     for (const Edge *E : CurN->outgoing_edges()) {
4749347b66cSPierre van Houtryve       if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
4759347b66cSPierre van Houtryve         continue;
4769347b66cSPierre van Houtryve       WorkList.push_back(E->Dst);
4779347b66cSPierre van Houtryve     }
4789347b66cSPierre van Houtryve   }
4799347b66cSPierre van Houtryve }
4809347b66cSPierre van Houtryve 
481b3a8400aSPierre van Houtryve /// Checks if \p I has MD_callees and if it does, parse it and put the function
482b3a8400aSPierre van Houtryve /// in \p Callees.
483b3a8400aSPierre van Houtryve ///
484b3a8400aSPierre van Houtryve /// \returns true if there was metadata and it was parsed correctly. false if
485b3a8400aSPierre van Houtryve /// there was no MD or if it contained unknown entries and parsing failed.
486b3a8400aSPierre van Houtryve /// If this returns false, \p Callees will contain incomplete information
487b3a8400aSPierre van Houtryve /// and must not be used.
488b3a8400aSPierre van Houtryve static bool handleCalleesMD(const Instruction &I,
489b3a8400aSPierre van Houtryve                             SetVector<Function *> &Callees) {
490b3a8400aSPierre van Houtryve   auto *MD = I.getMetadata(LLVMContext::MD_callees);
491b3a8400aSPierre van Houtryve   if (!MD)
492b3a8400aSPierre van Houtryve     return false;
493b3a8400aSPierre van Houtryve 
494b3a8400aSPierre van Houtryve   for (const auto &Op : MD->operands()) {
495b3a8400aSPierre van Houtryve     Function *Callee = mdconst::extract_or_null<Function>(Op);
496b3a8400aSPierre van Houtryve     if (!Callee)
497b3a8400aSPierre van Houtryve       return false;
498b3a8400aSPierre van Houtryve     Callees.insert(Callee);
499b3a8400aSPierre van Houtryve   }
500b3a8400aSPierre van Houtryve 
501b3a8400aSPierre van Houtryve   return true;
502b3a8400aSPierre van Houtryve }
503b3a8400aSPierre van Houtryve 
5049347b66cSPierre van Houtryve void SplitGraph::buildGraph(CallGraph &CG) {
5059347b66cSPierre van Houtryve   SplitModuleTimer SMT("buildGraph", "graph construction");
5069347b66cSPierre van Houtryve   LLVM_DEBUG(
5079347b66cSPierre van Houtryve       dbgs()
5089347b66cSPierre van Houtryve       << "[build graph] constructing graph representation of the input\n");
5099347b66cSPierre van Houtryve 
510d656b206SPierre van Houtryve   // FIXME(?): Is the callgraph really worth using if we have to iterate the
511d656b206SPierre van Houtryve   // function again whenever it fails to give us enough information?
512d656b206SPierre van Houtryve 
5139347b66cSPierre van Houtryve   // We build the graph by just iterating all functions in the module and
5149347b66cSPierre van Houtryve   // working on their direct callees. At the end, all nodes should be linked
5159347b66cSPierre van Houtryve   // together as expected.
5169347b66cSPierre van Houtryve   DenseMap<const GlobalValue *, Node *> Cache;
5179347b66cSPierre van Houtryve   SmallVector<const Function *> FnsWithIndirectCalls, IndirectlyCallableFns;
5189347b66cSPierre van Houtryve   for (const Function &Fn : M) {
5199347b66cSPierre van Houtryve     if (Fn.isDeclaration())
5209347b66cSPierre van Houtryve       continue;
5219347b66cSPierre van Houtryve 
5229347b66cSPierre van Houtryve     // Look at direct callees and create the necessary edges in the graph.
523d656b206SPierre van Houtryve     SetVector<const Function *> DirectCallees;
524d656b206SPierre van Houtryve     bool CallsExternal = false;
5259347b66cSPierre van Houtryve     for (auto &CGEntry : *CG[&Fn]) {
52643fd244bSPierre van Houtryve       auto *CGNode = CGEntry.second;
527d656b206SPierre van Houtryve       if (auto *Callee = CGNode->getFunction()) {
5289347b66cSPierre van Houtryve         if (!Callee->isDeclaration())
529d656b206SPierre van Houtryve           DirectCallees.insert(Callee);
530d656b206SPierre van Houtryve       } else if (CGNode == CG.getCallsExternalNode())
531d656b206SPierre van Houtryve         CallsExternal = true;
53243fd244bSPierre van Houtryve     }
53343fd244bSPierre van Houtryve 
5349347b66cSPierre van Houtryve     // Keep track of this function if it contains an indirect call and/or if it
5359347b66cSPierre van Houtryve     // can be indirectly called.
536d656b206SPierre van Houtryve     if (CallsExternal) {
537d656b206SPierre van Houtryve       LLVM_DEBUG(dbgs() << "  [!] callgraph is incomplete for ";
538d656b206SPierre van Houtryve                  Fn.printAsOperand(dbgs());
539d656b206SPierre van Houtryve                  dbgs() << " - analyzing function\n");
540d656b206SPierre van Houtryve 
541b3a8400aSPierre van Houtryve       SetVector<Function *> KnownCallees;
542b3a8400aSPierre van Houtryve       bool HasUnknownIndirectCall = false;
543d656b206SPierre van Houtryve       for (const auto &Inst : instructions(Fn)) {
544d656b206SPierre van Houtryve         // look at all calls without a direct callee.
545b3a8400aSPierre van Houtryve         const auto *CB = dyn_cast<CallBase>(&Inst);
546b3a8400aSPierre van Houtryve         if (!CB || CB->getCalledFunction())
547b3a8400aSPierre van Houtryve           continue;
548b3a8400aSPierre van Houtryve 
549d656b206SPierre van Houtryve         // inline assembly can be ignored, unless InlineAsmIsIndirectCall is
550d656b206SPierre van Houtryve         // true.
551d656b206SPierre van Houtryve         if (CB->isInlineAsm()) {
552d656b206SPierre van Houtryve           LLVM_DEBUG(dbgs() << "    found inline assembly\n");
553d656b206SPierre van Houtryve           continue;
554d656b206SPierre van Houtryve         }
555d656b206SPierre van Houtryve 
556b3a8400aSPierre van Houtryve         if (handleCalleesMD(Inst, KnownCallees))
557b3a8400aSPierre van Houtryve           continue;
558b3a8400aSPierre van Houtryve         // If we failed to parse any !callees MD, or some was missing,
559b3a8400aSPierre van Houtryve         // the entire KnownCallees list is now unreliable.
560b3a8400aSPierre van Houtryve         KnownCallees.clear();
561b3a8400aSPierre van Houtryve 
562b3a8400aSPierre van Houtryve         // Everything else is handled conservatively. If we fall into the
563b3a8400aSPierre van Houtryve         // conservative case don't bother analyzing further.
564b3a8400aSPierre van Houtryve         HasUnknownIndirectCall = true;
565d656b206SPierre van Houtryve         break;
566d656b206SPierre van Houtryve       }
567d656b206SPierre van Houtryve 
568b3a8400aSPierre van Houtryve       if (HasUnknownIndirectCall) {
569d656b206SPierre van Houtryve         LLVM_DEBUG(dbgs() << "    indirect call found\n");
5709347b66cSPierre van Houtryve         FnsWithIndirectCalls.push_back(&Fn);
571b3a8400aSPierre van Houtryve       } else if (!KnownCallees.empty())
572b3a8400aSPierre van Houtryve         DirectCallees.insert(KnownCallees.begin(), KnownCallees.end());
573d656b206SPierre van Houtryve     }
574d656b206SPierre van Houtryve 
575d656b206SPierre van Houtryve     Node &N = getNode(Cache, Fn);
576d656b206SPierre van Houtryve     for (const auto *Callee : DirectCallees)
577d656b206SPierre van Houtryve       createEdge(N, getNode(Cache, *Callee), EdgeKind::DirectCall);
5786345604aSDanial Klimkin 
5799347b66cSPierre van Houtryve     if (canBeIndirectlyCalled(Fn))
5809347b66cSPierre van Houtryve       IndirectlyCallableFns.push_back(&Fn);
5819347b66cSPierre van Houtryve   }
5829347b66cSPierre van Houtryve 
5839347b66cSPierre van Houtryve   // Post-process functions with indirect calls.
5849347b66cSPierre van Houtryve   for (const Function *Fn : FnsWithIndirectCalls) {
5859347b66cSPierre van Houtryve     for (const Function *Candidate : IndirectlyCallableFns) {
5869347b66cSPierre van Houtryve       Node &Src = getNode(Cache, *Fn);
5879347b66cSPierre van Houtryve       Node &Dst = getNode(Cache, *Candidate);
5889347b66cSPierre van Houtryve       createEdge(Src, Dst, EdgeKind::IndirectCall);
5896345604aSDanial Klimkin     }
59043fd244bSPierre van Houtryve   }
59143fd244bSPierre van Houtryve 
5929347b66cSPierre van Houtryve   // Now, find all entry points.
5939347b66cSPierre van Houtryve   SmallVector<Node *, 16> CandidateEntryPoints;
5949347b66cSPierre van Houtryve   BitVector NodesReachableByKernels = createNodesBitVector();
5959347b66cSPierre van Houtryve   for (Node *N : Nodes) {
5969347b66cSPierre van Houtryve     // Functions with an Entry CC are always graph entry points too.
5979347b66cSPierre van Houtryve     if (N->isEntryFunctionCC()) {
5989347b66cSPierre van Houtryve       N->markAsGraphEntry();
5999347b66cSPierre van Houtryve       N->getDependencies(NodesReachableByKernels);
6009347b66cSPierre van Houtryve     } else if (!N->hasAnyIncomingEdgesOfKind(EdgeKind::DirectCall))
6019347b66cSPierre van Houtryve       CandidateEntryPoints.push_back(N);
60243fd244bSPierre van Houtryve   }
60343fd244bSPierre van Houtryve 
6049347b66cSPierre van Houtryve   for (Node *N : CandidateEntryPoints) {
6059347b66cSPierre van Houtryve     // This can be another entry point if it's not reachable by a kernel
6069347b66cSPierre van Houtryve     // TODO: We could sort all of the possible new entries in a stable order
6079347b66cSPierre van Houtryve     // (e.g. by cost), then consume them one by one until
6089347b66cSPierre van Houtryve     // NodesReachableByKernels is all 1s. It'd allow us to avoid
6099347b66cSPierre van Houtryve     // considering some nodes as non-entries in some specific cases.
6109347b66cSPierre van Houtryve     if (!NodesReachableByKernels.test(N->getID()))
6119347b66cSPierre van Houtryve       N->markAsGraphEntry();
6126345604aSDanial Klimkin   }
6136345604aSDanial Klimkin 
6141c025fb0SPierre van Houtryve #ifndef NDEBUG
6159347b66cSPierre van Houtryve   assert(verifyGraph());
6161c025fb0SPierre van Houtryve #endif
61743fd244bSPierre van Houtryve }
61843fd244bSPierre van Houtryve 
6199347b66cSPierre van Houtryve #ifndef NDEBUG
6209347b66cSPierre van Houtryve bool SplitGraph::verifyGraph() const {
6219347b66cSPierre van Houtryve   unsigned ExpectedID = 0;
6229347b66cSPierre van Houtryve   // Exceptionally using a set here in case IDs are messed up.
6239347b66cSPierre van Houtryve   DenseSet<const Node *> SeenNodes;
6249347b66cSPierre van Houtryve   DenseSet<const Function *> SeenFunctionNodes;
6259347b66cSPierre van Houtryve   for (const Node *N : Nodes) {
6269347b66cSPierre van Houtryve     if (N->getID() != (ExpectedID++)) {
6279347b66cSPierre van Houtryve       errs() << "Node IDs are incorrect!\n";
6286345604aSDanial Klimkin       return false;
629c9b6e01bSPierre van Houtryve     }
630c9b6e01bSPierre van Houtryve 
6319347b66cSPierre van Houtryve     if (!SeenNodes.insert(N).second) {
6329347b66cSPierre van Houtryve       errs() << "Node seen more than once!\n";
6339347b66cSPierre van Houtryve       return false;
6349347b66cSPierre van Houtryve     }
6359347b66cSPierre van Houtryve 
6369347b66cSPierre van Houtryve     if (&getNode(N->getID()) != N) {
6379347b66cSPierre van Houtryve       errs() << "getNode doesn't return the right node\n";
6389347b66cSPierre van Houtryve       return false;
6399347b66cSPierre van Houtryve     }
6409347b66cSPierre van Houtryve 
6419347b66cSPierre van Houtryve     for (const Edge *E : N->IncomingEdges) {
6429347b66cSPierre van Houtryve       if (!E->Src || !E->Dst || (E->Dst != N) ||
6439347b66cSPierre van Houtryve           (find(E->Src->OutgoingEdges, E) == E->Src->OutgoingEdges.end())) {
6449347b66cSPierre van Houtryve         errs() << "ill-formed incoming edges\n";
6459347b66cSPierre van Houtryve         return false;
6469347b66cSPierre van Houtryve       }
6479347b66cSPierre van Houtryve     }
6489347b66cSPierre van Houtryve 
6499347b66cSPierre van Houtryve     for (const Edge *E : N->OutgoingEdges) {
6509347b66cSPierre van Houtryve       if (!E->Src || !E->Dst || (E->Src != N) ||
6519347b66cSPierre van Houtryve           (find(E->Dst->IncomingEdges, E) == E->Dst->IncomingEdges.end())) {
6529347b66cSPierre van Houtryve         errs() << "ill-formed outgoing edges\n";
6539347b66cSPierre van Houtryve         return false;
6549347b66cSPierre van Houtryve       }
6559347b66cSPierre van Houtryve     }
6569347b66cSPierre van Houtryve 
6579347b66cSPierre van Houtryve     const Function &Fn = N->getFunction();
6589347b66cSPierre van Houtryve     if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv())) {
6599347b66cSPierre van Houtryve       if (N->hasAnyIncomingEdges()) {
6609347b66cSPierre van Houtryve         errs() << "Kernels cannot have incoming edges\n";
6619347b66cSPierre van Houtryve         return false;
6629347b66cSPierre van Houtryve       }
6639347b66cSPierre van Houtryve     }
6649347b66cSPierre van Houtryve 
6659347b66cSPierre van Houtryve     if (Fn.isDeclaration()) {
6669347b66cSPierre van Houtryve       errs() << "declarations shouldn't have nodes!\n";
6679347b66cSPierre van Houtryve       return false;
6689347b66cSPierre van Houtryve     }
6699347b66cSPierre van Houtryve 
6709347b66cSPierre van Houtryve     auto [It, Inserted] = SeenFunctionNodes.insert(&Fn);
6719347b66cSPierre van Houtryve     if (!Inserted) {
6729347b66cSPierre van Houtryve       errs() << "one function has multiple nodes!\n";
6739347b66cSPierre van Houtryve       return false;
6749347b66cSPierre van Houtryve     }
6759347b66cSPierre van Houtryve   }
6769347b66cSPierre van Houtryve 
6779347b66cSPierre van Houtryve   if (ExpectedID != Nodes.size()) {
6789347b66cSPierre van Houtryve     errs() << "Node IDs out of sync!\n";
6799347b66cSPierre van Houtryve     return false;
6809347b66cSPierre van Houtryve   }
6819347b66cSPierre van Houtryve 
6829347b66cSPierre van Houtryve   if (createNodesBitVector().size() != getNumNodes()) {
6839347b66cSPierre van Houtryve     errs() << "nodes bit vector doesn't have the right size!\n";
6849347b66cSPierre van Houtryve     return false;
6859347b66cSPierre van Houtryve   }
6869347b66cSPierre van Houtryve 
6879347b66cSPierre van Houtryve   // Check we respect the promise of Node::isKernel
6889347b66cSPierre van Houtryve   BitVector BV = createNodesBitVector();
6899347b66cSPierre van Houtryve   for (const Node *N : nodes()) {
6909347b66cSPierre van Houtryve     if (N->isGraphEntryPoint())
6919347b66cSPierre van Houtryve       N->getDependencies(BV);
6929347b66cSPierre van Houtryve   }
6939347b66cSPierre van Houtryve 
6949347b66cSPierre van Houtryve   // Ensure each function in the module has an associated node.
6959347b66cSPierre van Houtryve   for (const auto &Fn : M) {
6969347b66cSPierre van Houtryve     if (!Fn.isDeclaration()) {
6979347b66cSPierre van Houtryve       if (!SeenFunctionNodes.contains(&Fn)) {
6989347b66cSPierre van Houtryve         errs() << "Fn has no associated node in the graph!\n";
6999347b66cSPierre van Houtryve         return false;
7009347b66cSPierre van Houtryve       }
7019347b66cSPierre van Houtryve     }
7029347b66cSPierre van Houtryve   }
7039347b66cSPierre van Houtryve 
7049347b66cSPierre van Houtryve   if (!BV.all()) {
7059347b66cSPierre van Houtryve     errs() << "not all nodes are reachable through the graph's entry points!\n";
7069347b66cSPierre van Houtryve     return false;
7079347b66cSPierre van Houtryve   }
7089347b66cSPierre van Houtryve 
7099347b66cSPierre van Houtryve   return true;
7109347b66cSPierre van Houtryve }
7119347b66cSPierre van Houtryve #endif
7129347b66cSPierre van Houtryve 
7139347b66cSPierre van Houtryve CostType SplitGraph::calculateCost(const BitVector &BV) const {
7149347b66cSPierre van Houtryve   CostType Cost = 0;
7159347b66cSPierre van Houtryve   for (unsigned NodeID : BV.set_bits())
7169347b66cSPierre van Houtryve     Cost += getNode(NodeID).getIndividualCost();
7179347b66cSPierre van Houtryve   return Cost;
7189347b66cSPierre van Houtryve }
7199347b66cSPierre van Houtryve 
7209347b66cSPierre van Houtryve SplitGraph::Node &
7219347b66cSPierre van Houtryve SplitGraph::getNode(DenseMap<const GlobalValue *, Node *> &Cache,
7229347b66cSPierre van Houtryve                     const GlobalValue &GV) {
7239347b66cSPierre van Houtryve   auto &N = Cache[&GV];
7249347b66cSPierre van Houtryve   if (N)
7259347b66cSPierre van Houtryve     return *N;
7269347b66cSPierre van Houtryve 
7279347b66cSPierre van Houtryve   CostType Cost = 0;
7289347b66cSPierre van Houtryve   bool NonCopyable = false;
7299347b66cSPierre van Houtryve   if (const Function *Fn = dyn_cast<Function>(&GV)) {
7309347b66cSPierre van Houtryve     NonCopyable = isNonCopyable(*Fn);
7319347b66cSPierre van Houtryve     Cost = CostMap.at(Fn);
7329347b66cSPierre van Houtryve   }
7339347b66cSPierre van Houtryve   N = new (NodesPool.Allocate()) Node(Nodes.size(), GV, Cost, NonCopyable);
7349347b66cSPierre van Houtryve   Nodes.push_back(N);
7359347b66cSPierre van Houtryve   assert(&getNode(N->getID()) == N);
7369347b66cSPierre van Houtryve   return *N;
7379347b66cSPierre van Houtryve }
7389347b66cSPierre van Houtryve 
7399347b66cSPierre van Houtryve const SplitGraph::Edge &SplitGraph::createEdge(Node &Src, Node &Dst,
7409347b66cSPierre van Houtryve                                                EdgeKind EK) {
7419347b66cSPierre van Houtryve   const Edge *E = new (EdgesPool.Allocate<Edge>(1)) Edge(&Src, &Dst, EK);
7429347b66cSPierre van Houtryve   Src.OutgoingEdges.push_back(E);
7439347b66cSPierre van Houtryve   Dst.IncomingEdges.push_back(E);
7449347b66cSPierre van Houtryve   return *E;
7459347b66cSPierre van Houtryve }
7469347b66cSPierre van Houtryve 
7479347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
7489347b66cSPierre van Houtryve // Split Proposals
7499347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
7509347b66cSPierre van Houtryve 
7519347b66cSPierre van Houtryve /// Represents a module splitting proposal.
7529347b66cSPierre van Houtryve ///
7539347b66cSPierre van Houtryve /// Proposals are made of N BitVectors, one for each partition, where each bit
7549347b66cSPierre van Houtryve /// set indicates that the node is present and should be copied inside that
7559347b66cSPierre van Houtryve /// partition.
7569347b66cSPierre van Houtryve ///
7579347b66cSPierre van Houtryve /// Proposals have several metrics attached so they can be compared/sorted,
7589347b66cSPierre van Houtryve /// which the driver to try multiple strategies resultings in multiple proposals
7599347b66cSPierre van Houtryve /// and choose the best one out of them.
7609347b66cSPierre van Houtryve class SplitProposal {
7619347b66cSPierre van Houtryve public:
7629347b66cSPierre van Houtryve   SplitProposal(const SplitGraph &SG, unsigned MaxPartitions) : SG(&SG) {
7639347b66cSPierre van Houtryve     Partitions.resize(MaxPartitions, {0, SG.createNodesBitVector()});
7649347b66cSPierre van Houtryve   }
7659347b66cSPierre van Houtryve 
7669347b66cSPierre van Houtryve   void setName(StringRef NewName) { Name = NewName; }
7679347b66cSPierre van Houtryve   StringRef getName() const { return Name; }
7689347b66cSPierre van Houtryve 
7699347b66cSPierre van Houtryve   const BitVector &operator[](unsigned PID) const {
7709347b66cSPierre van Houtryve     return Partitions[PID].second;
7719347b66cSPierre van Houtryve   }
7729347b66cSPierre van Houtryve 
7739347b66cSPierre van Houtryve   void add(unsigned PID, const BitVector &BV) {
7749347b66cSPierre van Houtryve     Partitions[PID].second |= BV;
7759347b66cSPierre van Houtryve     updateScore(PID);
7769347b66cSPierre van Houtryve   }
7779347b66cSPierre van Houtryve 
7789347b66cSPierre van Houtryve   void print(raw_ostream &OS) const;
7799347b66cSPierre van Houtryve   LLVM_DUMP_METHOD void dump() const { print(dbgs()); }
7809347b66cSPierre van Houtryve 
7819347b66cSPierre van Houtryve   // Find the cheapest partition (lowest cost). In case of ties, always returns
7829347b66cSPierre van Houtryve   // the highest partition number.
7839347b66cSPierre van Houtryve   unsigned findCheapestPartition() const;
7849347b66cSPierre van Houtryve 
7859347b66cSPierre van Houtryve   /// Calculate the CodeSize and Bottleneck scores.
7869347b66cSPierre van Houtryve   void calculateScores();
7879347b66cSPierre van Houtryve 
7889347b66cSPierre van Houtryve #ifndef NDEBUG
7899347b66cSPierre van Houtryve   void verifyCompleteness() const;
7909347b66cSPierre van Houtryve #endif
7919347b66cSPierre van Houtryve 
7929347b66cSPierre van Houtryve   /// Only available after \ref calculateScores is called.
7939347b66cSPierre van Houtryve   ///
7949347b66cSPierre van Houtryve   /// A positive number indicating the % of code duplication that this proposal
7959347b66cSPierre van Houtryve   /// creates. e.g. 0.2 means this proposal adds roughly 20% code size by
7969347b66cSPierre van Houtryve   /// duplicating some functions across partitions.
7979347b66cSPierre van Houtryve   ///
7989347b66cSPierre van Houtryve   /// Value is always rounded up to 3 decimal places.
7999347b66cSPierre van Houtryve   ///
8009347b66cSPierre van Houtryve   /// A perfect score would be 0.0, and anything approaching 1.0 is very bad.
8019347b66cSPierre van Houtryve   double getCodeSizeScore() const { return CodeSizeScore; }
8029347b66cSPierre van Houtryve 
8039347b66cSPierre van Houtryve   /// Only available after \ref calculateScores is called.
8049347b66cSPierre van Houtryve   ///
8059347b66cSPierre van Houtryve   /// A number between [0, 1] which indicates how big of a bottleneck is
8069347b66cSPierre van Houtryve   /// expected from the largest partition.
8079347b66cSPierre van Houtryve   ///
8089347b66cSPierre van Houtryve   /// A score of 1.0 means the biggest partition is as big as the source module,
8099347b66cSPierre van Houtryve   /// so build time will be equal to or greater than the build time of the
8109347b66cSPierre van Houtryve   /// initial input.
8119347b66cSPierre van Houtryve   ///
8129347b66cSPierre van Houtryve   /// Value is always rounded up to 3 decimal places.
8139347b66cSPierre van Houtryve   ///
8149347b66cSPierre van Houtryve   /// This is one of the metrics used to estimate this proposal's build time.
8159347b66cSPierre van Houtryve   double getBottleneckScore() const { return BottleneckScore; }
8169347b66cSPierre van Houtryve 
8179347b66cSPierre van Houtryve private:
8189347b66cSPierre van Houtryve   void updateScore(unsigned PID) {
8199347b66cSPierre van Houtryve     assert(SG);
8209347b66cSPierre van Houtryve     for (auto &[PCost, Nodes] : Partitions) {
8219347b66cSPierre van Houtryve       TotalCost -= PCost;
8229347b66cSPierre van Houtryve       PCost = SG->calculateCost(Nodes);
8239347b66cSPierre van Houtryve       TotalCost += PCost;
8249347b66cSPierre van Houtryve     }
8259347b66cSPierre van Houtryve   }
8269347b66cSPierre van Houtryve 
8279347b66cSPierre van Houtryve   /// \see getCodeSizeScore
8289347b66cSPierre van Houtryve   double CodeSizeScore = 0.0;
8299347b66cSPierre van Houtryve   /// \see getBottleneckScore
8309347b66cSPierre van Houtryve   double BottleneckScore = 0.0;
8319347b66cSPierre van Houtryve   /// Aggregated cost of all partitions
8329347b66cSPierre van Houtryve   CostType TotalCost = 0;
8339347b66cSPierre van Houtryve 
8349347b66cSPierre van Houtryve   const SplitGraph *SG = nullptr;
8359347b66cSPierre van Houtryve   std::string Name;
8369347b66cSPierre van Houtryve 
8379347b66cSPierre van Houtryve   std::vector<std::pair<CostType, BitVector>> Partitions;
8389347b66cSPierre van Houtryve };
8399347b66cSPierre van Houtryve 
8409347b66cSPierre van Houtryve void SplitProposal::print(raw_ostream &OS) const {
8419347b66cSPierre van Houtryve   assert(SG);
8429347b66cSPierre van Houtryve 
8439347b66cSPierre van Houtryve   OS << "[proposal] " << Name << ", total cost:" << TotalCost
8449347b66cSPierre van Houtryve      << ", code size score:" << format("%0.3f", CodeSizeScore)
8459347b66cSPierre van Houtryve      << ", bottleneck score:" << format("%0.3f", BottleneckScore) << '\n';
8469347b66cSPierre van Houtryve   for (const auto &[PID, Part] : enumerate(Partitions)) {
8479347b66cSPierre van Houtryve     const auto &[Cost, NodeIDs] = Part;
8489347b66cSPierre van Houtryve     OS << "  - P" << PID << " nodes:" << NodeIDs.count() << " cost: " << Cost
8499347b66cSPierre van Houtryve        << '|' << formatRatioOf(Cost, SG->getModuleCost()) << "%\n";
8509347b66cSPierre van Houtryve   }
8519347b66cSPierre van Houtryve }
8529347b66cSPierre van Houtryve 
8539347b66cSPierre van Houtryve unsigned SplitProposal::findCheapestPartition() const {
8549347b66cSPierre van Houtryve   assert(!Partitions.empty());
8559347b66cSPierre van Houtryve   CostType CurCost = std::numeric_limits<CostType>::max();
8569347b66cSPierre van Houtryve   unsigned CurPID = InvalidPID;
8579347b66cSPierre van Houtryve   for (const auto &[Idx, Part] : enumerate(Partitions)) {
8589347b66cSPierre van Houtryve     if (Part.first <= CurCost) {
8599347b66cSPierre van Houtryve       CurPID = Idx;
8609347b66cSPierre van Houtryve       CurCost = Part.first;
8619347b66cSPierre van Houtryve     }
8629347b66cSPierre van Houtryve   }
8639347b66cSPierre van Houtryve   assert(CurPID != InvalidPID);
8649347b66cSPierre van Houtryve   return CurPID;
8659347b66cSPierre van Houtryve }
8669347b66cSPierre van Houtryve 
8679347b66cSPierre van Houtryve void SplitProposal::calculateScores() {
8689347b66cSPierre van Houtryve   if (Partitions.empty())
8699347b66cSPierre van Houtryve     return;
8709347b66cSPierre van Houtryve 
8719347b66cSPierre van Houtryve   assert(SG);
8729347b66cSPierre van Houtryve   CostType LargestPCost = 0;
8739347b66cSPierre van Houtryve   for (auto &[PCost, Nodes] : Partitions) {
8749347b66cSPierre van Houtryve     if (PCost > LargestPCost)
8759347b66cSPierre van Houtryve       LargestPCost = PCost;
8769347b66cSPierre van Houtryve   }
8779347b66cSPierre van Houtryve 
8789347b66cSPierre van Houtryve   CostType ModuleCost = SG->getModuleCost();
8799347b66cSPierre van Houtryve   CodeSizeScore = double(TotalCost) / ModuleCost;
8809347b66cSPierre van Houtryve   assert(CodeSizeScore >= 0.0);
8819347b66cSPierre van Houtryve 
8829347b66cSPierre van Houtryve   BottleneckScore = double(LargestPCost) / ModuleCost;
8839347b66cSPierre van Houtryve 
8849347b66cSPierre van Houtryve   CodeSizeScore = std::ceil(CodeSizeScore * 100.0) / 100.0;
8859347b66cSPierre van Houtryve   BottleneckScore = std::ceil(BottleneckScore * 100.0) / 100.0;
8869347b66cSPierre van Houtryve }
8879347b66cSPierre van Houtryve 
8889347b66cSPierre van Houtryve #ifndef NDEBUG
8899347b66cSPierre van Houtryve void SplitProposal::verifyCompleteness() const {
8909347b66cSPierre van Houtryve   if (Partitions.empty())
8919347b66cSPierre van Houtryve     return;
8929347b66cSPierre van Houtryve 
8939347b66cSPierre van Houtryve   BitVector Result = Partitions[0].second;
8949347b66cSPierre van Houtryve   for (const auto &P : drop_begin(Partitions))
8959347b66cSPierre van Houtryve     Result |= P.second;
8969347b66cSPierre van Houtryve   assert(Result.all() && "some nodes are missing from this proposal!");
8979347b66cSPierre van Houtryve }
8989347b66cSPierre van Houtryve #endif
8999347b66cSPierre van Houtryve 
9009347b66cSPierre van Houtryve //===-- RecursiveSearchStrategy -------------------------------------------===//
9019347b66cSPierre van Houtryve 
9029347b66cSPierre van Houtryve /// Partitioning algorithm.
9039347b66cSPierre van Houtryve ///
9049347b66cSPierre van Houtryve /// This is a recursive search algorithm that can explore multiple possiblities.
9059347b66cSPierre van Houtryve ///
9069347b66cSPierre van Houtryve /// When a cluster of nodes can go into more than one partition, and we haven't
9079347b66cSPierre van Houtryve /// reached maximum search depth, we recurse and explore both options and their
9089347b66cSPierre van Houtryve /// consequences. Both branches will yield a proposal, and the driver will grade
9099347b66cSPierre van Houtryve /// both and choose the best one.
9109347b66cSPierre van Houtryve ///
9119347b66cSPierre van Houtryve /// If max depth is reached, we will use some heuristics to make a choice. Most
9129347b66cSPierre van Houtryve /// of the time we will just use the least-pressured (cheapest) partition, but
9139347b66cSPierre van Houtryve /// if a cluster is particularly big and there is a good amount of overlap with
9149347b66cSPierre van Houtryve /// an existing partition, we will choose that partition instead.
9159347b66cSPierre van Houtryve class RecursiveSearchSplitting {
9169347b66cSPierre van Houtryve public:
9179347b66cSPierre van Houtryve   using SubmitProposalFn = function_ref<void(SplitProposal)>;
9189347b66cSPierre van Houtryve 
9199347b66cSPierre van Houtryve   RecursiveSearchSplitting(const SplitGraph &SG, unsigned NumParts,
9209347b66cSPierre van Houtryve                            SubmitProposalFn SubmitProposal);
9219347b66cSPierre van Houtryve 
9229347b66cSPierre van Houtryve   void run();
9239347b66cSPierre van Houtryve 
9249347b66cSPierre van Houtryve private:
9259347b66cSPierre van Houtryve   struct WorkListEntry {
9269347b66cSPierre van Houtryve     WorkListEntry(const BitVector &BV) : Cluster(BV) {}
9279347b66cSPierre van Houtryve 
9289347b66cSPierre van Houtryve     unsigned NumNonEntryNodes = 0;
9299347b66cSPierre van Houtryve     CostType TotalCost = 0;
9309347b66cSPierre van Houtryve     CostType CostExcludingGraphEntryPoints = 0;
9319347b66cSPierre van Houtryve     BitVector Cluster;
9329347b66cSPierre van Houtryve   };
9339347b66cSPierre van Houtryve 
9349347b66cSPierre van Houtryve   /// Collects all graph entry points's clusters and sort them so the most
9359347b66cSPierre van Houtryve   /// expensive clusters are viewed first. This will merge clusters together if
9369347b66cSPierre van Houtryve   /// they share a non-copyable dependency.
9379347b66cSPierre van Houtryve   void setupWorkList();
9389347b66cSPierre van Houtryve 
9399347b66cSPierre van Houtryve   /// Recursive function that assigns the worklist item at \p Idx into a
9409347b66cSPierre van Houtryve   /// partition of \p SP.
9419347b66cSPierre van Houtryve   ///
9429347b66cSPierre van Houtryve   /// \p Depth is the current search depth. When this value is equal to
9439347b66cSPierre van Houtryve   /// \ref MaxDepth, we can no longer recurse.
9449347b66cSPierre van Houtryve   ///
9459347b66cSPierre van Houtryve   /// This function only recurses if there is more than one possible assignment,
9469347b66cSPierre van Houtryve   /// otherwise it is iterative to avoid creating a call stack that is as big as
9479347b66cSPierre van Houtryve   /// \ref WorkList.
9489347b66cSPierre van Houtryve   void pickPartition(unsigned Depth, unsigned Idx, SplitProposal SP);
9499347b66cSPierre van Houtryve 
9509347b66cSPierre van Houtryve   /// \return A pair: first element is the PID of the partition that has the
9519347b66cSPierre van Houtryve   /// most similarities with \p Entry, or \ref InvalidPID if no partition was
9529347b66cSPierre van Houtryve   /// found with at least one element in common. The second element is the
9539347b66cSPierre van Houtryve   /// aggregated cost of all dependencies in common between \p Entry and that
9549347b66cSPierre van Houtryve   /// partition.
9559347b66cSPierre van Houtryve   std::pair<unsigned, CostType>
9569347b66cSPierre van Houtryve   findMostSimilarPartition(const WorkListEntry &Entry, const SplitProposal &SP);
9579347b66cSPierre van Houtryve 
9589347b66cSPierre van Houtryve   const SplitGraph &SG;
9599347b66cSPierre van Houtryve   unsigned NumParts;
9609347b66cSPierre van Houtryve   SubmitProposalFn SubmitProposal;
9619347b66cSPierre van Houtryve 
9629347b66cSPierre van Houtryve   // A Cluster is considered large when its cost, excluding entry points,
9639347b66cSPierre van Houtryve   // exceeds this value.
9649347b66cSPierre van Houtryve   CostType LargeClusterThreshold = 0;
9659347b66cSPierre van Houtryve   unsigned NumProposalsSubmitted = 0;
9669347b66cSPierre van Houtryve   SmallVector<WorkListEntry> WorkList;
9679347b66cSPierre van Houtryve };
9689347b66cSPierre van Houtryve 
9699347b66cSPierre van Houtryve RecursiveSearchSplitting::RecursiveSearchSplitting(
9709347b66cSPierre van Houtryve     const SplitGraph &SG, unsigned NumParts, SubmitProposalFn SubmitProposal)
9719347b66cSPierre van Houtryve     : SG(SG), NumParts(NumParts), SubmitProposal(SubmitProposal) {
9729347b66cSPierre van Houtryve   // arbitrary max value as a safeguard. Anything above 10 will already be
9739347b66cSPierre van Houtryve   // slow, this is just a max value to prevent extreme resource exhaustion or
9749347b66cSPierre van Houtryve   // unbounded run time.
9759347b66cSPierre van Houtryve   if (MaxDepth > 16)
9769347b66cSPierre van Houtryve     report_fatal_error("[amdgpu-split-module] search depth of " +
9779347b66cSPierre van Houtryve                        Twine(MaxDepth) + " is too high!");
9789347b66cSPierre van Houtryve   LargeClusterThreshold =
9799347b66cSPierre van Houtryve       (LargeFnFactor != 0.0)
9809347b66cSPierre van Houtryve           ? CostType(((SG.getModuleCost() / NumParts) * LargeFnFactor))
9819347b66cSPierre van Houtryve           : std::numeric_limits<CostType>::max();
9829347b66cSPierre van Houtryve   LLVM_DEBUG(dbgs() << "[recursive search] large cluster threshold set at "
9839347b66cSPierre van Houtryve                     << LargeClusterThreshold << "\n");
9849347b66cSPierre van Houtryve }
9859347b66cSPierre van Houtryve 
9869347b66cSPierre van Houtryve void RecursiveSearchSplitting::run() {
9879347b66cSPierre van Houtryve   {
9889347b66cSPierre van Houtryve     SplitModuleTimer SMT("recursive_search_prepare", "preparing worklist");
9899347b66cSPierre van Houtryve     setupWorkList();
9909347b66cSPierre van Houtryve   }
9919347b66cSPierre van Houtryve 
9929347b66cSPierre van Houtryve   {
9939347b66cSPierre van Houtryve     SplitModuleTimer SMT("recursive_search_pick", "partitioning");
9949347b66cSPierre van Houtryve     SplitProposal SP(SG, NumParts);
9959347b66cSPierre van Houtryve     pickPartition(/*BranchDepth=*/0, /*Idx=*/0, SP);
9969347b66cSPierre van Houtryve   }
9979347b66cSPierre van Houtryve }
9989347b66cSPierre van Houtryve 
9999347b66cSPierre van Houtryve void RecursiveSearchSplitting::setupWorkList() {
10009347b66cSPierre van Houtryve   // e.g. if A and B are two worklist item, and they both call a non copyable
10019347b66cSPierre van Houtryve   // dependency C, this does:
10029347b66cSPierre van Houtryve   //    A=C
10039347b66cSPierre van Houtryve   //    B=C
10049347b66cSPierre van Houtryve   // => NodeEC will create a single group (A, B, C) and we create a new
10059347b66cSPierre van Houtryve   // WorkList entry for that group.
10069347b66cSPierre van Houtryve 
10079347b66cSPierre van Houtryve   EquivalenceClasses<unsigned> NodeEC;
10089347b66cSPierre van Houtryve   for (const SplitGraph::Node *N : SG.nodes()) {
10099347b66cSPierre van Houtryve     if (!N->isGraphEntryPoint())
10109347b66cSPierre van Houtryve       continue;
10119347b66cSPierre van Houtryve 
10129347b66cSPierre van Houtryve     NodeEC.insert(N->getID());
10139347b66cSPierre van Houtryve     N->visitAllDependencies([&](const SplitGraph::Node &Dep) {
10149347b66cSPierre van Houtryve       if (&Dep != N && Dep.isNonCopyable())
10159347b66cSPierre van Houtryve         NodeEC.unionSets(N->getID(), Dep.getID());
10169347b66cSPierre van Houtryve     });
10179347b66cSPierre van Houtryve   }
10189347b66cSPierre van Houtryve 
10199347b66cSPierre van Houtryve   for (auto I = NodeEC.begin(), E = NodeEC.end(); I != E; ++I) {
10209347b66cSPierre van Houtryve     if (!I->isLeader())
10219347b66cSPierre van Houtryve       continue;
10229347b66cSPierre van Houtryve 
10239347b66cSPierre van Houtryve     BitVector Cluster = SG.createNodesBitVector();
10249347b66cSPierre van Houtryve     for (auto MI = NodeEC.member_begin(I); MI != NodeEC.member_end(); ++MI) {
10259347b66cSPierre van Houtryve       const SplitGraph::Node &N = SG.getNode(*MI);
10269347b66cSPierre van Houtryve       if (N.isGraphEntryPoint())
10279347b66cSPierre van Houtryve         N.getDependencies(Cluster);
10289347b66cSPierre van Houtryve     }
10299347b66cSPierre van Houtryve     WorkList.emplace_back(std::move(Cluster));
10309347b66cSPierre van Houtryve   }
10319347b66cSPierre van Houtryve 
10329347b66cSPierre van Houtryve   // Calculate costs and other useful information.
10339347b66cSPierre van Houtryve   for (WorkListEntry &Entry : WorkList) {
10349347b66cSPierre van Houtryve     for (unsigned NodeID : Entry.Cluster.set_bits()) {
10359347b66cSPierre van Houtryve       const SplitGraph::Node &N = SG.getNode(NodeID);
10369347b66cSPierre van Houtryve       const CostType Cost = N.getIndividualCost();
10379347b66cSPierre van Houtryve 
10389347b66cSPierre van Houtryve       Entry.TotalCost += Cost;
10399347b66cSPierre van Houtryve       if (!N.isGraphEntryPoint()) {
10409347b66cSPierre van Houtryve         Entry.CostExcludingGraphEntryPoints += Cost;
10419347b66cSPierre van Houtryve         ++Entry.NumNonEntryNodes;
10429347b66cSPierre van Houtryve       }
10439347b66cSPierre van Houtryve     }
10449347b66cSPierre van Houtryve   }
10459347b66cSPierre van Houtryve 
10469347b66cSPierre van Houtryve   stable_sort(WorkList, [](const WorkListEntry &A, const WorkListEntry &B) {
10479347b66cSPierre van Houtryve     if (A.TotalCost != B.TotalCost)
10489347b66cSPierre van Houtryve       return A.TotalCost > B.TotalCost;
10499347b66cSPierre van Houtryve 
10509347b66cSPierre van Houtryve     if (A.CostExcludingGraphEntryPoints != B.CostExcludingGraphEntryPoints)
10519347b66cSPierre van Houtryve       return A.CostExcludingGraphEntryPoints > B.CostExcludingGraphEntryPoints;
10529347b66cSPierre van Houtryve 
10539347b66cSPierre van Houtryve     if (A.NumNonEntryNodes != B.NumNonEntryNodes)
10549347b66cSPierre van Houtryve       return A.NumNonEntryNodes > B.NumNonEntryNodes;
10559347b66cSPierre van Houtryve 
10569347b66cSPierre van Houtryve     return A.Cluster.count() > B.Cluster.count();
10579347b66cSPierre van Houtryve   });
10589347b66cSPierre van Houtryve 
10599347b66cSPierre van Houtryve   LLVM_DEBUG({
10609347b66cSPierre van Houtryve     dbgs() << "[recursive search] worklist:\n";
10619347b66cSPierre van Houtryve     for (const auto &[Idx, Entry] : enumerate(WorkList)) {
10629347b66cSPierre van Houtryve       dbgs() << "  - [" << Idx << "]: ";
10639347b66cSPierre van Houtryve       for (unsigned NodeID : Entry.Cluster.set_bits())
10649347b66cSPierre van Houtryve         dbgs() << NodeID << " ";
10659347b66cSPierre van Houtryve       dbgs() << "(total_cost:" << Entry.TotalCost
10669347b66cSPierre van Houtryve              << ", cost_excl_entries:" << Entry.CostExcludingGraphEntryPoints
10679347b66cSPierre van Houtryve              << ")\n";
10689347b66cSPierre van Houtryve     }
10699347b66cSPierre van Houtryve   });
10709347b66cSPierre van Houtryve }
10719347b66cSPierre van Houtryve 
10729347b66cSPierre van Houtryve void RecursiveSearchSplitting::pickPartition(unsigned Depth, unsigned Idx,
10739347b66cSPierre van Houtryve                                              SplitProposal SP) {
10749347b66cSPierre van Houtryve   while (Idx < WorkList.size()) {
10759347b66cSPierre van Houtryve     // Step 1: Determine candidate PIDs.
10769347b66cSPierre van Houtryve     //
10779347b66cSPierre van Houtryve     const WorkListEntry &Entry = WorkList[Idx];
10789347b66cSPierre van Houtryve     const BitVector &Cluster = Entry.Cluster;
10799347b66cSPierre van Houtryve 
10809347b66cSPierre van Houtryve     // Default option is to do load-balancing, AKA assign to least pressured
10819347b66cSPierre van Houtryve     // partition.
10829347b66cSPierre van Houtryve     const unsigned CheapestPID = SP.findCheapestPartition();
10839347b66cSPierre van Houtryve     assert(CheapestPID != InvalidPID);
10849347b66cSPierre van Houtryve 
10859347b66cSPierre van Houtryve     // Explore assigning to the kernel that contains the most dependencies in
10869347b66cSPierre van Houtryve     // common.
10879347b66cSPierre van Houtryve     const auto [MostSimilarPID, SimilarDepsCost] =
10889347b66cSPierre van Houtryve         findMostSimilarPartition(Entry, SP);
10899347b66cSPierre van Houtryve 
10909347b66cSPierre van Houtryve     // We can chose to explore only one path if we only have one valid path, or
10919347b66cSPierre van Houtryve     // if we reached maximum search depth and can no longer branch out.
10929347b66cSPierre van Houtryve     unsigned SinglePIDToTry = InvalidPID;
10939347b66cSPierre van Houtryve     if (MostSimilarPID == InvalidPID) // no similar PID found
10949347b66cSPierre van Houtryve       SinglePIDToTry = CheapestPID;
10959347b66cSPierre van Houtryve     else if (MostSimilarPID == CheapestPID) // both landed on the same PID
10969347b66cSPierre van Houtryve       SinglePIDToTry = CheapestPID;
10979347b66cSPierre van Houtryve     else if (Depth >= MaxDepth) {
10989347b66cSPierre van Houtryve       // We have to choose one path. Use a heuristic to guess which one will be
10999347b66cSPierre van Houtryve       // more appropriate.
11009347b66cSPierre van Houtryve       if (Entry.CostExcludingGraphEntryPoints > LargeClusterThreshold) {
11019347b66cSPierre van Houtryve         // Check if the amount of code in common makes it worth it.
11029347b66cSPierre van Houtryve         assert(SimilarDepsCost && Entry.CostExcludingGraphEntryPoints);
1103345b3319SFraser Cormack         const double Ratio = static_cast<double>(SimilarDepsCost) /
1104345b3319SFraser Cormack                              Entry.CostExcludingGraphEntryPoints;
11059347b66cSPierre van Houtryve         assert(Ratio >= 0.0 && Ratio <= 1.0);
1106345b3319SFraser Cormack         if (Ratio > LargeFnOverlapForMerge) {
11079347b66cSPierre van Houtryve           // For debug, just print "L", so we'll see "L3=P3" for instance, which
11089347b66cSPierre van Houtryve           // will mean we reached max depth and chose P3 based on this
11099347b66cSPierre van Houtryve           // heuristic.
11109347b66cSPierre van Houtryve           LLVM_DEBUG(dbgs() << 'L');
11119347b66cSPierre van Houtryve           SinglePIDToTry = MostSimilarPID;
11129347b66cSPierre van Houtryve         }
11139347b66cSPierre van Houtryve       } else
11149347b66cSPierre van Houtryve         SinglePIDToTry = CheapestPID;
11159347b66cSPierre van Houtryve     }
11169347b66cSPierre van Houtryve 
11179347b66cSPierre van Houtryve     // Step 2: Explore candidates.
11189347b66cSPierre van Houtryve 
11199347b66cSPierre van Houtryve     // When we only explore one possible path, and thus branch depth doesn't
11209347b66cSPierre van Houtryve     // increase, do not recurse, iterate instead.
11219347b66cSPierre van Houtryve     if (SinglePIDToTry != InvalidPID) {
11229347b66cSPierre van Houtryve       LLVM_DEBUG(dbgs() << Idx << "=P" << SinglePIDToTry << ' ');
11239347b66cSPierre van Houtryve       // Only one path to explore, don't clone SP, don't increase depth.
11249347b66cSPierre van Houtryve       SP.add(SinglePIDToTry, Cluster);
11259347b66cSPierre van Houtryve       ++Idx;
11269347b66cSPierre van Houtryve       continue;
11279347b66cSPierre van Houtryve     }
11289347b66cSPierre van Houtryve 
11299347b66cSPierre van Houtryve     assert(MostSimilarPID != InvalidPID);
11309347b66cSPierre van Houtryve 
11319347b66cSPierre van Houtryve     // We explore multiple paths: recurse at increased depth, then stop this
11329347b66cSPierre van Houtryve     // function.
11339347b66cSPierre van Houtryve 
11349347b66cSPierre van Houtryve     LLVM_DEBUG(dbgs() << '\n');
11359347b66cSPierre van Houtryve 
11369347b66cSPierre van Houtryve     // lb = load balancing = put in cheapest partition
11379347b66cSPierre van Houtryve     {
11389347b66cSPierre van Houtryve       SplitProposal BranchSP = SP;
11399347b66cSPierre van Houtryve       LLVM_DEBUG(dbgs().indent(Depth)
11409347b66cSPierre van Houtryve                  << " [lb] " << Idx << "=P" << CheapestPID << "? ");
11419347b66cSPierre van Houtryve       BranchSP.add(CheapestPID, Cluster);
11429347b66cSPierre van Houtryve       pickPartition(Depth + 1, Idx + 1, BranchSP);
11439347b66cSPierre van Houtryve     }
11449347b66cSPierre van Houtryve 
11459347b66cSPierre van Houtryve     // ms = most similar = put in partition with the most in common
11469347b66cSPierre van Houtryve     {
11479347b66cSPierre van Houtryve       SplitProposal BranchSP = SP;
11489347b66cSPierre van Houtryve       LLVM_DEBUG(dbgs().indent(Depth)
11499347b66cSPierre van Houtryve                  << " [ms] " << Idx << "=P" << MostSimilarPID << "? ");
11509347b66cSPierre van Houtryve       BranchSP.add(MostSimilarPID, Cluster);
11519347b66cSPierre van Houtryve       pickPartition(Depth + 1, Idx + 1, BranchSP);
11529347b66cSPierre van Houtryve     }
11539347b66cSPierre van Houtryve 
11549347b66cSPierre van Houtryve     return;
11559347b66cSPierre van Houtryve   }
11569347b66cSPierre van Houtryve 
11579347b66cSPierre van Houtryve   // Step 3: If we assigned all WorkList items, submit the proposal.
11589347b66cSPierre van Houtryve 
11599347b66cSPierre van Houtryve   assert(Idx == WorkList.size());
11609347b66cSPierre van Houtryve   assert(NumProposalsSubmitted <= (2u << MaxDepth) &&
11619347b66cSPierre van Houtryve          "Search got out of bounds?");
11629347b66cSPierre van Houtryve   SP.setName("recursive_search (depth=" + std::to_string(Depth) + ") #" +
11639347b66cSPierre van Houtryve              std::to_string(NumProposalsSubmitted++));
11649347b66cSPierre van Houtryve   LLVM_DEBUG(dbgs() << '\n');
11659347b66cSPierre van Houtryve   SubmitProposal(SP);
11669347b66cSPierre van Houtryve }
11679347b66cSPierre van Houtryve 
11689347b66cSPierre van Houtryve std::pair<unsigned, CostType>
11699347b66cSPierre van Houtryve RecursiveSearchSplitting::findMostSimilarPartition(const WorkListEntry &Entry,
11709347b66cSPierre van Houtryve                                                    const SplitProposal &SP) {
11719347b66cSPierre van Houtryve   if (!Entry.NumNonEntryNodes)
11729347b66cSPierre van Houtryve     return {InvalidPID, 0};
11739347b66cSPierre van Houtryve 
11749347b66cSPierre van Houtryve   // We take the partition that is the most similar using Cost as a metric.
11759347b66cSPierre van Houtryve   // So we take the set of nodes in common, compute their aggregated cost, and
11769347b66cSPierre van Houtryve   // pick the partition with the highest cost in common.
11779347b66cSPierre van Houtryve   unsigned ChosenPID = InvalidPID;
11789347b66cSPierre van Houtryve   CostType ChosenCost = 0;
11799347b66cSPierre van Houtryve   for (unsigned PID = 0; PID < NumParts; ++PID) {
11809347b66cSPierre van Houtryve     BitVector BV = SP[PID];
11819347b66cSPierre van Houtryve     BV &= Entry.Cluster; // FIXME: & doesn't work between BVs?!
11829347b66cSPierre van Houtryve 
11839347b66cSPierre van Houtryve     if (BV.none())
11849347b66cSPierre van Houtryve       continue;
11859347b66cSPierre van Houtryve 
11869347b66cSPierre van Houtryve     const CostType Cost = SG.calculateCost(BV);
11879347b66cSPierre van Houtryve 
11889347b66cSPierre van Houtryve     if (ChosenPID == InvalidPID || ChosenCost < Cost ||
11899347b66cSPierre van Houtryve         (ChosenCost == Cost && PID > ChosenPID)) {
11909347b66cSPierre van Houtryve       ChosenPID = PID;
11919347b66cSPierre van Houtryve       ChosenCost = Cost;
11929347b66cSPierre van Houtryve     }
11939347b66cSPierre van Houtryve   }
11949347b66cSPierre van Houtryve 
11959347b66cSPierre van Houtryve   return {ChosenPID, ChosenCost};
11969347b66cSPierre van Houtryve }
11979347b66cSPierre van Houtryve 
11989347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
11999347b66cSPierre van Houtryve // DOTGraph Printing Support
12009347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
12019347b66cSPierre van Houtryve 
12029347b66cSPierre van Houtryve const SplitGraph::Node *mapEdgeToDst(const SplitGraph::Edge *E) {
12039347b66cSPierre van Houtryve   return E->Dst;
12049347b66cSPierre van Houtryve }
12059347b66cSPierre van Houtryve 
12069347b66cSPierre van Houtryve using SplitGraphEdgeDstIterator =
12079347b66cSPierre van Houtryve     mapped_iterator<SplitGraph::edges_iterator, decltype(&mapEdgeToDst)>;
12089347b66cSPierre van Houtryve 
12099347b66cSPierre van Houtryve } // namespace
12109347b66cSPierre van Houtryve 
12119347b66cSPierre van Houtryve template <> struct GraphTraits<SplitGraph> {
12129347b66cSPierre van Houtryve   using NodeRef = const SplitGraph::Node *;
12139347b66cSPierre van Houtryve   using nodes_iterator = SplitGraph::nodes_iterator;
12149347b66cSPierre van Houtryve   using ChildIteratorType = SplitGraphEdgeDstIterator;
12159347b66cSPierre van Houtryve 
12169347b66cSPierre van Houtryve   using EdgeRef = const SplitGraph::Edge *;
12179347b66cSPierre van Houtryve   using ChildEdgeIteratorType = SplitGraph::edges_iterator;
12189347b66cSPierre van Houtryve 
12199347b66cSPierre van Houtryve   static NodeRef getEntryNode(NodeRef N) { return N; }
12209347b66cSPierre van Houtryve 
12219347b66cSPierre van Houtryve   static ChildIteratorType child_begin(NodeRef Ref) {
12229347b66cSPierre van Houtryve     return {Ref->outgoing_edges().begin(), mapEdgeToDst};
12239347b66cSPierre van Houtryve   }
12249347b66cSPierre van Houtryve   static ChildIteratorType child_end(NodeRef Ref) {
12259347b66cSPierre van Houtryve     return {Ref->outgoing_edges().end(), mapEdgeToDst};
12269347b66cSPierre van Houtryve   }
12279347b66cSPierre van Houtryve 
12289347b66cSPierre van Houtryve   static nodes_iterator nodes_begin(const SplitGraph &G) {
12299347b66cSPierre van Houtryve     return G.nodes().begin();
12309347b66cSPierre van Houtryve   }
12319347b66cSPierre van Houtryve   static nodes_iterator nodes_end(const SplitGraph &G) {
12329347b66cSPierre van Houtryve     return G.nodes().end();
12339347b66cSPierre van Houtryve   }
12349347b66cSPierre van Houtryve };
12359347b66cSPierre van Houtryve 
12369347b66cSPierre van Houtryve template <> struct DOTGraphTraits<SplitGraph> : public DefaultDOTGraphTraits {
12379347b66cSPierre van Houtryve   DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
12389347b66cSPierre van Houtryve 
12399347b66cSPierre van Houtryve   static std::string getGraphName(const SplitGraph &SG) {
12409347b66cSPierre van Houtryve     return SG.getModule().getName().str();
12419347b66cSPierre van Houtryve   }
12429347b66cSPierre van Houtryve 
12439347b66cSPierre van Houtryve   std::string getNodeLabel(const SplitGraph::Node *N, const SplitGraph &SG) {
12449347b66cSPierre van Houtryve     return N->getName().str();
12459347b66cSPierre van Houtryve   }
12469347b66cSPierre van Houtryve 
12479347b66cSPierre van Houtryve   static std::string getNodeDescription(const SplitGraph::Node *N,
12489347b66cSPierre van Houtryve                                         const SplitGraph &SG) {
12499347b66cSPierre van Houtryve     std::string Result;
12509347b66cSPierre van Houtryve     if (N->isEntryFunctionCC())
12519347b66cSPierre van Houtryve       Result += "entry-fn-cc ";
12529347b66cSPierre van Houtryve     if (N->isNonCopyable())
12539347b66cSPierre van Houtryve       Result += "non-copyable ";
12549347b66cSPierre van Houtryve     Result += "cost:" + std::to_string(N->getIndividualCost());
12559347b66cSPierre van Houtryve     return Result;
12569347b66cSPierre van Houtryve   }
12579347b66cSPierre van Houtryve 
12589347b66cSPierre van Houtryve   static std::string getNodeAttributes(const SplitGraph::Node *N,
12599347b66cSPierre van Houtryve                                        const SplitGraph &SG) {
12609347b66cSPierre van Houtryve     return N->hasAnyIncomingEdges() ? "" : "color=\"red\"";
12619347b66cSPierre van Houtryve   }
12629347b66cSPierre van Houtryve 
12639347b66cSPierre van Houtryve   static std::string getEdgeAttributes(const SplitGraph::Node *N,
12649347b66cSPierre van Houtryve                                        SplitGraphEdgeDstIterator EI,
12659347b66cSPierre van Houtryve                                        const SplitGraph &SG) {
12669347b66cSPierre van Houtryve 
12679347b66cSPierre van Houtryve     switch ((*EI.getCurrent())->Kind) {
12689347b66cSPierre van Houtryve     case SplitGraph::EdgeKind::DirectCall:
12699347b66cSPierre van Houtryve       return "";
12709347b66cSPierre van Houtryve     case SplitGraph::EdgeKind::IndirectCall:
12719347b66cSPierre van Houtryve       return "style=\"dashed\"";
12729347b66cSPierre van Houtryve     }
12739347b66cSPierre van Houtryve     llvm_unreachable("Unknown SplitGraph::EdgeKind enum");
12749347b66cSPierre van Houtryve   }
12759347b66cSPierre van Houtryve };
12769347b66cSPierre van Houtryve 
12779347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
12789347b66cSPierre van Houtryve // Driver
12799347b66cSPierre van Houtryve //===----------------------------------------------------------------------===//
12809347b66cSPierre van Houtryve 
12819347b66cSPierre van Houtryve namespace {
12829347b66cSPierre van Houtryve 
12839347b66cSPierre van Houtryve // If we didn't externalize GVs, then local GVs need to be conservatively
12849347b66cSPierre van Houtryve // imported into every module (including their initializers), and then cleaned
12859347b66cSPierre van Houtryve // up afterwards.
12869347b66cSPierre van Houtryve static bool needsConservativeImport(const GlobalValue *GV) {
12879347b66cSPierre van Houtryve   if (const auto *Var = dyn_cast<GlobalVariable>(GV))
12889347b66cSPierre van Houtryve     return Var->hasLocalLinkage();
12899347b66cSPierre van Houtryve   return isa<GlobalAlias>(GV);
12909347b66cSPierre van Houtryve }
12919347b66cSPierre van Houtryve 
12929347b66cSPierre van Houtryve /// Prints a summary of the partition \p N, represented by module \p M, to \p
12939347b66cSPierre van Houtryve /// OS.
12949347b66cSPierre van Houtryve static void printPartitionSummary(raw_ostream &OS, unsigned N, const Module &M,
12959347b66cSPierre van Houtryve                                   unsigned PartCost, unsigned ModuleCost) {
12969347b66cSPierre van Houtryve   OS << "*** Partition P" << N << " ***\n";
12979347b66cSPierre van Houtryve 
12989347b66cSPierre van Houtryve   for (const auto &Fn : M) {
12999347b66cSPierre van Houtryve     if (!Fn.isDeclaration())
13009347b66cSPierre van Houtryve       OS << " - [function] " << Fn.getName() << "\n";
13019347b66cSPierre van Houtryve   }
13029347b66cSPierre van Houtryve 
13039347b66cSPierre van Houtryve   for (const auto &GV : M.globals()) {
13049347b66cSPierre van Houtryve     if (GV.hasInitializer())
13059347b66cSPierre van Houtryve       OS << " - [global] " << GV.getName() << "\n";
13069347b66cSPierre van Houtryve   }
13079347b66cSPierre van Houtryve 
13089347b66cSPierre van Houtryve   OS << "Partition contains " << formatRatioOf(PartCost, ModuleCost)
13099347b66cSPierre van Houtryve      << "% of the source\n";
13109347b66cSPierre van Houtryve }
13119347b66cSPierre van Houtryve 
13129347b66cSPierre van Houtryve static void evaluateProposal(SplitProposal &Best, SplitProposal New) {
13139347b66cSPierre van Houtryve   SplitModuleTimer SMT("proposal_evaluation", "proposal ranking algorithm");
13149347b66cSPierre van Houtryve 
13159347b66cSPierre van Houtryve   LLVM_DEBUG({
13169347b66cSPierre van Houtryve     New.verifyCompleteness();
13179347b66cSPierre van Houtryve     if (DebugProposalSearch)
13189347b66cSPierre van Houtryve       New.print(dbgs());
13199347b66cSPierre van Houtryve   });
13209347b66cSPierre van Houtryve 
13219347b66cSPierre van Houtryve   const double CurBScore = Best.getBottleneckScore();
13229347b66cSPierre van Houtryve   const double CurCSScore = Best.getCodeSizeScore();
13239347b66cSPierre van Houtryve   const double NewBScore = New.getBottleneckScore();
13249347b66cSPierre van Houtryve   const double NewCSScore = New.getCodeSizeScore();
13259347b66cSPierre van Houtryve 
13269347b66cSPierre van Houtryve   // TODO: Improve this
13279347b66cSPierre van Houtryve   //    We can probably lower the precision of the comparison at first
13289347b66cSPierre van Houtryve   //    e.g. if we have
13299347b66cSPierre van Houtryve   //      - (Current): BScore: 0.489 CSCore 1.105
13309347b66cSPierre van Houtryve   //      - (New): BScore: 0.475 CSCore 1.305
13319347b66cSPierre van Houtryve   //    Currently we'd choose the new one because the bottleneck score is
13329347b66cSPierre van Houtryve   //    lower, but the new one duplicates more code. It may be worth it to
13339347b66cSPierre van Houtryve   //    discard the new proposal as the impact on build time is negligible.
13349347b66cSPierre van Houtryve 
13359347b66cSPierre van Houtryve   // Compare them
13369347b66cSPierre van Houtryve   bool IsBest = false;
13379347b66cSPierre van Houtryve   if (NewBScore < CurBScore)
13389347b66cSPierre van Houtryve     IsBest = true;
13399347b66cSPierre van Houtryve   else if (NewBScore == CurBScore)
13409347b66cSPierre van Houtryve     IsBest = (NewCSScore < CurCSScore); // Use code size as tie breaker.
13419347b66cSPierre van Houtryve 
13429347b66cSPierre van Houtryve   if (IsBest)
13439347b66cSPierre van Houtryve     Best = std::move(New);
13449347b66cSPierre van Houtryve 
13459347b66cSPierre van Houtryve   LLVM_DEBUG(if (DebugProposalSearch) {
13469347b66cSPierre van Houtryve     if (IsBest)
13479347b66cSPierre van Houtryve       dbgs() << "[search] new best proposal!\n";
13489347b66cSPierre van Houtryve     else
13499347b66cSPierre van Houtryve       dbgs() << "[search] discarding - not profitable\n";
13509347b66cSPierre van Houtryve   });
13519347b66cSPierre van Houtryve }
13529347b66cSPierre van Houtryve 
13539347b66cSPierre van Houtryve /// Trivial helper to create an identical copy of \p M.
13549347b66cSPierre van Houtryve static std::unique_ptr<Module> cloneAll(const Module &M) {
13559347b66cSPierre van Houtryve   ValueToValueMapTy VMap;
13569347b66cSPierre van Houtryve   return CloneModule(M, VMap, [&](const GlobalValue *GV) { return true; });
13579347b66cSPierre van Houtryve }
13589347b66cSPierre van Houtryve 
13599347b66cSPierre van Houtryve /// Writes \p SG as a DOTGraph to \ref ModuleDotCfgDir if requested.
13609347b66cSPierre van Houtryve static void writeDOTGraph(const SplitGraph &SG) {
13619347b66cSPierre van Houtryve   if (ModuleDotCfgOutput.empty())
13629347b66cSPierre van Houtryve     return;
13639347b66cSPierre van Houtryve 
13649347b66cSPierre van Houtryve   std::error_code EC;
13659347b66cSPierre van Houtryve   raw_fd_ostream OS(ModuleDotCfgOutput, EC);
13669347b66cSPierre van Houtryve   if (EC) {
13679347b66cSPierre van Houtryve     errs() << "[" DEBUG_TYPE "]: cannot open '" << ModuleDotCfgOutput
13689347b66cSPierre van Houtryve            << "' - DOTGraph will not be printed\n";
13699347b66cSPierre van Houtryve   }
13709347b66cSPierre van Houtryve   WriteGraph(OS, SG, /*ShortName=*/false,
13719347b66cSPierre van Houtryve              /*Title=*/SG.getModule().getName());
13729347b66cSPierre van Houtryve }
13739347b66cSPierre van Houtryve 
1374d95b82c4SPierre van Houtryve static void splitAMDGPUModule(
13759347b66cSPierre van Houtryve     GetTTIFn GetTTI, Module &M, unsigned NumParts,
137643fd244bSPierre van Houtryve     function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
137743fd244bSPierre van Houtryve   CallGraph CG(M);
137843fd244bSPierre van Houtryve 
137943fd244bSPierre van Houtryve   // Externalize functions whose address are taken.
138043fd244bSPierre van Houtryve   //
138143fd244bSPierre van Houtryve   // This is needed because partitioning is purely based on calls, but sometimes
138243fd244bSPierre van Houtryve   // a kernel/function may just look at the address of another local function
138343fd244bSPierre van Houtryve   // and not do anything (no calls). After partitioning, that local function may
138443fd244bSPierre van Houtryve   // end up in a different module (so it's just a declaration in the module
138543fd244bSPierre van Houtryve   // where its address is taken), which emits a "undefined hidden symbol" linker
138643fd244bSPierre van Houtryve   // error.
138743fd244bSPierre van Houtryve   //
138843fd244bSPierre van Houtryve   // Additionally, it guides partitioning to not duplicate this function if it's
138943fd244bSPierre van Houtryve   // called directly at some point.
1390d656b206SPierre van Houtryve   //
1391d656b206SPierre van Houtryve   // TODO: Could we be smarter about this ? This makes all functions whose
1392d656b206SPierre van Houtryve   // addresses are taken non-copyable. We should probably model this type of
1393d656b206SPierre van Houtryve   // constraint in the graph and use it to guide splitting, instead of
1394d656b206SPierre van Houtryve   // externalizing like this. Maybe non-copyable should really mean "keep one
1395d656b206SPierre van Houtryve   // visible copy, then internalize all other copies" for some functions?
1396d656b206SPierre van Houtryve   if (!NoExternalizeOnAddrTaken) {
139743fd244bSPierre van Houtryve     for (auto &Fn : M) {
1398d656b206SPierre van Houtryve       // TODO: Should aliases count? Probably not but they're so rare I'm not
1399d656b206SPierre van Houtryve       // sure it's worth fixing.
1400d656b206SPierre van Houtryve       if (Fn.hasLocalLinkage() && Fn.hasAddressTaken()) {
1401d656b206SPierre van Houtryve         LLVM_DEBUG(dbgs() << "[externalize] "; Fn.printAsOperand(dbgs());
1402d656b206SPierre van Houtryve                    dbgs() << " because its address is taken\n");
140343fd244bSPierre van Houtryve         externalize(Fn);
140443fd244bSPierre van Houtryve       }
140543fd244bSPierre van Houtryve     }
1406d656b206SPierre van Houtryve   }
140743fd244bSPierre van Houtryve 
140843fd244bSPierre van Houtryve   // Externalize local GVs, which avoids duplicating their initializers, which
140943fd244bSPierre van Houtryve   // in turns helps keep code size in check.
141043fd244bSPierre van Houtryve   if (!NoExternalizeGlobals) {
141143fd244bSPierre van Houtryve     for (auto &GV : M.globals()) {
141243fd244bSPierre van Houtryve       if (GV.hasLocalLinkage())
14139347b66cSPierre van Houtryve         LLVM_DEBUG(dbgs() << "[externalize] GV " << GV.getName() << '\n');
141443fd244bSPierre van Houtryve       externalize(GV);
141543fd244bSPierre van Houtryve     }
141643fd244bSPierre van Houtryve   }
141743fd244bSPierre van Houtryve 
141843fd244bSPierre van Houtryve   // Start by calculating the cost of every function in the module, as well as
141943fd244bSPierre van Houtryve   // the module's overall cost.
14209347b66cSPierre van Houtryve   FunctionsCostMap FnCosts;
14219347b66cSPierre van Houtryve   const CostType ModuleCost = calculateFunctionCosts(GetTTI, M, FnCosts);
142243fd244bSPierre van Houtryve 
14239347b66cSPierre van Houtryve   // Build the SplitGraph, which represents the module's functions and models
14249347b66cSPierre van Houtryve   // their dependencies accurately.
14259347b66cSPierre van Houtryve   SplitGraph SG(M, FnCosts, ModuleCost);
14269347b66cSPierre van Houtryve   SG.buildGraph(CG);
14279347b66cSPierre van Houtryve 
14289347b66cSPierre van Houtryve   if (SG.empty()) {
14299347b66cSPierre van Houtryve     LLVM_DEBUG(
14309347b66cSPierre van Houtryve         dbgs()
14319347b66cSPierre van Houtryve         << "[!] no nodes in graph, input is empty - no splitting possible\n");
14329347b66cSPierre van Houtryve     ModuleCallback(cloneAll(M));
14339347b66cSPierre van Houtryve     return;
143443fd244bSPierre van Houtryve   }
14351c025fb0SPierre van Houtryve 
14369347b66cSPierre van Houtryve   LLVM_DEBUG({
14379347b66cSPierre van Houtryve     dbgs() << "[graph] nodes:\n";
14389347b66cSPierre van Houtryve     for (const SplitGraph::Node *N : SG.nodes()) {
14399347b66cSPierre van Houtryve       dbgs() << "  - [" << N->getID() << "]: " << N->getName() << " "
1440d656b206SPierre van Houtryve              << (N->isGraphEntryPoint() ? "(entry)" : "") << " "
1441d656b206SPierre van Houtryve              << (N->isNonCopyable() ? "(noncopyable)" : "") << "\n";
14421c025fb0SPierre van Houtryve     }
144343fd244bSPierre van Houtryve   });
144443fd244bSPierre van Houtryve 
14459347b66cSPierre van Houtryve   writeDOTGraph(SG);
14461c025fb0SPierre van Houtryve 
14479347b66cSPierre van Houtryve   LLVM_DEBUG(dbgs() << "[search] testing splitting strategies\n");
144843fd244bSPierre van Houtryve 
14499347b66cSPierre van Houtryve   std::optional<SplitProposal> Proposal;
14509347b66cSPierre van Houtryve   const auto EvaluateProposal = [&](SplitProposal SP) {
14519347b66cSPierre van Houtryve     SP.calculateScores();
14529347b66cSPierre van Houtryve     if (!Proposal)
14539347b66cSPierre van Houtryve       Proposal = std::move(SP);
14549347b66cSPierre van Houtryve     else
14559347b66cSPierre van Houtryve       evaluateProposal(*Proposal, std::move(SP));
145643fd244bSPierre van Houtryve   };
145743fd244bSPierre van Houtryve 
14589347b66cSPierre van Houtryve   // TODO: It would be very easy to create new strategies by just adding a base
14599347b66cSPierre van Houtryve   // class to RecursiveSearchSplitting and abstracting it away.
14609347b66cSPierre van Houtryve   RecursiveSearchSplitting(SG, NumParts, EvaluateProposal).run();
14619347b66cSPierre van Houtryve   LLVM_DEBUG(if (Proposal) dbgs() << "[search done] selected proposal: "
14629347b66cSPierre van Houtryve                                   << Proposal->getName() << "\n";);
14639347b66cSPierre van Houtryve 
14649347b66cSPierre van Houtryve   if (!Proposal) {
14659347b66cSPierre van Houtryve     LLVM_DEBUG(dbgs() << "[!] no proposal made, no splitting possible!\n");
14669347b66cSPierre van Houtryve     ModuleCallback(cloneAll(M));
14679347b66cSPierre van Houtryve     return;
14689347b66cSPierre van Houtryve   }
14699347b66cSPierre van Houtryve 
14709347b66cSPierre van Houtryve   LLVM_DEBUG(Proposal->print(dbgs()););
14719347b66cSPierre van Houtryve 
14729347b66cSPierre van Houtryve   std::optional<raw_fd_ostream> SummariesOS;
14739347b66cSPierre van Houtryve   if (!PartitionSummariesOutput.empty()) {
14749347b66cSPierre van Houtryve     std::error_code EC;
14759347b66cSPierre van Houtryve     SummariesOS.emplace(PartitionSummariesOutput, EC);
14769347b66cSPierre van Houtryve     if (EC)
14779347b66cSPierre van Houtryve       errs() << "[" DEBUG_TYPE "]: cannot open '" << PartitionSummariesOutput
14789347b66cSPierre van Houtryve              << "' - Partition summaries will not be printed\n";
14799347b66cSPierre van Houtryve   }
14809347b66cSPierre van Houtryve 
14819347b66cSPierre van Houtryve   for (unsigned PID = 0; PID < NumParts; ++PID) {
14829347b66cSPierre van Houtryve     SplitModuleTimer SMT2("modules_creation",
14839347b66cSPierre van Houtryve                           "creating modules for each partition");
14849347b66cSPierre van Houtryve     LLVM_DEBUG(dbgs() << "[split] creating new modules\n");
14859347b66cSPierre van Houtryve 
14869347b66cSPierre van Houtryve     DenseSet<const Function *> FnsInPart;
14879347b66cSPierre van Houtryve     for (unsigned NodeID : (*Proposal)[PID].set_bits())
14889347b66cSPierre van Houtryve       FnsInPart.insert(&SG.getNode(NodeID).getFunction());
148943fd244bSPierre van Houtryve 
149043fd244bSPierre van Houtryve     ValueToValueMapTy VMap;
14919347b66cSPierre van Houtryve     CostType PartCost = 0;
149243fd244bSPierre van Houtryve     std::unique_ptr<Module> MPart(
149343fd244bSPierre van Houtryve         CloneModule(M, VMap, [&](const GlobalValue *GV) {
149443fd244bSPierre van Houtryve           // Functions go in their assigned partition.
14959347b66cSPierre van Houtryve           if (const auto *Fn = dyn_cast<Function>(GV)) {
14969347b66cSPierre van Houtryve             if (FnsInPart.contains(Fn)) {
14979347b66cSPierre van Houtryve               PartCost += SG.getCost(*Fn);
149843fd244bSPierre van Houtryve               return true;
14999347b66cSPierre van Houtryve             }
15009347b66cSPierre van Houtryve             return false;
15019347b66cSPierre van Houtryve           }
150243fd244bSPierre van Houtryve 
150343fd244bSPierre van Houtryve           // Everything else goes in the first partition.
15049347b66cSPierre van Houtryve           return needsConservativeImport(GV) || PID == 0;
150543fd244bSPierre van Houtryve         }));
150643fd244bSPierre van Houtryve 
15079347b66cSPierre van Houtryve     // FIXME: Aliases aren't seen often, and their handling isn't perfect so
15089347b66cSPierre van Houtryve     // bugs are possible.
15099347b66cSPierre van Houtryve 
151043fd244bSPierre van Houtryve     // Clean-up conservatively imported GVs without any users.
15119347b66cSPierre van Houtryve     for (auto &GV : make_early_inc_range(MPart->global_values())) {
15129347b66cSPierre van Houtryve       if (needsConservativeImport(&GV) && GV.use_empty())
151343fd244bSPierre van Houtryve         GV.eraseFromParent();
151443fd244bSPierre van Houtryve     }
151543fd244bSPierre van Houtryve 
15169347b66cSPierre van Houtryve     if (SummariesOS)
15179347b66cSPierre van Houtryve       printPartitionSummary(*SummariesOS, PID, *MPart, PartCost, ModuleCost);
15189347b66cSPierre van Houtryve 
15199347b66cSPierre van Houtryve     LLVM_DEBUG(
15209347b66cSPierre van Houtryve         printPartitionSummary(dbgs(), PID, *MPart, PartCost, ModuleCost));
15219347b66cSPierre van Houtryve 
152243fd244bSPierre van Houtryve     ModuleCallback(std::move(MPart));
152343fd244bSPierre van Houtryve   }
152443fd244bSPierre van Houtryve }
1525d95b82c4SPierre van Houtryve } // namespace
1526d95b82c4SPierre van Houtryve 
1527d95b82c4SPierre van Houtryve PreservedAnalyses AMDGPUSplitModulePass::run(Module &M,
1528d95b82c4SPierre van Houtryve                                              ModuleAnalysisManager &MAM) {
15299347b66cSPierre van Houtryve   SplitModuleTimer SMT(
15309347b66cSPierre van Houtryve       "total", "total pass runtime (incl. potentially waiting for lockfile)");
15319347b66cSPierre van Houtryve 
1532d95b82c4SPierre van Houtryve   FunctionAnalysisManager &FAM =
1533d95b82c4SPierre van Houtryve       MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
1534d95b82c4SPierre van Houtryve   const auto TTIGetter = [&FAM](Function &F) -> const TargetTransformInfo & {
1535d95b82c4SPierre van Houtryve     return FAM.getResult<TargetIRAnalysis>(F);
1536d95b82c4SPierre van Houtryve   };
15379347b66cSPierre van Houtryve 
15389347b66cSPierre van Houtryve   bool Done = false;
15399347b66cSPierre van Houtryve #ifndef NDEBUG
15409347b66cSPierre van Houtryve   if (UseLockFile) {
15419347b66cSPierre van Houtryve     SmallString<128> LockFilePath;
15429347b66cSPierre van Houtryve     sys::path::system_temp_directory(/*ErasedOnReboot=*/true, LockFilePath);
15439347b66cSPierre van Houtryve     sys::path::append(LockFilePath, "amdgpu-split-module-debug");
15449347b66cSPierre van Houtryve     LLVM_DEBUG(dbgs() << DEBUG_TYPE " using lockfile '" << LockFilePath
15459347b66cSPierre van Houtryve                       << "'\n");
15469347b66cSPierre van Houtryve 
15479347b66cSPierre van Houtryve     while (true) {
15489347b66cSPierre van Houtryve       llvm::LockFileManager Locked(LockFilePath.str());
15499347b66cSPierre van Houtryve       switch (Locked) {
15509347b66cSPierre van Houtryve       case LockFileManager::LFS_Error:
15519347b66cSPierre van Houtryve         LLVM_DEBUG(
15529347b66cSPierre van Houtryve             dbgs() << "[amdgpu-split-module] unable to acquire lockfile, debug "
15539347b66cSPierre van Houtryve                       "output may be mangled by other processes\n");
15549347b66cSPierre van Houtryve         Locked.unsafeRemoveLockFile();
15559347b66cSPierre van Houtryve         break;
15569347b66cSPierre van Houtryve       case LockFileManager::LFS_Owned:
15579347b66cSPierre van Houtryve         break;
15589347b66cSPierre van Houtryve       case LockFileManager::LFS_Shared: {
15599347b66cSPierre van Houtryve         switch (Locked.waitForUnlock()) {
15609347b66cSPierre van Houtryve         case LockFileManager::Res_Success:
15619347b66cSPierre van Houtryve           break;
15629347b66cSPierre van Houtryve         case LockFileManager::Res_OwnerDied:
15639347b66cSPierre van Houtryve           continue; // try again to get the lock.
15649347b66cSPierre van Houtryve         case LockFileManager::Res_Timeout:
15659347b66cSPierre van Houtryve           LLVM_DEBUG(
15669347b66cSPierre van Houtryve               dbgs()
15679347b66cSPierre van Houtryve               << "[amdgpu-split-module] unable to acquire lockfile, debug "
15689347b66cSPierre van Houtryve                  "output may be mangled by other processes\n");
15699347b66cSPierre van Houtryve           Locked.unsafeRemoveLockFile();
15709347b66cSPierre van Houtryve           break; // give up
1571c9b6e01bSPierre van Houtryve         }
15729347b66cSPierre van Houtryve         break;
15739347b66cSPierre van Houtryve       }
15749347b66cSPierre van Houtryve       }
15759347b66cSPierre van Houtryve 
15769347b66cSPierre van Houtryve       splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
15779347b66cSPierre van Houtryve       Done = true;
15789347b66cSPierre van Houtryve       break;
15799347b66cSPierre van Houtryve     }
15809347b66cSPierre van Houtryve   }
15819347b66cSPierre van Houtryve #endif
15829347b66cSPierre van Houtryve 
15839347b66cSPierre van Houtryve   if (!Done)
15849347b66cSPierre van Houtryve     splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
15859347b66cSPierre van Houtryve 
15869347b66cSPierre van Houtryve   // We can change linkage/visibilities in the input, consider that nothing is
15879347b66cSPierre van Houtryve   // preserved just to be safe. This pass runs last anyway.
15889347b66cSPierre van Houtryve   return PreservedAnalyses::none();
15899347b66cSPierre van Houtryve }
15909347b66cSPierre van Houtryve } // namespace llvm
1591