xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp (revision 140cbca83d2cf9ebb1718671fdd251fef5bc63b3)
1 //===- AMDGPUSplitModule.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Implements a module splitting algorithm designed to support the
10 /// FullLTO --lto-partitions option for parallel codegen.
11 ///
12 /// The role of this module splitting pass is the same as
13 /// lib/Transforms/Utils/SplitModule.cpp: load-balance the module's functions
14 /// across a set of N partitions to allow for parallel codegen.
15 ///
16 /// The similarities mostly end here, as this pass achieves load-balancing in a
17 /// more elaborate fashion which is targeted towards AMDGPU modules. It can take
18 /// advantage of the structure of AMDGPU modules (which are mostly
19 /// self-contained) to allow for more efficient splitting without affecting
20 /// codegen negatively, or causing innaccurate resource usage analysis.
21 ///
22 /// High-level pass overview:
23 ///   - SplitGraph & associated classes
24 ///      - Graph representation of the module and of the dependencies that
25 ///      matter for splitting.
26 ///   - RecursiveSearchSplitting
27 ///     - Core splitting algorithm.
28 ///   - SplitProposal
29 ///     - Represents a suggested solution for splitting the input module. These
30 ///     solutions can be scored to determine the best one when multiple
31 ///     solutions are available.
32 ///   - Driver/pass "run" function glues everything together.
33 
34 #include "AMDGPUSplitModule.h"
35 #include "AMDGPUTargetMachine.h"
36 #include "Utils/AMDGPUBaseInfo.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/EquivalenceClasses.h"
39 #include "llvm/ADT/GraphTraits.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/StringExtras.h"
42 #include "llvm/ADT/StringRef.h"
43 #include "llvm/Analysis/CallGraph.h"
44 #include "llvm/Analysis/TargetTransformInfo.h"
45 #include "llvm/IR/Function.h"
46 #include "llvm/IR/InstIterator.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/User.h"
50 #include "llvm/IR/Value.h"
51 #include "llvm/Support/Allocator.h"
52 #include "llvm/Support/Casting.h"
53 #include "llvm/Support/DOTGraphTraits.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/FileSystem.h"
56 #include "llvm/Support/GraphWriter.h"
57 #include "llvm/Support/Path.h"
58 #include "llvm/Support/Timer.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/Transforms/Utils/Cloning.h"
61 #include <algorithm>
62 #include <cassert>
63 #include <cmath>
64 #include <iterator>
65 #include <memory>
66 #include <utility>
67 #include <vector>
68 
69 #ifndef NDEBUG
70 #include "llvm/Support/LockFileManager.h"
71 #endif
72 
73 #define DEBUG_TYPE "amdgpu-split-module"
74 
75 namespace llvm {
76 namespace {
77 
78 static cl::opt<unsigned> MaxDepth(
79     "amdgpu-module-splitting-max-depth",
80     cl::desc(
81         "maximum search depth. 0 forces a greedy approach. "
82         "warning: the algorithm is up to O(2^N), where N is the max depth."),
83     cl::init(8));
84 
85 static cl::opt<float> LargeFnFactor(
86     "amdgpu-module-splitting-large-threshold", cl::init(2.0f), cl::Hidden,
87     cl::desc(
88         "when max depth is reached and we can no longer branch out, this "
89         "value determines if a function is worth merging into an already "
90         "existing partition to reduce code duplication. This is a factor "
91         "of the ideal partition size, e.g. 2.0 means we consider the "
92         "function for merging if its cost (including its callees) is 2x the "
93         "size of an ideal partition."));
94 
95 static cl::opt<float> LargeFnOverlapForMerge(
96     "amdgpu-module-splitting-merge-threshold", cl::init(0.7f), cl::Hidden,
97     cl::desc("when a function is considered for merging into a partition that "
98              "already contains some of its callees, do the merge if at least "
99              "n% of the code it can reach is already present inside the "
100              "partition; e.g. 0.7 means only merge >70%"));
101 
102 static cl::opt<bool> NoExternalizeGlobals(
103     "amdgpu-module-splitting-no-externalize-globals", cl::Hidden,
104     cl::desc("disables externalization of global variable with local linkage; "
105              "may cause globals to be duplicated which increases binary size"));
106 
107 static cl::opt<bool> NoExternalizeOnAddrTaken(
108     "amdgpu-module-splitting-no-externalize-address-taken", cl::Hidden,
109     cl::desc(
110         "disables externalization of functions whose addresses are taken"));
111 
112 static cl::opt<std::string>
113     ModuleDotCfgOutput("amdgpu-module-splitting-print-module-dotcfg",
114                        cl::Hidden,
115                        cl::desc("output file to write out the dotgraph "
116                                 "representation of the input module"));
117 
118 static cl::opt<std::string> PartitionSummariesOutput(
119     "amdgpu-module-splitting-print-partition-summaries", cl::Hidden,
120     cl::desc("output file to write out a summary of "
121              "the partitions created for each module"));
122 
123 #ifndef NDEBUG
124 static cl::opt<bool>
125     UseLockFile("amdgpu-module-splitting-serial-execution", cl::Hidden,
126                 cl::desc("use a lock file so only one process in the system "
127                          "can run this pass at once. useful to avoid mangled "
128                          "debug output in multithreaded environments."));
129 
130 static cl::opt<bool>
131     DebugProposalSearch("amdgpu-module-splitting-debug-proposal-search",
132                         cl::Hidden,
133                         cl::desc("print all proposals received and whether "
134                                  "they were rejected or accepted"));
135 #endif
136 
137 struct SplitModuleTimer : NamedRegionTimer {
138   SplitModuleTimer(StringRef Name, StringRef Desc)
139       : NamedRegionTimer(Name, Desc, DEBUG_TYPE, "AMDGPU Module Splitting",
140                          TimePassesIsEnabled) {}
141 };
142 
143 //===----------------------------------------------------------------------===//
144 // Utils
145 //===----------------------------------------------------------------------===//
146 
147 using CostType = InstructionCost::CostType;
148 using FunctionsCostMap = DenseMap<const Function *, CostType>;
149 using GetTTIFn = function_ref<const TargetTransformInfo &(Function &)>;
150 static constexpr unsigned InvalidPID = -1;
151 
152 /// \param Num numerator
153 /// \param Dem denominator
154 /// \returns a printable object to print (Num/Dem) using "%0.2f".
155 static auto formatRatioOf(CostType Num, CostType Dem) {
156   return format("%0.2f", (static_cast<double>(Num) / Dem) * 100);
157 }
158 
159 /// Checks whether a given function is non-copyable.
160 ///
161 /// Non-copyable functions cannot be cloned into multiple partitions, and only
162 /// one copy of the function can be present across all partitions.
163 ///
164 /// External functions fall into this category. If we were to clone them, we
165 /// would end up with multiple symbol definitions and a very unhappy linker.
166 static bool isNonCopyable(const Function &F) {
167   assert(AMDGPU::isEntryFunctionCC(F.getCallingConv())
168              ? F.hasExternalLinkage()
169              : true && "Kernel w/o external linkage?");
170   return F.hasExternalLinkage() || !F.isDefinitionExact();
171 }
172 
173 /// If \p GV has local linkage, make it external + hidden.
174 static void externalize(GlobalValue &GV) {
175   if (GV.hasLocalLinkage()) {
176     GV.setLinkage(GlobalValue::ExternalLinkage);
177     GV.setVisibility(GlobalValue::HiddenVisibility);
178   }
179 
180   // Unnamed entities must be named consistently between modules. setName will
181   // give a distinct name to each such entity.
182   if (!GV.hasName())
183     GV.setName("__llvmsplit_unnamed");
184 }
185 
186 /// Cost analysis function. Calculates the cost of each function in \p M
187 ///
188 /// \param GetTTI Abstract getter for TargetTransformInfo.
189 /// \param M Module to analyze.
190 /// \param CostMap[out] Resulting Function -> Cost map.
191 /// \return The module's total cost.
192 static CostType calculateFunctionCosts(GetTTIFn GetTTI, Module &M,
193                                        FunctionsCostMap &CostMap) {
194   SplitModuleTimer SMT("calculateFunctionCosts", "cost analysis");
195 
196   LLVM_DEBUG(dbgs() << "[cost analysis] calculating function costs\n");
197   CostType ModuleCost = 0;
198   [[maybe_unused]] CostType KernelCost = 0;
199 
200   for (auto &Fn : M) {
201     if (Fn.isDeclaration())
202       continue;
203 
204     CostType FnCost = 0;
205     const auto &TTI = GetTTI(Fn);
206     for (const auto &BB : Fn) {
207       for (const auto &I : BB) {
208         auto Cost =
209             TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
210         assert(Cost != InstructionCost::getMax());
211         // Assume expensive if we can't tell the cost of an instruction.
212         CostType CostVal =
213             Cost.getValue().value_or(TargetTransformInfo::TCC_Expensive);
214         assert((FnCost + CostVal) >= FnCost && "Overflow!");
215         FnCost += CostVal;
216       }
217     }
218 
219     assert(FnCost != 0);
220 
221     CostMap[&Fn] = FnCost;
222     assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!");
223     ModuleCost += FnCost;
224 
225     if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv()))
226       KernelCost += FnCost;
227   }
228 
229   if (CostMap.empty())
230     return 0;
231 
232   assert(ModuleCost);
233   LLVM_DEBUG({
234     const CostType FnCost = ModuleCost - KernelCost;
235     dbgs() << " - total module cost is " << ModuleCost << ". kernels cost "
236            << "" << KernelCost << " ("
237            << format("%0.2f", (float(KernelCost) / ModuleCost) * 100)
238            << "% of the module), functions cost " << FnCost << " ("
239            << format("%0.2f", (float(FnCost) / ModuleCost) * 100)
240            << "% of the module)\n";
241   });
242 
243   return ModuleCost;
244 }
245 
246 /// \return true if \p F can be indirectly called
247 static bool canBeIndirectlyCalled(const Function &F) {
248   if (F.isDeclaration() || AMDGPU::isEntryFunctionCC(F.getCallingConv()))
249     return false;
250   return !F.hasLocalLinkage() ||
251          F.hasAddressTaken(/*PutOffender=*/nullptr,
252                            /*IgnoreCallbackUses=*/false,
253                            /*IgnoreAssumeLikeCalls=*/true,
254                            /*IgnoreLLVMUsed=*/true,
255                            /*IgnoreARCAttachedCall=*/false,
256                            /*IgnoreCastedDirectCall=*/true);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // Graph-based Module Representation
261 //===----------------------------------------------------------------------===//
262 
263 /// AMDGPUSplitModule's view of the source Module, as a graph of all components
264 /// that can be split into different modules.
265 ///
266 /// The most trivial instance of this graph is just the CallGraph of the module,
267 /// but it is not guaranteed that the graph is strictly equal to the CG. It
268 /// currently always is but it's designed in a way that would eventually allow
269 /// us to create abstract nodes, or nodes for different entities such as global
270 /// variables or any other meaningful constraint we must consider.
271 ///
272 /// The graph is only mutable by this class, and is generally not modified
273 /// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
274 /// mutate it.
275 class SplitGraph {
276 public:
277   class Node;
278 
279   enum class EdgeKind : uint8_t {
280     /// The nodes are related through a direct call. This is a "strong" edge as
281     /// it means the Src will directly reference the Dst.
282     DirectCall,
283     /// The nodes are related through an indirect call.
284     /// This is a "weaker" edge and is only considered when traversing the graph
285     /// starting from a kernel. We need this edge for resource usage analysis.
286     ///
287     /// The reason why we have this edge in the first place is due to how
288     /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
289     /// the resource usage of the kernel containing the indirect call is the
290     /// max resource usage of all functions that can be indirectly called.
291     IndirectCall,
292   };
293 
294   /// An edge between two nodes. Edges are directional, and tagged with a
295   /// "kind".
296   struct Edge {
297     Edge(Node *Src, Node *Dst, EdgeKind Kind)
298         : Src(Src), Dst(Dst), Kind(Kind) {}
299 
300     Node *Src; ///< Source
301     Node *Dst; ///< Destination
302     EdgeKind Kind;
303   };
304 
305   using EdgesVec = SmallVector<const Edge *, 0>;
306   using edges_iterator = EdgesVec::const_iterator;
307   using nodes_iterator = const Node *const *;
308 
309   SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
310              CostType ModuleCost)
311       : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
312 
313   void buildGraph(CallGraph &CG);
314 
315 #ifndef NDEBUG
316   bool verifyGraph() const;
317 #endif
318 
319   bool empty() const { return Nodes.empty(); }
320   const iterator_range<nodes_iterator> nodes() const {
321     return {Nodes.begin(), Nodes.end()};
322   }
323   const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
324 
325   unsigned getNumNodes() const { return Nodes.size(); }
326   BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
327 
328   const Module &getModule() const { return M; }
329 
330   CostType getModuleCost() const { return ModuleCost; }
331   CostType getCost(const Function &F) const { return CostMap.at(&F); }
332 
333   /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
334   /// IDs).
335   CostType calculateCost(const BitVector &BV) const;
336 
337 private:
338   /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
339   /// updates \p Cache.
340   Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
341                 const GlobalValue &GV);
342 
343   // Create a new edge between two nodes and add it to both nodes.
344   const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
345 
346   const Module &M;
347   const FunctionsCostMap &CostMap;
348   CostType ModuleCost;
349 
350   // Final list of nodes with stable ordering.
351   SmallVector<Node *> Nodes;
352 
353   SpecificBumpPtrAllocator<Node> NodesPool;
354 
355   // Edges are trivially destructible objects, so as a small optimization we
356   // use a BumpPtrAllocator which avoids destructor calls but also makes
357   // allocation faster.
358   static_assert(
359       std::is_trivially_destructible_v<Edge>,
360       "Edge must be trivially destructible to use the BumpPtrAllocator");
361   BumpPtrAllocator EdgesPool;
362 };
363 
364 /// Nodes in the SplitGraph contain both incoming, and outgoing edges.
365 /// Incoming edges have this node as their Dst, and Outgoing ones have this node
366 /// as their Src.
367 ///
368 /// Edge objects are shared by both nodes in Src/Dst. They provide immediate
369 /// feedback on how two nodes are related, and in which direction they are
370 /// related, which is valuable information to make splitting decisions.
371 ///
372 /// Nodes are fundamentally abstract, and any consumers of the graph should
373 /// treat them as such. While a node will be a function most of the time, we
374 /// could also create nodes for any other reason. In the future, we could have
375 /// single nodes for multiple functions, or nodes for GVs, etc.
376 class SplitGraph::Node {
377   friend class SplitGraph;
378 
379 public:
380   Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
381        bool IsNonCopyable)
382       : ID(ID), GV(GV), IndividualCost(IndividualCost),
383         IsNonCopyable(IsNonCopyable), IsEntryFnCC(false), IsGraphEntry(false) {
384     if (auto *Fn = dyn_cast<Function>(&GV))
385       IsEntryFnCC = AMDGPU::isEntryFunctionCC(Fn->getCallingConv());
386   }
387 
388   /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
389   /// nodes in the graph. This ID can be used as an index in a BitVector.
390   unsigned getID() const { return ID; }
391 
392   const Function &getFunction() const { return cast<Function>(GV); }
393 
394   /// \returns the cost to import this component into a given module, not
395   /// accounting for any dependencies that may need to be imported as well.
396   CostType getIndividualCost() const { return IndividualCost; }
397 
398   bool isNonCopyable() const { return IsNonCopyable; }
399   bool isEntryFunctionCC() const { return IsEntryFnCC; }
400 
401   /// \returns whether this is an entry point in the graph. Entry points are
402   /// defined as follows: if you take all entry points in the graph, and iterate
403   /// their dependencies, you are guaranteed to visit all nodes in the graph at
404   /// least once.
405   bool isGraphEntryPoint() const { return IsGraphEntry; }
406 
407   StringRef getName() const { return GV.getName(); }
408 
409   bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
410   bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
411     return any_of(IncomingEdges, [&](const auto *E) { return E->Kind == EK; });
412   }
413 
414   bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
415   bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
416     return any_of(OutgoingEdges, [&](const auto *E) { return E->Kind == EK; });
417   }
418 
419   iterator_range<edges_iterator> incoming_edges() const {
420     return IncomingEdges;
421   }
422 
423   iterator_range<edges_iterator> outgoing_edges() const {
424     return OutgoingEdges;
425   }
426 
427   bool shouldFollowIndirectCalls() const { return isEntryFunctionCC(); }
428 
429   /// Visit all children of this node in a recursive fashion. Also visits Self.
430   /// If \ref shouldFollowIndirectCalls returns false, then this only follows
431   /// DirectCall edges.
432   ///
433   /// \param Visitor Visitor Function.
434   void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
435 
436   /// Adds the depedencies of this node in \p BV by setting the bit
437   /// corresponding to each node.
438   ///
439   /// Implemented using \ref visitAllDependencies, hence it follows the same
440   /// rules regarding dependencies traversal.
441   ///
442   /// \param[out] BV The bitvector where the bits should be set.
443   void getDependencies(BitVector &BV) const {
444     visitAllDependencies([&](const Node &N) { BV.set(N.getID()); });
445   }
446 
447 private:
448   void markAsGraphEntry() { IsGraphEntry = true; }
449 
450   unsigned ID;
451   const GlobalValue &GV;
452   CostType IndividualCost;
453   bool IsNonCopyable : 1;
454   bool IsEntryFnCC : 1;
455   bool IsGraphEntry : 1;
456 
457   // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
458   // together)
459   EdgesVec IncomingEdges;
460   EdgesVec OutgoingEdges;
461 };
462 
463 void SplitGraph::Node::visitAllDependencies(
464     std::function<void(const Node &)> Visitor) const {
465   const bool FollowIndirect = shouldFollowIndirectCalls();
466   // FIXME: If this can access SplitGraph in the future, use a BitVector
467   // instead.
468   DenseSet<const Node *> Seen;
469   SmallVector<const Node *, 8> WorkList({this});
470   while (!WorkList.empty()) {
471     const Node *CurN = WorkList.pop_back_val();
472     if (auto [It, Inserted] = Seen.insert(CurN); !Inserted)
473       continue;
474 
475     Visitor(*CurN);
476 
477     for (const Edge *E : CurN->outgoing_edges()) {
478       if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
479         continue;
480       WorkList.push_back(E->Dst);
481     }
482   }
483 }
484 
485 void SplitGraph::buildGraph(CallGraph &CG) {
486   SplitModuleTimer SMT("buildGraph", "graph construction");
487   LLVM_DEBUG(
488       dbgs()
489       << "[build graph] constructing graph representation of the input\n");
490 
491   // FIXME(?): Is the callgraph really worth using if we have to iterate the
492   // function again whenever it fails to give us enough information?
493 
494   // We build the graph by just iterating all functions in the module and
495   // working on their direct callees. At the end, all nodes should be linked
496   // together as expected.
497   DenseMap<const GlobalValue *, Node *> Cache;
498   SmallVector<const Function *> FnsWithIndirectCalls, IndirectlyCallableFns;
499   for (const Function &Fn : M) {
500     if (Fn.isDeclaration())
501       continue;
502 
503     // Look at direct callees and create the necessary edges in the graph.
504     SetVector<const Function *> DirectCallees;
505     bool CallsExternal = false;
506     for (auto &CGEntry : *CG[&Fn]) {
507       auto *CGNode = CGEntry.second;
508       if (auto *Callee = CGNode->getFunction()) {
509         if (!Callee->isDeclaration())
510           DirectCallees.insert(Callee);
511       } else if (CGNode == CG.getCallsExternalNode())
512         CallsExternal = true;
513     }
514 
515     // Keep track of this function if it contains an indirect call and/or if it
516     // can be indirectly called.
517     if (CallsExternal) {
518       LLVM_DEBUG(dbgs() << "  [!] callgraph is incomplete for ";
519                  Fn.printAsOperand(dbgs());
520                  dbgs() << " - analyzing function\n");
521 
522       bool HasIndirectCall = false;
523       for (const auto &Inst : instructions(Fn)) {
524         // look at all calls without a direct callee.
525         if (const auto *CB = dyn_cast<CallBase>(&Inst);
526             CB && !CB->getCalledFunction()) {
527           // inline assembly can be ignored, unless InlineAsmIsIndirectCall is
528           // true.
529           if (CB->isInlineAsm()) {
530             LLVM_DEBUG(dbgs() << "    found inline assembly\n");
531             continue;
532           }
533 
534           // everything else is handled conservatively.
535           HasIndirectCall = true;
536           break;
537         }
538       }
539 
540       if (HasIndirectCall) {
541         LLVM_DEBUG(dbgs() << "    indirect call found\n");
542         FnsWithIndirectCalls.push_back(&Fn);
543       }
544     }
545 
546     Node &N = getNode(Cache, Fn);
547     for (const auto *Callee : DirectCallees)
548       createEdge(N, getNode(Cache, *Callee), EdgeKind::DirectCall);
549 
550     if (canBeIndirectlyCalled(Fn))
551       IndirectlyCallableFns.push_back(&Fn);
552   }
553 
554   // Post-process functions with indirect calls.
555   for (const Function *Fn : FnsWithIndirectCalls) {
556     for (const Function *Candidate : IndirectlyCallableFns) {
557       Node &Src = getNode(Cache, *Fn);
558       Node &Dst = getNode(Cache, *Candidate);
559       createEdge(Src, Dst, EdgeKind::IndirectCall);
560     }
561   }
562 
563   // Now, find all entry points.
564   SmallVector<Node *, 16> CandidateEntryPoints;
565   BitVector NodesReachableByKernels = createNodesBitVector();
566   for (Node *N : Nodes) {
567     // Functions with an Entry CC are always graph entry points too.
568     if (N->isEntryFunctionCC()) {
569       N->markAsGraphEntry();
570       N->getDependencies(NodesReachableByKernels);
571     } else if (!N->hasAnyIncomingEdgesOfKind(EdgeKind::DirectCall))
572       CandidateEntryPoints.push_back(N);
573   }
574 
575   for (Node *N : CandidateEntryPoints) {
576     // This can be another entry point if it's not reachable by a kernel
577     // TODO: We could sort all of the possible new entries in a stable order
578     // (e.g. by cost), then consume them one by one until
579     // NodesReachableByKernels is all 1s. It'd allow us to avoid
580     // considering some nodes as non-entries in some specific cases.
581     if (!NodesReachableByKernels.test(N->getID()))
582       N->markAsGraphEntry();
583   }
584 
585 #ifndef NDEBUG
586   assert(verifyGraph());
587 #endif
588 }
589 
590 #ifndef NDEBUG
591 bool SplitGraph::verifyGraph() const {
592   unsigned ExpectedID = 0;
593   // Exceptionally using a set here in case IDs are messed up.
594   DenseSet<const Node *> SeenNodes;
595   DenseSet<const Function *> SeenFunctionNodes;
596   for (const Node *N : Nodes) {
597     if (N->getID() != (ExpectedID++)) {
598       errs() << "Node IDs are incorrect!\n";
599       return false;
600     }
601 
602     if (!SeenNodes.insert(N).second) {
603       errs() << "Node seen more than once!\n";
604       return false;
605     }
606 
607     if (&getNode(N->getID()) != N) {
608       errs() << "getNode doesn't return the right node\n";
609       return false;
610     }
611 
612     for (const Edge *E : N->IncomingEdges) {
613       if (!E->Src || !E->Dst || (E->Dst != N) ||
614           (find(E->Src->OutgoingEdges, E) == E->Src->OutgoingEdges.end())) {
615         errs() << "ill-formed incoming edges\n";
616         return false;
617       }
618     }
619 
620     for (const Edge *E : N->OutgoingEdges) {
621       if (!E->Src || !E->Dst || (E->Src != N) ||
622           (find(E->Dst->IncomingEdges, E) == E->Dst->IncomingEdges.end())) {
623         errs() << "ill-formed outgoing edges\n";
624         return false;
625       }
626     }
627 
628     const Function &Fn = N->getFunction();
629     if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv())) {
630       if (N->hasAnyIncomingEdges()) {
631         errs() << "Kernels cannot have incoming edges\n";
632         return false;
633       }
634     }
635 
636     if (Fn.isDeclaration()) {
637       errs() << "declarations shouldn't have nodes!\n";
638       return false;
639     }
640 
641     auto [It, Inserted] = SeenFunctionNodes.insert(&Fn);
642     if (!Inserted) {
643       errs() << "one function has multiple nodes!\n";
644       return false;
645     }
646   }
647 
648   if (ExpectedID != Nodes.size()) {
649     errs() << "Node IDs out of sync!\n";
650     return false;
651   }
652 
653   if (createNodesBitVector().size() != getNumNodes()) {
654     errs() << "nodes bit vector doesn't have the right size!\n";
655     return false;
656   }
657 
658   // Check we respect the promise of Node::isKernel
659   BitVector BV = createNodesBitVector();
660   for (const Node *N : nodes()) {
661     if (N->isGraphEntryPoint())
662       N->getDependencies(BV);
663   }
664 
665   // Ensure each function in the module has an associated node.
666   for (const auto &Fn : M) {
667     if (!Fn.isDeclaration()) {
668       if (!SeenFunctionNodes.contains(&Fn)) {
669         errs() << "Fn has no associated node in the graph!\n";
670         return false;
671       }
672     }
673   }
674 
675   if (!BV.all()) {
676     errs() << "not all nodes are reachable through the graph's entry points!\n";
677     return false;
678   }
679 
680   return true;
681 }
682 #endif
683 
684 CostType SplitGraph::calculateCost(const BitVector &BV) const {
685   CostType Cost = 0;
686   for (unsigned NodeID : BV.set_bits())
687     Cost += getNode(NodeID).getIndividualCost();
688   return Cost;
689 }
690 
691 SplitGraph::Node &
692 SplitGraph::getNode(DenseMap<const GlobalValue *, Node *> &Cache,
693                     const GlobalValue &GV) {
694   auto &N = Cache[&GV];
695   if (N)
696     return *N;
697 
698   CostType Cost = 0;
699   bool NonCopyable = false;
700   if (const Function *Fn = dyn_cast<Function>(&GV)) {
701     NonCopyable = isNonCopyable(*Fn);
702     Cost = CostMap.at(Fn);
703   }
704   N = new (NodesPool.Allocate()) Node(Nodes.size(), GV, Cost, NonCopyable);
705   Nodes.push_back(N);
706   assert(&getNode(N->getID()) == N);
707   return *N;
708 }
709 
710 const SplitGraph::Edge &SplitGraph::createEdge(Node &Src, Node &Dst,
711                                                EdgeKind EK) {
712   const Edge *E = new (EdgesPool.Allocate<Edge>(1)) Edge(&Src, &Dst, EK);
713   Src.OutgoingEdges.push_back(E);
714   Dst.IncomingEdges.push_back(E);
715   return *E;
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // Split Proposals
720 //===----------------------------------------------------------------------===//
721 
722 /// Represents a module splitting proposal.
723 ///
724 /// Proposals are made of N BitVectors, one for each partition, where each bit
725 /// set indicates that the node is present and should be copied inside that
726 /// partition.
727 ///
728 /// Proposals have several metrics attached so they can be compared/sorted,
729 /// which the driver to try multiple strategies resultings in multiple proposals
730 /// and choose the best one out of them.
731 class SplitProposal {
732 public:
733   SplitProposal(const SplitGraph &SG, unsigned MaxPartitions) : SG(&SG) {
734     Partitions.resize(MaxPartitions, {0, SG.createNodesBitVector()});
735   }
736 
737   void setName(StringRef NewName) { Name = NewName; }
738   StringRef getName() const { return Name; }
739 
740   const BitVector &operator[](unsigned PID) const {
741     return Partitions[PID].second;
742   }
743 
744   void add(unsigned PID, const BitVector &BV) {
745     Partitions[PID].second |= BV;
746     updateScore(PID);
747   }
748 
749   void print(raw_ostream &OS) const;
750   LLVM_DUMP_METHOD void dump() const { print(dbgs()); }
751 
752   // Find the cheapest partition (lowest cost). In case of ties, always returns
753   // the highest partition number.
754   unsigned findCheapestPartition() const;
755 
756   /// Calculate the CodeSize and Bottleneck scores.
757   void calculateScores();
758 
759 #ifndef NDEBUG
760   void verifyCompleteness() const;
761 #endif
762 
763   /// Only available after \ref calculateScores is called.
764   ///
765   /// A positive number indicating the % of code duplication that this proposal
766   /// creates. e.g. 0.2 means this proposal adds roughly 20% code size by
767   /// duplicating some functions across partitions.
768   ///
769   /// Value is always rounded up to 3 decimal places.
770   ///
771   /// A perfect score would be 0.0, and anything approaching 1.0 is very bad.
772   double getCodeSizeScore() const { return CodeSizeScore; }
773 
774   /// Only available after \ref calculateScores is called.
775   ///
776   /// A number between [0, 1] which indicates how big of a bottleneck is
777   /// expected from the largest partition.
778   ///
779   /// A score of 1.0 means the biggest partition is as big as the source module,
780   /// so build time will be equal to or greater than the build time of the
781   /// initial input.
782   ///
783   /// Value is always rounded up to 3 decimal places.
784   ///
785   /// This is one of the metrics used to estimate this proposal's build time.
786   double getBottleneckScore() const { return BottleneckScore; }
787 
788 private:
789   void updateScore(unsigned PID) {
790     assert(SG);
791     for (auto &[PCost, Nodes] : Partitions) {
792       TotalCost -= PCost;
793       PCost = SG->calculateCost(Nodes);
794       TotalCost += PCost;
795     }
796   }
797 
798   /// \see getCodeSizeScore
799   double CodeSizeScore = 0.0;
800   /// \see getBottleneckScore
801   double BottleneckScore = 0.0;
802   /// Aggregated cost of all partitions
803   CostType TotalCost = 0;
804 
805   const SplitGraph *SG = nullptr;
806   std::string Name;
807 
808   std::vector<std::pair<CostType, BitVector>> Partitions;
809 };
810 
811 void SplitProposal::print(raw_ostream &OS) const {
812   assert(SG);
813 
814   OS << "[proposal] " << Name << ", total cost:" << TotalCost
815      << ", code size score:" << format("%0.3f", CodeSizeScore)
816      << ", bottleneck score:" << format("%0.3f", BottleneckScore) << '\n';
817   for (const auto &[PID, Part] : enumerate(Partitions)) {
818     const auto &[Cost, NodeIDs] = Part;
819     OS << "  - P" << PID << " nodes:" << NodeIDs.count() << " cost: " << Cost
820        << '|' << formatRatioOf(Cost, SG->getModuleCost()) << "%\n";
821   }
822 }
823 
824 unsigned SplitProposal::findCheapestPartition() const {
825   assert(!Partitions.empty());
826   CostType CurCost = std::numeric_limits<CostType>::max();
827   unsigned CurPID = InvalidPID;
828   for (const auto &[Idx, Part] : enumerate(Partitions)) {
829     if (Part.first <= CurCost) {
830       CurPID = Idx;
831       CurCost = Part.first;
832     }
833   }
834   assert(CurPID != InvalidPID);
835   return CurPID;
836 }
837 
838 void SplitProposal::calculateScores() {
839   if (Partitions.empty())
840     return;
841 
842   assert(SG);
843   CostType LargestPCost = 0;
844   for (auto &[PCost, Nodes] : Partitions) {
845     if (PCost > LargestPCost)
846       LargestPCost = PCost;
847   }
848 
849   CostType ModuleCost = SG->getModuleCost();
850   CodeSizeScore = double(TotalCost) / ModuleCost;
851   assert(CodeSizeScore >= 0.0);
852 
853   BottleneckScore = double(LargestPCost) / ModuleCost;
854 
855   CodeSizeScore = std::ceil(CodeSizeScore * 100.0) / 100.0;
856   BottleneckScore = std::ceil(BottleneckScore * 100.0) / 100.0;
857 }
858 
859 #ifndef NDEBUG
860 void SplitProposal::verifyCompleteness() const {
861   if (Partitions.empty())
862     return;
863 
864   BitVector Result = Partitions[0].second;
865   for (const auto &P : drop_begin(Partitions))
866     Result |= P.second;
867   assert(Result.all() && "some nodes are missing from this proposal!");
868 }
869 #endif
870 
871 //===-- RecursiveSearchStrategy -------------------------------------------===//
872 
873 /// Partitioning algorithm.
874 ///
875 /// This is a recursive search algorithm that can explore multiple possiblities.
876 ///
877 /// When a cluster of nodes can go into more than one partition, and we haven't
878 /// reached maximum search depth, we recurse and explore both options and their
879 /// consequences. Both branches will yield a proposal, and the driver will grade
880 /// both and choose the best one.
881 ///
882 /// If max depth is reached, we will use some heuristics to make a choice. Most
883 /// of the time we will just use the least-pressured (cheapest) partition, but
884 /// if a cluster is particularly big and there is a good amount of overlap with
885 /// an existing partition, we will choose that partition instead.
886 class RecursiveSearchSplitting {
887 public:
888   using SubmitProposalFn = function_ref<void(SplitProposal)>;
889 
890   RecursiveSearchSplitting(const SplitGraph &SG, unsigned NumParts,
891                            SubmitProposalFn SubmitProposal);
892 
893   void run();
894 
895 private:
896   struct WorkListEntry {
897     WorkListEntry(const BitVector &BV) : Cluster(BV) {}
898 
899     unsigned NumNonEntryNodes = 0;
900     CostType TotalCost = 0;
901     CostType CostExcludingGraphEntryPoints = 0;
902     BitVector Cluster;
903   };
904 
905   /// Collects all graph entry points's clusters and sort them so the most
906   /// expensive clusters are viewed first. This will merge clusters together if
907   /// they share a non-copyable dependency.
908   void setupWorkList();
909 
910   /// Recursive function that assigns the worklist item at \p Idx into a
911   /// partition of \p SP.
912   ///
913   /// \p Depth is the current search depth. When this value is equal to
914   /// \ref MaxDepth, we can no longer recurse.
915   ///
916   /// This function only recurses if there is more than one possible assignment,
917   /// otherwise it is iterative to avoid creating a call stack that is as big as
918   /// \ref WorkList.
919   void pickPartition(unsigned Depth, unsigned Idx, SplitProposal SP);
920 
921   /// \return A pair: first element is the PID of the partition that has the
922   /// most similarities with \p Entry, or \ref InvalidPID if no partition was
923   /// found with at least one element in common. The second element is the
924   /// aggregated cost of all dependencies in common between \p Entry and that
925   /// partition.
926   std::pair<unsigned, CostType>
927   findMostSimilarPartition(const WorkListEntry &Entry, const SplitProposal &SP);
928 
929   const SplitGraph &SG;
930   unsigned NumParts;
931   SubmitProposalFn SubmitProposal;
932 
933   // A Cluster is considered large when its cost, excluding entry points,
934   // exceeds this value.
935   CostType LargeClusterThreshold = 0;
936   unsigned NumProposalsSubmitted = 0;
937   SmallVector<WorkListEntry> WorkList;
938 };
939 
940 RecursiveSearchSplitting::RecursiveSearchSplitting(
941     const SplitGraph &SG, unsigned NumParts, SubmitProposalFn SubmitProposal)
942     : SG(SG), NumParts(NumParts), SubmitProposal(SubmitProposal) {
943   // arbitrary max value as a safeguard. Anything above 10 will already be
944   // slow, this is just a max value to prevent extreme resource exhaustion or
945   // unbounded run time.
946   if (MaxDepth > 16)
947     report_fatal_error("[amdgpu-split-module] search depth of " +
948                        Twine(MaxDepth) + " is too high!");
949   LargeClusterThreshold =
950       (LargeFnFactor != 0.0)
951           ? CostType(((SG.getModuleCost() / NumParts) * LargeFnFactor))
952           : std::numeric_limits<CostType>::max();
953   LLVM_DEBUG(dbgs() << "[recursive search] large cluster threshold set at "
954                     << LargeClusterThreshold << "\n");
955 }
956 
957 void RecursiveSearchSplitting::run() {
958   {
959     SplitModuleTimer SMT("recursive_search_prepare", "preparing worklist");
960     setupWorkList();
961   }
962 
963   {
964     SplitModuleTimer SMT("recursive_search_pick", "partitioning");
965     SplitProposal SP(SG, NumParts);
966     pickPartition(/*BranchDepth=*/0, /*Idx=*/0, SP);
967   }
968 }
969 
970 void RecursiveSearchSplitting::setupWorkList() {
971   // e.g. if A and B are two worklist item, and they both call a non copyable
972   // dependency C, this does:
973   //    A=C
974   //    B=C
975   // => NodeEC will create a single group (A, B, C) and we create a new
976   // WorkList entry for that group.
977 
978   EquivalenceClasses<unsigned> NodeEC;
979   for (const SplitGraph::Node *N : SG.nodes()) {
980     if (!N->isGraphEntryPoint())
981       continue;
982 
983     NodeEC.insert(N->getID());
984     N->visitAllDependencies([&](const SplitGraph::Node &Dep) {
985       if (&Dep != N && Dep.isNonCopyable())
986         NodeEC.unionSets(N->getID(), Dep.getID());
987     });
988   }
989 
990   for (auto I = NodeEC.begin(), E = NodeEC.end(); I != E; ++I) {
991     if (!I->isLeader())
992       continue;
993 
994     BitVector Cluster = SG.createNodesBitVector();
995     for (auto MI = NodeEC.member_begin(I); MI != NodeEC.member_end(); ++MI) {
996       const SplitGraph::Node &N = SG.getNode(*MI);
997       if (N.isGraphEntryPoint())
998         N.getDependencies(Cluster);
999     }
1000     WorkList.emplace_back(std::move(Cluster));
1001   }
1002 
1003   // Calculate costs and other useful information.
1004   for (WorkListEntry &Entry : WorkList) {
1005     for (unsigned NodeID : Entry.Cluster.set_bits()) {
1006       const SplitGraph::Node &N = SG.getNode(NodeID);
1007       const CostType Cost = N.getIndividualCost();
1008 
1009       Entry.TotalCost += Cost;
1010       if (!N.isGraphEntryPoint()) {
1011         Entry.CostExcludingGraphEntryPoints += Cost;
1012         ++Entry.NumNonEntryNodes;
1013       }
1014     }
1015   }
1016 
1017   stable_sort(WorkList, [](const WorkListEntry &A, const WorkListEntry &B) {
1018     if (A.TotalCost != B.TotalCost)
1019       return A.TotalCost > B.TotalCost;
1020 
1021     if (A.CostExcludingGraphEntryPoints != B.CostExcludingGraphEntryPoints)
1022       return A.CostExcludingGraphEntryPoints > B.CostExcludingGraphEntryPoints;
1023 
1024     if (A.NumNonEntryNodes != B.NumNonEntryNodes)
1025       return A.NumNonEntryNodes > B.NumNonEntryNodes;
1026 
1027     return A.Cluster.count() > B.Cluster.count();
1028   });
1029 
1030   LLVM_DEBUG({
1031     dbgs() << "[recursive search] worklist:\n";
1032     for (const auto &[Idx, Entry] : enumerate(WorkList)) {
1033       dbgs() << "  - [" << Idx << "]: ";
1034       for (unsigned NodeID : Entry.Cluster.set_bits())
1035         dbgs() << NodeID << " ";
1036       dbgs() << "(total_cost:" << Entry.TotalCost
1037              << ", cost_excl_entries:" << Entry.CostExcludingGraphEntryPoints
1038              << ")\n";
1039     }
1040   });
1041 }
1042 
1043 void RecursiveSearchSplitting::pickPartition(unsigned Depth, unsigned Idx,
1044                                              SplitProposal SP) {
1045   while (Idx < WorkList.size()) {
1046     // Step 1: Determine candidate PIDs.
1047     //
1048     const WorkListEntry &Entry = WorkList[Idx];
1049     const BitVector &Cluster = Entry.Cluster;
1050 
1051     // Default option is to do load-balancing, AKA assign to least pressured
1052     // partition.
1053     const unsigned CheapestPID = SP.findCheapestPartition();
1054     assert(CheapestPID != InvalidPID);
1055 
1056     // Explore assigning to the kernel that contains the most dependencies in
1057     // common.
1058     const auto [MostSimilarPID, SimilarDepsCost] =
1059         findMostSimilarPartition(Entry, SP);
1060 
1061     // We can chose to explore only one path if we only have one valid path, or
1062     // if we reached maximum search depth and can no longer branch out.
1063     unsigned SinglePIDToTry = InvalidPID;
1064     if (MostSimilarPID == InvalidPID) // no similar PID found
1065       SinglePIDToTry = CheapestPID;
1066     else if (MostSimilarPID == CheapestPID) // both landed on the same PID
1067       SinglePIDToTry = CheapestPID;
1068     else if (Depth >= MaxDepth) {
1069       // We have to choose one path. Use a heuristic to guess which one will be
1070       // more appropriate.
1071       if (Entry.CostExcludingGraphEntryPoints > LargeClusterThreshold) {
1072         // Check if the amount of code in common makes it worth it.
1073         assert(SimilarDepsCost && Entry.CostExcludingGraphEntryPoints);
1074         const double Ratio =
1075             SimilarDepsCost / Entry.CostExcludingGraphEntryPoints;
1076         assert(Ratio >= 0.0 && Ratio <= 1.0);
1077         if (LargeFnOverlapForMerge > Ratio) {
1078           // For debug, just print "L", so we'll see "L3=P3" for instance, which
1079           // will mean we reached max depth and chose P3 based on this
1080           // heuristic.
1081           LLVM_DEBUG(dbgs() << 'L');
1082           SinglePIDToTry = MostSimilarPID;
1083         }
1084       } else
1085         SinglePIDToTry = CheapestPID;
1086     }
1087 
1088     // Step 2: Explore candidates.
1089 
1090     // When we only explore one possible path, and thus branch depth doesn't
1091     // increase, do not recurse, iterate instead.
1092     if (SinglePIDToTry != InvalidPID) {
1093       LLVM_DEBUG(dbgs() << Idx << "=P" << SinglePIDToTry << ' ');
1094       // Only one path to explore, don't clone SP, don't increase depth.
1095       SP.add(SinglePIDToTry, Cluster);
1096       ++Idx;
1097       continue;
1098     }
1099 
1100     assert(MostSimilarPID != InvalidPID);
1101 
1102     // We explore multiple paths: recurse at increased depth, then stop this
1103     // function.
1104 
1105     LLVM_DEBUG(dbgs() << '\n');
1106 
1107     // lb = load balancing = put in cheapest partition
1108     {
1109       SplitProposal BranchSP = SP;
1110       LLVM_DEBUG(dbgs().indent(Depth)
1111                  << " [lb] " << Idx << "=P" << CheapestPID << "? ");
1112       BranchSP.add(CheapestPID, Cluster);
1113       pickPartition(Depth + 1, Idx + 1, BranchSP);
1114     }
1115 
1116     // ms = most similar = put in partition with the most in common
1117     {
1118       SplitProposal BranchSP = SP;
1119       LLVM_DEBUG(dbgs().indent(Depth)
1120                  << " [ms] " << Idx << "=P" << MostSimilarPID << "? ");
1121       BranchSP.add(MostSimilarPID, Cluster);
1122       pickPartition(Depth + 1, Idx + 1, BranchSP);
1123     }
1124 
1125     return;
1126   }
1127 
1128   // Step 3: If we assigned all WorkList items, submit the proposal.
1129 
1130   assert(Idx == WorkList.size());
1131   assert(NumProposalsSubmitted <= (2u << MaxDepth) &&
1132          "Search got out of bounds?");
1133   SP.setName("recursive_search (depth=" + std::to_string(Depth) + ") #" +
1134              std::to_string(NumProposalsSubmitted++));
1135   LLVM_DEBUG(dbgs() << '\n');
1136   SubmitProposal(SP);
1137 }
1138 
1139 std::pair<unsigned, CostType>
1140 RecursiveSearchSplitting::findMostSimilarPartition(const WorkListEntry &Entry,
1141                                                    const SplitProposal &SP) {
1142   if (!Entry.NumNonEntryNodes)
1143     return {InvalidPID, 0};
1144 
1145   // We take the partition that is the most similar using Cost as a metric.
1146   // So we take the set of nodes in common, compute their aggregated cost, and
1147   // pick the partition with the highest cost in common.
1148   unsigned ChosenPID = InvalidPID;
1149   CostType ChosenCost = 0;
1150   for (unsigned PID = 0; PID < NumParts; ++PID) {
1151     BitVector BV = SP[PID];
1152     BV &= Entry.Cluster; // FIXME: & doesn't work between BVs?!
1153 
1154     if (BV.none())
1155       continue;
1156 
1157     const CostType Cost = SG.calculateCost(BV);
1158 
1159     if (ChosenPID == InvalidPID || ChosenCost < Cost ||
1160         (ChosenCost == Cost && PID > ChosenPID)) {
1161       ChosenPID = PID;
1162       ChosenCost = Cost;
1163     }
1164   }
1165 
1166   return {ChosenPID, ChosenCost};
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // DOTGraph Printing Support
1171 //===----------------------------------------------------------------------===//
1172 
1173 const SplitGraph::Node *mapEdgeToDst(const SplitGraph::Edge *E) {
1174   return E->Dst;
1175 }
1176 
1177 using SplitGraphEdgeDstIterator =
1178     mapped_iterator<SplitGraph::edges_iterator, decltype(&mapEdgeToDst)>;
1179 
1180 } // namespace
1181 
1182 template <> struct GraphTraits<SplitGraph> {
1183   using NodeRef = const SplitGraph::Node *;
1184   using nodes_iterator = SplitGraph::nodes_iterator;
1185   using ChildIteratorType = SplitGraphEdgeDstIterator;
1186 
1187   using EdgeRef = const SplitGraph::Edge *;
1188   using ChildEdgeIteratorType = SplitGraph::edges_iterator;
1189 
1190   static NodeRef getEntryNode(NodeRef N) { return N; }
1191 
1192   static ChildIteratorType child_begin(NodeRef Ref) {
1193     return {Ref->outgoing_edges().begin(), mapEdgeToDst};
1194   }
1195   static ChildIteratorType child_end(NodeRef Ref) {
1196     return {Ref->outgoing_edges().end(), mapEdgeToDst};
1197   }
1198 
1199   static nodes_iterator nodes_begin(const SplitGraph &G) {
1200     return G.nodes().begin();
1201   }
1202   static nodes_iterator nodes_end(const SplitGraph &G) {
1203     return G.nodes().end();
1204   }
1205 };
1206 
1207 template <> struct DOTGraphTraits<SplitGraph> : public DefaultDOTGraphTraits {
1208   DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
1209 
1210   static std::string getGraphName(const SplitGraph &SG) {
1211     return SG.getModule().getName().str();
1212   }
1213 
1214   std::string getNodeLabel(const SplitGraph::Node *N, const SplitGraph &SG) {
1215     return N->getName().str();
1216   }
1217 
1218   static std::string getNodeDescription(const SplitGraph::Node *N,
1219                                         const SplitGraph &SG) {
1220     std::string Result;
1221     if (N->isEntryFunctionCC())
1222       Result += "entry-fn-cc ";
1223     if (N->isNonCopyable())
1224       Result += "non-copyable ";
1225     Result += "cost:" + std::to_string(N->getIndividualCost());
1226     return Result;
1227   }
1228 
1229   static std::string getNodeAttributes(const SplitGraph::Node *N,
1230                                        const SplitGraph &SG) {
1231     return N->hasAnyIncomingEdges() ? "" : "color=\"red\"";
1232   }
1233 
1234   static std::string getEdgeAttributes(const SplitGraph::Node *N,
1235                                        SplitGraphEdgeDstIterator EI,
1236                                        const SplitGraph &SG) {
1237 
1238     switch ((*EI.getCurrent())->Kind) {
1239     case SplitGraph::EdgeKind::DirectCall:
1240       return "";
1241     case SplitGraph::EdgeKind::IndirectCall:
1242       return "style=\"dashed\"";
1243     }
1244     llvm_unreachable("Unknown SplitGraph::EdgeKind enum");
1245   }
1246 };
1247 
1248 //===----------------------------------------------------------------------===//
1249 // Driver
1250 //===----------------------------------------------------------------------===//
1251 
1252 namespace {
1253 
1254 // If we didn't externalize GVs, then local GVs need to be conservatively
1255 // imported into every module (including their initializers), and then cleaned
1256 // up afterwards.
1257 static bool needsConservativeImport(const GlobalValue *GV) {
1258   if (const auto *Var = dyn_cast<GlobalVariable>(GV))
1259     return Var->hasLocalLinkage();
1260   return isa<GlobalAlias>(GV);
1261 }
1262 
1263 /// Prints a summary of the partition \p N, represented by module \p M, to \p
1264 /// OS.
1265 static void printPartitionSummary(raw_ostream &OS, unsigned N, const Module &M,
1266                                   unsigned PartCost, unsigned ModuleCost) {
1267   OS << "*** Partition P" << N << " ***\n";
1268 
1269   for (const auto &Fn : M) {
1270     if (!Fn.isDeclaration())
1271       OS << " - [function] " << Fn.getName() << "\n";
1272   }
1273 
1274   for (const auto &GV : M.globals()) {
1275     if (GV.hasInitializer())
1276       OS << " - [global] " << GV.getName() << "\n";
1277   }
1278 
1279   OS << "Partition contains " << formatRatioOf(PartCost, ModuleCost)
1280      << "% of the source\n";
1281 }
1282 
1283 static void evaluateProposal(SplitProposal &Best, SplitProposal New) {
1284   SplitModuleTimer SMT("proposal_evaluation", "proposal ranking algorithm");
1285 
1286   LLVM_DEBUG({
1287     New.verifyCompleteness();
1288     if (DebugProposalSearch)
1289       New.print(dbgs());
1290   });
1291 
1292   const double CurBScore = Best.getBottleneckScore();
1293   const double CurCSScore = Best.getCodeSizeScore();
1294   const double NewBScore = New.getBottleneckScore();
1295   const double NewCSScore = New.getCodeSizeScore();
1296 
1297   // TODO: Improve this
1298   //    We can probably lower the precision of the comparison at first
1299   //    e.g. if we have
1300   //      - (Current): BScore: 0.489 CSCore 1.105
1301   //      - (New): BScore: 0.475 CSCore 1.305
1302   //    Currently we'd choose the new one because the bottleneck score is
1303   //    lower, but the new one duplicates more code. It may be worth it to
1304   //    discard the new proposal as the impact on build time is negligible.
1305 
1306   // Compare them
1307   bool IsBest = false;
1308   if (NewBScore < CurBScore)
1309     IsBest = true;
1310   else if (NewBScore == CurBScore)
1311     IsBest = (NewCSScore < CurCSScore); // Use code size as tie breaker.
1312 
1313   if (IsBest)
1314     Best = std::move(New);
1315 
1316   LLVM_DEBUG(if (DebugProposalSearch) {
1317     if (IsBest)
1318       dbgs() << "[search] new best proposal!\n";
1319     else
1320       dbgs() << "[search] discarding - not profitable\n";
1321   });
1322 }
1323 
1324 /// Trivial helper to create an identical copy of \p M.
1325 static std::unique_ptr<Module> cloneAll(const Module &M) {
1326   ValueToValueMapTy VMap;
1327   return CloneModule(M, VMap, [&](const GlobalValue *GV) { return true; });
1328 }
1329 
1330 /// Writes \p SG as a DOTGraph to \ref ModuleDotCfgDir if requested.
1331 static void writeDOTGraph(const SplitGraph &SG) {
1332   if (ModuleDotCfgOutput.empty())
1333     return;
1334 
1335   std::error_code EC;
1336   raw_fd_ostream OS(ModuleDotCfgOutput, EC);
1337   if (EC) {
1338     errs() << "[" DEBUG_TYPE "]: cannot open '" << ModuleDotCfgOutput
1339            << "' - DOTGraph will not be printed\n";
1340   }
1341   WriteGraph(OS, SG, /*ShortName=*/false,
1342              /*Title=*/SG.getModule().getName());
1343 }
1344 
1345 static void splitAMDGPUModule(
1346     GetTTIFn GetTTI, Module &M, unsigned NumParts,
1347     function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
1348   CallGraph CG(M);
1349 
1350   // Externalize functions whose address are taken.
1351   //
1352   // This is needed because partitioning is purely based on calls, but sometimes
1353   // a kernel/function may just look at the address of another local function
1354   // and not do anything (no calls). After partitioning, that local function may
1355   // end up in a different module (so it's just a declaration in the module
1356   // where its address is taken), which emits a "undefined hidden symbol" linker
1357   // error.
1358   //
1359   // Additionally, it guides partitioning to not duplicate this function if it's
1360   // called directly at some point.
1361   //
1362   // TODO: Could we be smarter about this ? This makes all functions whose
1363   // addresses are taken non-copyable. We should probably model this type of
1364   // constraint in the graph and use it to guide splitting, instead of
1365   // externalizing like this. Maybe non-copyable should really mean "keep one
1366   // visible copy, then internalize all other copies" for some functions?
1367   if (!NoExternalizeOnAddrTaken) {
1368     for (auto &Fn : M) {
1369       // TODO: Should aliases count? Probably not but they're so rare I'm not
1370       // sure it's worth fixing.
1371       if (Fn.hasLocalLinkage() && Fn.hasAddressTaken()) {
1372         LLVM_DEBUG(dbgs() << "[externalize] "; Fn.printAsOperand(dbgs());
1373                    dbgs() << " because its address is taken\n");
1374         externalize(Fn);
1375       }
1376     }
1377   }
1378 
1379   // Externalize local GVs, which avoids duplicating their initializers, which
1380   // in turns helps keep code size in check.
1381   if (!NoExternalizeGlobals) {
1382     for (auto &GV : M.globals()) {
1383       if (GV.hasLocalLinkage())
1384         LLVM_DEBUG(dbgs() << "[externalize] GV " << GV.getName() << '\n');
1385       externalize(GV);
1386     }
1387   }
1388 
1389   // Start by calculating the cost of every function in the module, as well as
1390   // the module's overall cost.
1391   FunctionsCostMap FnCosts;
1392   const CostType ModuleCost = calculateFunctionCosts(GetTTI, M, FnCosts);
1393 
1394   // Build the SplitGraph, which represents the module's functions and models
1395   // their dependencies accurately.
1396   SplitGraph SG(M, FnCosts, ModuleCost);
1397   SG.buildGraph(CG);
1398 
1399   if (SG.empty()) {
1400     LLVM_DEBUG(
1401         dbgs()
1402         << "[!] no nodes in graph, input is empty - no splitting possible\n");
1403     ModuleCallback(cloneAll(M));
1404     return;
1405   }
1406 
1407   LLVM_DEBUG({
1408     dbgs() << "[graph] nodes:\n";
1409     for (const SplitGraph::Node *N : SG.nodes()) {
1410       dbgs() << "  - [" << N->getID() << "]: " << N->getName() << " "
1411              << (N->isGraphEntryPoint() ? "(entry)" : "") << " "
1412              << (N->isNonCopyable() ? "(noncopyable)" : "") << "\n";
1413     }
1414   });
1415 
1416   writeDOTGraph(SG);
1417 
1418   LLVM_DEBUG(dbgs() << "[search] testing splitting strategies\n");
1419 
1420   std::optional<SplitProposal> Proposal;
1421   const auto EvaluateProposal = [&](SplitProposal SP) {
1422     SP.calculateScores();
1423     if (!Proposal)
1424       Proposal = std::move(SP);
1425     else
1426       evaluateProposal(*Proposal, std::move(SP));
1427   };
1428 
1429   // TODO: It would be very easy to create new strategies by just adding a base
1430   // class to RecursiveSearchSplitting and abstracting it away.
1431   RecursiveSearchSplitting(SG, NumParts, EvaluateProposal).run();
1432   LLVM_DEBUG(if (Proposal) dbgs() << "[search done] selected proposal: "
1433                                   << Proposal->getName() << "\n";);
1434 
1435   if (!Proposal) {
1436     LLVM_DEBUG(dbgs() << "[!] no proposal made, no splitting possible!\n");
1437     ModuleCallback(cloneAll(M));
1438     return;
1439   }
1440 
1441   LLVM_DEBUG(Proposal->print(dbgs()););
1442 
1443   std::optional<raw_fd_ostream> SummariesOS;
1444   if (!PartitionSummariesOutput.empty()) {
1445     std::error_code EC;
1446     SummariesOS.emplace(PartitionSummariesOutput, EC);
1447     if (EC)
1448       errs() << "[" DEBUG_TYPE "]: cannot open '" << PartitionSummariesOutput
1449              << "' - Partition summaries will not be printed\n";
1450   }
1451 
1452   for (unsigned PID = 0; PID < NumParts; ++PID) {
1453     SplitModuleTimer SMT2("modules_creation",
1454                           "creating modules for each partition");
1455     LLVM_DEBUG(dbgs() << "[split] creating new modules\n");
1456 
1457     DenseSet<const Function *> FnsInPart;
1458     for (unsigned NodeID : (*Proposal)[PID].set_bits())
1459       FnsInPart.insert(&SG.getNode(NodeID).getFunction());
1460 
1461     ValueToValueMapTy VMap;
1462     CostType PartCost = 0;
1463     std::unique_ptr<Module> MPart(
1464         CloneModule(M, VMap, [&](const GlobalValue *GV) {
1465           // Functions go in their assigned partition.
1466           if (const auto *Fn = dyn_cast<Function>(GV)) {
1467             if (FnsInPart.contains(Fn)) {
1468               PartCost += SG.getCost(*Fn);
1469               return true;
1470             }
1471             return false;
1472           }
1473 
1474           // Everything else goes in the first partition.
1475           return needsConservativeImport(GV) || PID == 0;
1476         }));
1477 
1478     // FIXME: Aliases aren't seen often, and their handling isn't perfect so
1479     // bugs are possible.
1480 
1481     // Clean-up conservatively imported GVs without any users.
1482     for (auto &GV : make_early_inc_range(MPart->global_values())) {
1483       if (needsConservativeImport(&GV) && GV.use_empty())
1484         GV.eraseFromParent();
1485     }
1486 
1487     if (SummariesOS)
1488       printPartitionSummary(*SummariesOS, PID, *MPart, PartCost, ModuleCost);
1489 
1490     LLVM_DEBUG(
1491         printPartitionSummary(dbgs(), PID, *MPart, PartCost, ModuleCost));
1492 
1493     ModuleCallback(std::move(MPart));
1494   }
1495 }
1496 } // namespace
1497 
1498 PreservedAnalyses AMDGPUSplitModulePass::run(Module &M,
1499                                              ModuleAnalysisManager &MAM) {
1500   SplitModuleTimer SMT(
1501       "total", "total pass runtime (incl. potentially waiting for lockfile)");
1502 
1503   FunctionAnalysisManager &FAM =
1504       MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
1505   const auto TTIGetter = [&FAM](Function &F) -> const TargetTransformInfo & {
1506     return FAM.getResult<TargetIRAnalysis>(F);
1507   };
1508 
1509   bool Done = false;
1510 #ifndef NDEBUG
1511   if (UseLockFile) {
1512     SmallString<128> LockFilePath;
1513     sys::path::system_temp_directory(/*ErasedOnReboot=*/true, LockFilePath);
1514     sys::path::append(LockFilePath, "amdgpu-split-module-debug");
1515     LLVM_DEBUG(dbgs() << DEBUG_TYPE " using lockfile '" << LockFilePath
1516                       << "'\n");
1517 
1518     while (true) {
1519       llvm::LockFileManager Locked(LockFilePath.str());
1520       switch (Locked) {
1521       case LockFileManager::LFS_Error:
1522         LLVM_DEBUG(
1523             dbgs() << "[amdgpu-split-module] unable to acquire lockfile, debug "
1524                       "output may be mangled by other processes\n");
1525         Locked.unsafeRemoveLockFile();
1526         break;
1527       case LockFileManager::LFS_Owned:
1528         break;
1529       case LockFileManager::LFS_Shared: {
1530         switch (Locked.waitForUnlock()) {
1531         case LockFileManager::Res_Success:
1532           break;
1533         case LockFileManager::Res_OwnerDied:
1534           continue; // try again to get the lock.
1535         case LockFileManager::Res_Timeout:
1536           LLVM_DEBUG(
1537               dbgs()
1538               << "[amdgpu-split-module] unable to acquire lockfile, debug "
1539                  "output may be mangled by other processes\n");
1540           Locked.unsafeRemoveLockFile();
1541           break; // give up
1542         }
1543         break;
1544       }
1545       }
1546 
1547       splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
1548       Done = true;
1549       break;
1550     }
1551   }
1552 #endif
1553 
1554   if (!Done)
1555     splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
1556 
1557   // We can change linkage/visibilities in the input, consider that nothing is
1558   // preserved just to be safe. This pass runs last anyway.
1559   return PreservedAnalyses::none();
1560 }
1561 } // namespace llvm
1562