xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp (revision 3414993eaffcfa2cb4ff723c8468434d56dff213)
1 //===- AMDGPUSplitModule.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Implements a module splitting algorithm designed to support the
10 /// FullLTO --lto-partitions option for parallel codegen.
11 ///
12 /// The role of this module splitting pass is the same as
13 /// lib/Transforms/Utils/SplitModule.cpp: load-balance the module's functions
14 /// across a set of N partitions to allow for parallel codegen.
15 ///
16 /// The similarities mostly end here, as this pass achieves load-balancing in a
17 /// more elaborate fashion which is targeted towards AMDGPU modules. It can take
18 /// advantage of the structure of AMDGPU modules (which are mostly
19 /// self-contained) to allow for more efficient splitting without affecting
20 /// codegen negatively, or causing innaccurate resource usage analysis.
21 ///
22 /// High-level pass overview:
23 ///   - SplitGraph & associated classes
24 ///      - Graph representation of the module and of the dependencies that
25 ///      matter for splitting.
26 ///   - RecursiveSearchSplitting
27 ///     - Core splitting algorithm.
28 ///   - SplitProposal
29 ///     - Represents a suggested solution for splitting the input module. These
30 ///     solutions can be scored to determine the best one when multiple
31 ///     solutions are available.
32 ///   - Driver/pass "run" function glues everything together.
33 
34 #include "AMDGPUSplitModule.h"
35 #include "AMDGPUTargetMachine.h"
36 #include "Utils/AMDGPUBaseInfo.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/EquivalenceClasses.h"
39 #include "llvm/ADT/GraphTraits.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/StringExtras.h"
42 #include "llvm/ADT/StringRef.h"
43 #include "llvm/Analysis/CallGraph.h"
44 #include "llvm/Analysis/TargetTransformInfo.h"
45 #include "llvm/IR/Function.h"
46 #include "llvm/IR/InstIterator.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/Value.h"
50 #include "llvm/Support/Allocator.h"
51 #include "llvm/Support/Casting.h"
52 #include "llvm/Support/DOTGraphTraits.h"
53 #include "llvm/Support/Debug.h"
54 #include "llvm/Support/GraphWriter.h"
55 #include "llvm/Support/Path.h"
56 #include "llvm/Support/Timer.h"
57 #include "llvm/Support/raw_ostream.h"
58 #include "llvm/Transforms/Utils/Cloning.h"
59 #include <cassert>
60 #include <cmath>
61 #include <memory>
62 #include <utility>
63 #include <vector>
64 
65 #ifndef NDEBUG
66 #include "llvm/Support/LockFileManager.h"
67 #endif
68 
69 #define DEBUG_TYPE "amdgpu-split-module"
70 
71 namespace llvm {
72 namespace {
73 
74 static cl::opt<unsigned> MaxDepth(
75     "amdgpu-module-splitting-max-depth",
76     cl::desc(
77         "maximum search depth. 0 forces a greedy approach. "
78         "warning: the algorithm is up to O(2^N), where N is the max depth."),
79     cl::init(8));
80 
81 static cl::opt<float> LargeFnFactor(
82     "amdgpu-module-splitting-large-threshold", cl::init(2.0f), cl::Hidden,
83     cl::desc(
84         "when max depth is reached and we can no longer branch out, this "
85         "value determines if a function is worth merging into an already "
86         "existing partition to reduce code duplication. This is a factor "
87         "of the ideal partition size, e.g. 2.0 means we consider the "
88         "function for merging if its cost (including its callees) is 2x the "
89         "size of an ideal partition."));
90 
91 static cl::opt<float> LargeFnOverlapForMerge(
92     "amdgpu-module-splitting-merge-threshold", cl::init(0.7f), cl::Hidden,
93     cl::desc("when a function is considered for merging into a partition that "
94              "already contains some of its callees, do the merge if at least "
95              "n% of the code it can reach is already present inside the "
96              "partition; e.g. 0.7 means only merge >70%"));
97 
98 static cl::opt<bool> NoExternalizeGlobals(
99     "amdgpu-module-splitting-no-externalize-globals", cl::Hidden,
100     cl::desc("disables externalization of global variable with local linkage; "
101              "may cause globals to be duplicated which increases binary size"));
102 
103 static cl::opt<bool> NoExternalizeOnAddrTaken(
104     "amdgpu-module-splitting-no-externalize-address-taken", cl::Hidden,
105     cl::desc(
106         "disables externalization of functions whose addresses are taken"));
107 
108 static cl::opt<std::string>
109     ModuleDotCfgOutput("amdgpu-module-splitting-print-module-dotcfg",
110                        cl::Hidden,
111                        cl::desc("output file to write out the dotgraph "
112                                 "representation of the input module"));
113 
114 static cl::opt<std::string> PartitionSummariesOutput(
115     "amdgpu-module-splitting-print-partition-summaries", cl::Hidden,
116     cl::desc("output file to write out a summary of "
117              "the partitions created for each module"));
118 
119 #ifndef NDEBUG
120 static cl::opt<bool>
121     UseLockFile("amdgpu-module-splitting-serial-execution", cl::Hidden,
122                 cl::desc("use a lock file so only one process in the system "
123                          "can run this pass at once. useful to avoid mangled "
124                          "debug output in multithreaded environments."));
125 
126 static cl::opt<bool>
127     DebugProposalSearch("amdgpu-module-splitting-debug-proposal-search",
128                         cl::Hidden,
129                         cl::desc("print all proposals received and whether "
130                                  "they were rejected or accepted"));
131 #endif
132 
133 struct SplitModuleTimer : NamedRegionTimer {
134   SplitModuleTimer(StringRef Name, StringRef Desc)
135       : NamedRegionTimer(Name, Desc, DEBUG_TYPE, "AMDGPU Module Splitting",
136                          TimePassesIsEnabled) {}
137 };
138 
139 //===----------------------------------------------------------------------===//
140 // Utils
141 //===----------------------------------------------------------------------===//
142 
143 using CostType = InstructionCost::CostType;
144 using FunctionsCostMap = DenseMap<const Function *, CostType>;
145 using GetTTIFn = function_ref<const TargetTransformInfo &(Function &)>;
146 static constexpr unsigned InvalidPID = -1;
147 
148 /// \param Num numerator
149 /// \param Dem denominator
150 /// \returns a printable object to print (Num/Dem) using "%0.2f".
151 static auto formatRatioOf(CostType Num, CostType Dem) {
152   CostType DemOr1 = Dem ? Dem : 1;
153   return format("%0.2f", (static_cast<double>(Num) / DemOr1) * 100);
154 }
155 
156 /// Checks whether a given function is non-copyable.
157 ///
158 /// Non-copyable functions cannot be cloned into multiple partitions, and only
159 /// one copy of the function can be present across all partitions.
160 ///
161 /// External functions fall into this category. If we were to clone them, we
162 /// would end up with multiple symbol definitions and a very unhappy linker.
163 static bool isNonCopyable(const Function &F) {
164   assert(AMDGPU::isEntryFunctionCC(F.getCallingConv())
165              ? F.hasExternalLinkage()
166              : true && "Kernel w/o external linkage?");
167   return F.hasExternalLinkage() || !F.isDefinitionExact();
168 }
169 
170 /// If \p GV has local linkage, make it external + hidden.
171 static void externalize(GlobalValue &GV) {
172   if (GV.hasLocalLinkage()) {
173     GV.setLinkage(GlobalValue::ExternalLinkage);
174     GV.setVisibility(GlobalValue::HiddenVisibility);
175   }
176 
177   // Unnamed entities must be named consistently between modules. setName will
178   // give a distinct name to each such entity.
179   if (!GV.hasName())
180     GV.setName("__llvmsplit_unnamed");
181 }
182 
183 /// Cost analysis function. Calculates the cost of each function in \p M
184 ///
185 /// \param GetTTI Abstract getter for TargetTransformInfo.
186 /// \param M Module to analyze.
187 /// \param CostMap[out] Resulting Function -> Cost map.
188 /// \return The module's total cost.
189 static CostType calculateFunctionCosts(GetTTIFn GetTTI, Module &M,
190                                        FunctionsCostMap &CostMap) {
191   SplitModuleTimer SMT("calculateFunctionCosts", "cost analysis");
192 
193   LLVM_DEBUG(dbgs() << "[cost analysis] calculating function costs\n");
194   CostType ModuleCost = 0;
195   [[maybe_unused]] CostType KernelCost = 0;
196 
197   for (auto &Fn : M) {
198     if (Fn.isDeclaration())
199       continue;
200 
201     CostType FnCost = 0;
202     const auto &TTI = GetTTI(Fn);
203     for (const auto &BB : Fn) {
204       for (const auto &I : BB) {
205         auto Cost =
206             TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
207         assert(Cost != InstructionCost::getMax());
208         // Assume expensive if we can't tell the cost of an instruction.
209         CostType CostVal =
210             Cost.getValue().value_or(TargetTransformInfo::TCC_Expensive);
211         assert((FnCost + CostVal) >= FnCost && "Overflow!");
212         FnCost += CostVal;
213       }
214     }
215 
216     assert(FnCost != 0);
217 
218     CostMap[&Fn] = FnCost;
219     assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!");
220     ModuleCost += FnCost;
221 
222     if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv()))
223       KernelCost += FnCost;
224   }
225 
226   if (CostMap.empty())
227     return 0;
228 
229   assert(ModuleCost);
230   LLVM_DEBUG({
231     const CostType FnCost = ModuleCost - KernelCost;
232     dbgs() << " - total module cost is " << ModuleCost << ". kernels cost "
233            << "" << KernelCost << " ("
234            << format("%0.2f", (float(KernelCost) / ModuleCost) * 100)
235            << "% of the module), functions cost " << FnCost << " ("
236            << format("%0.2f", (float(FnCost) / ModuleCost) * 100)
237            << "% of the module)\n";
238   });
239 
240   return ModuleCost;
241 }
242 
243 /// \return true if \p F can be indirectly called
244 static bool canBeIndirectlyCalled(const Function &F) {
245   if (F.isDeclaration() || AMDGPU::isEntryFunctionCC(F.getCallingConv()))
246     return false;
247   return !F.hasLocalLinkage() ||
248          F.hasAddressTaken(/*PutOffender=*/nullptr,
249                            /*IgnoreCallbackUses=*/false,
250                            /*IgnoreAssumeLikeCalls=*/true,
251                            /*IgnoreLLVMUsed=*/true,
252                            /*IgnoreARCAttachedCall=*/false,
253                            /*IgnoreCastedDirectCall=*/true);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // Graph-based Module Representation
258 //===----------------------------------------------------------------------===//
259 
260 /// AMDGPUSplitModule's view of the source Module, as a graph of all components
261 /// that can be split into different modules.
262 ///
263 /// The most trivial instance of this graph is just the CallGraph of the module,
264 /// but it is not guaranteed that the graph is strictly equal to the CG. It
265 /// currently always is but it's designed in a way that would eventually allow
266 /// us to create abstract nodes, or nodes for different entities such as global
267 /// variables or any other meaningful constraint we must consider.
268 ///
269 /// The graph is only mutable by this class, and is generally not modified
270 /// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
271 /// mutate it.
272 class SplitGraph {
273 public:
274   class Node;
275 
276   enum class EdgeKind : uint8_t {
277     /// The nodes are related through a direct call. This is a "strong" edge as
278     /// it means the Src will directly reference the Dst.
279     DirectCall,
280     /// The nodes are related through an indirect call.
281     /// This is a "weaker" edge and is only considered when traversing the graph
282     /// starting from a kernel. We need this edge for resource usage analysis.
283     ///
284     /// The reason why we have this edge in the first place is due to how
285     /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
286     /// the resource usage of the kernel containing the indirect call is the
287     /// max resource usage of all functions that can be indirectly called.
288     IndirectCall,
289   };
290 
291   /// An edge between two nodes. Edges are directional, and tagged with a
292   /// "kind".
293   struct Edge {
294     Edge(Node *Src, Node *Dst, EdgeKind Kind)
295         : Src(Src), Dst(Dst), Kind(Kind) {}
296 
297     Node *Src; ///< Source
298     Node *Dst; ///< Destination
299     EdgeKind Kind;
300   };
301 
302   using EdgesVec = SmallVector<const Edge *, 0>;
303   using edges_iterator = EdgesVec::const_iterator;
304   using nodes_iterator = const Node *const *;
305 
306   SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
307              CostType ModuleCost)
308       : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
309 
310   void buildGraph(CallGraph &CG);
311 
312 #ifndef NDEBUG
313   bool verifyGraph() const;
314 #endif
315 
316   bool empty() const { return Nodes.empty(); }
317   const iterator_range<nodes_iterator> nodes() const {
318     return {Nodes.begin(), Nodes.end()};
319   }
320   const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
321 
322   unsigned getNumNodes() const { return Nodes.size(); }
323   BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
324 
325   const Module &getModule() const { return M; }
326 
327   CostType getModuleCost() const { return ModuleCost; }
328   CostType getCost(const Function &F) const { return CostMap.at(&F); }
329 
330   /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
331   /// IDs).
332   CostType calculateCost(const BitVector &BV) const;
333 
334 private:
335   /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
336   /// updates \p Cache.
337   Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
338                 const GlobalValue &GV);
339 
340   // Create a new edge between two nodes and add it to both nodes.
341   const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
342 
343   const Module &M;
344   const FunctionsCostMap &CostMap;
345   CostType ModuleCost;
346 
347   // Final list of nodes with stable ordering.
348   SmallVector<Node *> Nodes;
349 
350   SpecificBumpPtrAllocator<Node> NodesPool;
351 
352   // Edges are trivially destructible objects, so as a small optimization we
353   // use a BumpPtrAllocator which avoids destructor calls but also makes
354   // allocation faster.
355   static_assert(
356       std::is_trivially_destructible_v<Edge>,
357       "Edge must be trivially destructible to use the BumpPtrAllocator");
358   BumpPtrAllocator EdgesPool;
359 };
360 
361 /// Nodes in the SplitGraph contain both incoming, and outgoing edges.
362 /// Incoming edges have this node as their Dst, and Outgoing ones have this node
363 /// as their Src.
364 ///
365 /// Edge objects are shared by both nodes in Src/Dst. They provide immediate
366 /// feedback on how two nodes are related, and in which direction they are
367 /// related, which is valuable information to make splitting decisions.
368 ///
369 /// Nodes are fundamentally abstract, and any consumers of the graph should
370 /// treat them as such. While a node will be a function most of the time, we
371 /// could also create nodes for any other reason. In the future, we could have
372 /// single nodes for multiple functions, or nodes for GVs, etc.
373 class SplitGraph::Node {
374   friend class SplitGraph;
375 
376 public:
377   Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
378        bool IsNonCopyable)
379       : ID(ID), GV(GV), IndividualCost(IndividualCost),
380         IsNonCopyable(IsNonCopyable), IsEntryFnCC(false), IsGraphEntry(false) {
381     if (auto *Fn = dyn_cast<Function>(&GV))
382       IsEntryFnCC = AMDGPU::isEntryFunctionCC(Fn->getCallingConv());
383   }
384 
385   /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
386   /// nodes in the graph. This ID can be used as an index in a BitVector.
387   unsigned getID() const { return ID; }
388 
389   const Function &getFunction() const { return cast<Function>(GV); }
390 
391   /// \returns the cost to import this component into a given module, not
392   /// accounting for any dependencies that may need to be imported as well.
393   CostType getIndividualCost() const { return IndividualCost; }
394 
395   bool isNonCopyable() const { return IsNonCopyable; }
396   bool isEntryFunctionCC() const { return IsEntryFnCC; }
397 
398   /// \returns whether this is an entry point in the graph. Entry points are
399   /// defined as follows: if you take all entry points in the graph, and iterate
400   /// their dependencies, you are guaranteed to visit all nodes in the graph at
401   /// least once.
402   bool isGraphEntryPoint() const { return IsGraphEntry; }
403 
404   StringRef getName() const { return GV.getName(); }
405 
406   bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
407   bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
408     return any_of(IncomingEdges, [&](const auto *E) { return E->Kind == EK; });
409   }
410 
411   bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
412   bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
413     return any_of(OutgoingEdges, [&](const auto *E) { return E->Kind == EK; });
414   }
415 
416   iterator_range<edges_iterator> incoming_edges() const {
417     return IncomingEdges;
418   }
419 
420   iterator_range<edges_iterator> outgoing_edges() const {
421     return OutgoingEdges;
422   }
423 
424   bool shouldFollowIndirectCalls() const { return isEntryFunctionCC(); }
425 
426   /// Visit all children of this node in a recursive fashion. Also visits Self.
427   /// If \ref shouldFollowIndirectCalls returns false, then this only follows
428   /// DirectCall edges.
429   ///
430   /// \param Visitor Visitor Function.
431   void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
432 
433   /// Adds the depedencies of this node in \p BV by setting the bit
434   /// corresponding to each node.
435   ///
436   /// Implemented using \ref visitAllDependencies, hence it follows the same
437   /// rules regarding dependencies traversal.
438   ///
439   /// \param[out] BV The bitvector where the bits should be set.
440   void getDependencies(BitVector &BV) const {
441     visitAllDependencies([&](const Node &N) { BV.set(N.getID()); });
442   }
443 
444 private:
445   void markAsGraphEntry() { IsGraphEntry = true; }
446 
447   unsigned ID;
448   const GlobalValue &GV;
449   CostType IndividualCost;
450   bool IsNonCopyable : 1;
451   bool IsEntryFnCC : 1;
452   bool IsGraphEntry : 1;
453 
454   // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
455   // together)
456   EdgesVec IncomingEdges;
457   EdgesVec OutgoingEdges;
458 };
459 
460 void SplitGraph::Node::visitAllDependencies(
461     std::function<void(const Node &)> Visitor) const {
462   const bool FollowIndirect = shouldFollowIndirectCalls();
463   // FIXME: If this can access SplitGraph in the future, use a BitVector
464   // instead.
465   DenseSet<const Node *> Seen;
466   SmallVector<const Node *, 8> WorkList({this});
467   while (!WorkList.empty()) {
468     const Node *CurN = WorkList.pop_back_val();
469     if (auto [It, Inserted] = Seen.insert(CurN); !Inserted)
470       continue;
471 
472     Visitor(*CurN);
473 
474     for (const Edge *E : CurN->outgoing_edges()) {
475       if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
476         continue;
477       WorkList.push_back(E->Dst);
478     }
479   }
480 }
481 
482 /// Checks if \p I has MD_callees and if it does, parse it and put the function
483 /// in \p Callees.
484 ///
485 /// \returns true if there was metadata and it was parsed correctly. false if
486 /// there was no MD or if it contained unknown entries and parsing failed.
487 /// If this returns false, \p Callees will contain incomplete information
488 /// and must not be used.
489 static bool handleCalleesMD(const Instruction &I,
490                             SetVector<Function *> &Callees) {
491   auto *MD = I.getMetadata(LLVMContext::MD_callees);
492   if (!MD)
493     return false;
494 
495   for (const auto &Op : MD->operands()) {
496     Function *Callee = mdconst::extract_or_null<Function>(Op);
497     if (!Callee)
498       return false;
499     Callees.insert(Callee);
500   }
501 
502   return true;
503 }
504 
505 void SplitGraph::buildGraph(CallGraph &CG) {
506   SplitModuleTimer SMT("buildGraph", "graph construction");
507   LLVM_DEBUG(
508       dbgs()
509       << "[build graph] constructing graph representation of the input\n");
510 
511   // FIXME(?): Is the callgraph really worth using if we have to iterate the
512   // function again whenever it fails to give us enough information?
513 
514   // We build the graph by just iterating all functions in the module and
515   // working on their direct callees. At the end, all nodes should be linked
516   // together as expected.
517   DenseMap<const GlobalValue *, Node *> Cache;
518   SmallVector<const Function *> FnsWithIndirectCalls, IndirectlyCallableFns;
519   for (const Function &Fn : M) {
520     if (Fn.isDeclaration())
521       continue;
522 
523     // Look at direct callees and create the necessary edges in the graph.
524     SetVector<const Function *> DirectCallees;
525     bool CallsExternal = false;
526     for (auto &CGEntry : *CG[&Fn]) {
527       auto *CGNode = CGEntry.second;
528       if (auto *Callee = CGNode->getFunction()) {
529         if (!Callee->isDeclaration())
530           DirectCallees.insert(Callee);
531       } else if (CGNode == CG.getCallsExternalNode())
532         CallsExternal = true;
533     }
534 
535     // Keep track of this function if it contains an indirect call and/or if it
536     // can be indirectly called.
537     if (CallsExternal) {
538       LLVM_DEBUG(dbgs() << "  [!] callgraph is incomplete for ";
539                  Fn.printAsOperand(dbgs());
540                  dbgs() << " - analyzing function\n");
541 
542       SetVector<Function *> KnownCallees;
543       bool HasUnknownIndirectCall = false;
544       for (const auto &Inst : instructions(Fn)) {
545         // look at all calls without a direct callee.
546         const auto *CB = dyn_cast<CallBase>(&Inst);
547         if (!CB || CB->getCalledFunction())
548           continue;
549 
550         // inline assembly can be ignored, unless InlineAsmIsIndirectCall is
551         // true.
552         if (CB->isInlineAsm()) {
553           LLVM_DEBUG(dbgs() << "    found inline assembly\n");
554           continue;
555         }
556 
557         if (handleCalleesMD(Inst, KnownCallees))
558           continue;
559         // If we failed to parse any !callees MD, or some was missing,
560         // the entire KnownCallees list is now unreliable.
561         KnownCallees.clear();
562 
563         // Everything else is handled conservatively. If we fall into the
564         // conservative case don't bother analyzing further.
565         HasUnknownIndirectCall = true;
566         break;
567       }
568 
569       if (HasUnknownIndirectCall) {
570         LLVM_DEBUG(dbgs() << "    indirect call found\n");
571         FnsWithIndirectCalls.push_back(&Fn);
572       } else if (!KnownCallees.empty())
573         DirectCallees.insert(KnownCallees.begin(), KnownCallees.end());
574     }
575 
576     Node &N = getNode(Cache, Fn);
577     for (const auto *Callee : DirectCallees)
578       createEdge(N, getNode(Cache, *Callee), EdgeKind::DirectCall);
579 
580     if (canBeIndirectlyCalled(Fn))
581       IndirectlyCallableFns.push_back(&Fn);
582   }
583 
584   // Post-process functions with indirect calls.
585   for (const Function *Fn : FnsWithIndirectCalls) {
586     for (const Function *Candidate : IndirectlyCallableFns) {
587       Node &Src = getNode(Cache, *Fn);
588       Node &Dst = getNode(Cache, *Candidate);
589       createEdge(Src, Dst, EdgeKind::IndirectCall);
590     }
591   }
592 
593   // Now, find all entry points.
594   SmallVector<Node *, 16> CandidateEntryPoints;
595   BitVector NodesReachableByKernels = createNodesBitVector();
596   for (Node *N : Nodes) {
597     // Functions with an Entry CC are always graph entry points too.
598     if (N->isEntryFunctionCC()) {
599       N->markAsGraphEntry();
600       N->getDependencies(NodesReachableByKernels);
601     } else if (!N->hasAnyIncomingEdgesOfKind(EdgeKind::DirectCall))
602       CandidateEntryPoints.push_back(N);
603   }
604 
605   for (Node *N : CandidateEntryPoints) {
606     // This can be another entry point if it's not reachable by a kernel
607     // TODO: We could sort all of the possible new entries in a stable order
608     // (e.g. by cost), then consume them one by one until
609     // NodesReachableByKernels is all 1s. It'd allow us to avoid
610     // considering some nodes as non-entries in some specific cases.
611     if (!NodesReachableByKernels.test(N->getID()))
612       N->markAsGraphEntry();
613   }
614 
615 #ifndef NDEBUG
616   assert(verifyGraph());
617 #endif
618 }
619 
620 #ifndef NDEBUG
621 bool SplitGraph::verifyGraph() const {
622   unsigned ExpectedID = 0;
623   // Exceptionally using a set here in case IDs are messed up.
624   DenseSet<const Node *> SeenNodes;
625   DenseSet<const Function *> SeenFunctionNodes;
626   for (const Node *N : Nodes) {
627     if (N->getID() != (ExpectedID++)) {
628       errs() << "Node IDs are incorrect!\n";
629       return false;
630     }
631 
632     if (!SeenNodes.insert(N).second) {
633       errs() << "Node seen more than once!\n";
634       return false;
635     }
636 
637     if (&getNode(N->getID()) != N) {
638       errs() << "getNode doesn't return the right node\n";
639       return false;
640     }
641 
642     for (const Edge *E : N->IncomingEdges) {
643       if (!E->Src || !E->Dst || (E->Dst != N) ||
644           (find(E->Src->OutgoingEdges, E) == E->Src->OutgoingEdges.end())) {
645         errs() << "ill-formed incoming edges\n";
646         return false;
647       }
648     }
649 
650     for (const Edge *E : N->OutgoingEdges) {
651       if (!E->Src || !E->Dst || (E->Src != N) ||
652           (find(E->Dst->IncomingEdges, E) == E->Dst->IncomingEdges.end())) {
653         errs() << "ill-formed outgoing edges\n";
654         return false;
655       }
656     }
657 
658     const Function &Fn = N->getFunction();
659     if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv())) {
660       if (N->hasAnyIncomingEdges()) {
661         errs() << "Kernels cannot have incoming edges\n";
662         return false;
663       }
664     }
665 
666     if (Fn.isDeclaration()) {
667       errs() << "declarations shouldn't have nodes!\n";
668       return false;
669     }
670 
671     auto [It, Inserted] = SeenFunctionNodes.insert(&Fn);
672     if (!Inserted) {
673       errs() << "one function has multiple nodes!\n";
674       return false;
675     }
676   }
677 
678   if (ExpectedID != Nodes.size()) {
679     errs() << "Node IDs out of sync!\n";
680     return false;
681   }
682 
683   if (createNodesBitVector().size() != getNumNodes()) {
684     errs() << "nodes bit vector doesn't have the right size!\n";
685     return false;
686   }
687 
688   // Check we respect the promise of Node::isKernel
689   BitVector BV = createNodesBitVector();
690   for (const Node *N : nodes()) {
691     if (N->isGraphEntryPoint())
692       N->getDependencies(BV);
693   }
694 
695   // Ensure each function in the module has an associated node.
696   for (const auto &Fn : M) {
697     if (!Fn.isDeclaration()) {
698       if (!SeenFunctionNodes.contains(&Fn)) {
699         errs() << "Fn has no associated node in the graph!\n";
700         return false;
701       }
702     }
703   }
704 
705   if (!BV.all()) {
706     errs() << "not all nodes are reachable through the graph's entry points!\n";
707     return false;
708   }
709 
710   return true;
711 }
712 #endif
713 
714 CostType SplitGraph::calculateCost(const BitVector &BV) const {
715   CostType Cost = 0;
716   for (unsigned NodeID : BV.set_bits())
717     Cost += getNode(NodeID).getIndividualCost();
718   return Cost;
719 }
720 
721 SplitGraph::Node &
722 SplitGraph::getNode(DenseMap<const GlobalValue *, Node *> &Cache,
723                     const GlobalValue &GV) {
724   auto &N = Cache[&GV];
725   if (N)
726     return *N;
727 
728   CostType Cost = 0;
729   bool NonCopyable = false;
730   if (const Function *Fn = dyn_cast<Function>(&GV)) {
731     NonCopyable = isNonCopyable(*Fn);
732     Cost = CostMap.at(Fn);
733   }
734   N = new (NodesPool.Allocate()) Node(Nodes.size(), GV, Cost, NonCopyable);
735   Nodes.push_back(N);
736   assert(&getNode(N->getID()) == N);
737   return *N;
738 }
739 
740 const SplitGraph::Edge &SplitGraph::createEdge(Node &Src, Node &Dst,
741                                                EdgeKind EK) {
742   const Edge *E = new (EdgesPool.Allocate<Edge>(1)) Edge(&Src, &Dst, EK);
743   Src.OutgoingEdges.push_back(E);
744   Dst.IncomingEdges.push_back(E);
745   return *E;
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // Split Proposals
750 //===----------------------------------------------------------------------===//
751 
752 /// Represents a module splitting proposal.
753 ///
754 /// Proposals are made of N BitVectors, one for each partition, where each bit
755 /// set indicates that the node is present and should be copied inside that
756 /// partition.
757 ///
758 /// Proposals have several metrics attached so they can be compared/sorted,
759 /// which the driver to try multiple strategies resultings in multiple proposals
760 /// and choose the best one out of them.
761 class SplitProposal {
762 public:
763   SplitProposal(const SplitGraph &SG, unsigned MaxPartitions) : SG(&SG) {
764     Partitions.resize(MaxPartitions, {0, SG.createNodesBitVector()});
765   }
766 
767   void setName(StringRef NewName) { Name = NewName; }
768   StringRef getName() const { return Name; }
769 
770   const BitVector &operator[](unsigned PID) const {
771     return Partitions[PID].second;
772   }
773 
774   void add(unsigned PID, const BitVector &BV) {
775     Partitions[PID].second |= BV;
776     updateScore(PID);
777   }
778 
779   void print(raw_ostream &OS) const;
780   LLVM_DUMP_METHOD void dump() const { print(dbgs()); }
781 
782   // Find the cheapest partition (lowest cost). In case of ties, always returns
783   // the highest partition number.
784   unsigned findCheapestPartition() const;
785 
786   /// Calculate the CodeSize and Bottleneck scores.
787   void calculateScores();
788 
789 #ifndef NDEBUG
790   void verifyCompleteness() const;
791 #endif
792 
793   /// Only available after \ref calculateScores is called.
794   ///
795   /// A positive number indicating the % of code duplication that this proposal
796   /// creates. e.g. 0.2 means this proposal adds roughly 20% code size by
797   /// duplicating some functions across partitions.
798   ///
799   /// Value is always rounded up to 3 decimal places.
800   ///
801   /// A perfect score would be 0.0, and anything approaching 1.0 is very bad.
802   double getCodeSizeScore() const { return CodeSizeScore; }
803 
804   /// Only available after \ref calculateScores is called.
805   ///
806   /// A number between [0, 1] which indicates how big of a bottleneck is
807   /// expected from the largest partition.
808   ///
809   /// A score of 1.0 means the biggest partition is as big as the source module,
810   /// so build time will be equal to or greater than the build time of the
811   /// initial input.
812   ///
813   /// Value is always rounded up to 3 decimal places.
814   ///
815   /// This is one of the metrics used to estimate this proposal's build time.
816   double getBottleneckScore() const { return BottleneckScore; }
817 
818 private:
819   void updateScore(unsigned PID) {
820     assert(SG);
821     for (auto &[PCost, Nodes] : Partitions) {
822       TotalCost -= PCost;
823       PCost = SG->calculateCost(Nodes);
824       TotalCost += PCost;
825     }
826   }
827 
828   /// \see getCodeSizeScore
829   double CodeSizeScore = 0.0;
830   /// \see getBottleneckScore
831   double BottleneckScore = 0.0;
832   /// Aggregated cost of all partitions
833   CostType TotalCost = 0;
834 
835   const SplitGraph *SG = nullptr;
836   std::string Name;
837 
838   std::vector<std::pair<CostType, BitVector>> Partitions;
839 };
840 
841 void SplitProposal::print(raw_ostream &OS) const {
842   assert(SG);
843 
844   OS << "[proposal] " << Name << ", total cost:" << TotalCost
845      << ", code size score:" << format("%0.3f", CodeSizeScore)
846      << ", bottleneck score:" << format("%0.3f", BottleneckScore) << '\n';
847   for (const auto &[PID, Part] : enumerate(Partitions)) {
848     const auto &[Cost, NodeIDs] = Part;
849     OS << "  - P" << PID << " nodes:" << NodeIDs.count() << " cost: " << Cost
850        << '|' << formatRatioOf(Cost, SG->getModuleCost()) << "%\n";
851   }
852 }
853 
854 unsigned SplitProposal::findCheapestPartition() const {
855   assert(!Partitions.empty());
856   CostType CurCost = std::numeric_limits<CostType>::max();
857   unsigned CurPID = InvalidPID;
858   for (const auto &[Idx, Part] : enumerate(Partitions)) {
859     if (Part.first <= CurCost) {
860       CurPID = Idx;
861       CurCost = Part.first;
862     }
863   }
864   assert(CurPID != InvalidPID);
865   return CurPID;
866 }
867 
868 void SplitProposal::calculateScores() {
869   if (Partitions.empty())
870     return;
871 
872   assert(SG);
873   CostType LargestPCost = 0;
874   for (auto &[PCost, Nodes] : Partitions) {
875     if (PCost > LargestPCost)
876       LargestPCost = PCost;
877   }
878 
879   CostType ModuleCost = SG->getModuleCost();
880   CodeSizeScore = double(TotalCost) / ModuleCost;
881   assert(CodeSizeScore >= 0.0);
882 
883   BottleneckScore = double(LargestPCost) / ModuleCost;
884 
885   CodeSizeScore = std::ceil(CodeSizeScore * 100.0) / 100.0;
886   BottleneckScore = std::ceil(BottleneckScore * 100.0) / 100.0;
887 }
888 
889 #ifndef NDEBUG
890 void SplitProposal::verifyCompleteness() const {
891   if (Partitions.empty())
892     return;
893 
894   BitVector Result = Partitions[0].second;
895   for (const auto &P : drop_begin(Partitions))
896     Result |= P.second;
897   assert(Result.all() && "some nodes are missing from this proposal!");
898 }
899 #endif
900 
901 //===-- RecursiveSearchStrategy -------------------------------------------===//
902 
903 /// Partitioning algorithm.
904 ///
905 /// This is a recursive search algorithm that can explore multiple possiblities.
906 ///
907 /// When a cluster of nodes can go into more than one partition, and we haven't
908 /// reached maximum search depth, we recurse and explore both options and their
909 /// consequences. Both branches will yield a proposal, and the driver will grade
910 /// both and choose the best one.
911 ///
912 /// If max depth is reached, we will use some heuristics to make a choice. Most
913 /// of the time we will just use the least-pressured (cheapest) partition, but
914 /// if a cluster is particularly big and there is a good amount of overlap with
915 /// an existing partition, we will choose that partition instead.
916 class RecursiveSearchSplitting {
917 public:
918   using SubmitProposalFn = function_ref<void(SplitProposal)>;
919 
920   RecursiveSearchSplitting(const SplitGraph &SG, unsigned NumParts,
921                            SubmitProposalFn SubmitProposal);
922 
923   void run();
924 
925 private:
926   struct WorkListEntry {
927     WorkListEntry(const BitVector &BV) : Cluster(BV) {}
928 
929     unsigned NumNonEntryNodes = 0;
930     CostType TotalCost = 0;
931     CostType CostExcludingGraphEntryPoints = 0;
932     BitVector Cluster;
933   };
934 
935   /// Collects all graph entry points's clusters and sort them so the most
936   /// expensive clusters are viewed first. This will merge clusters together if
937   /// they share a non-copyable dependency.
938   void setupWorkList();
939 
940   /// Recursive function that assigns the worklist item at \p Idx into a
941   /// partition of \p SP.
942   ///
943   /// \p Depth is the current search depth. When this value is equal to
944   /// \ref MaxDepth, we can no longer recurse.
945   ///
946   /// This function only recurses if there is more than one possible assignment,
947   /// otherwise it is iterative to avoid creating a call stack that is as big as
948   /// \ref WorkList.
949   void pickPartition(unsigned Depth, unsigned Idx, SplitProposal SP);
950 
951   /// \return A pair: first element is the PID of the partition that has the
952   /// most similarities with \p Entry, or \ref InvalidPID if no partition was
953   /// found with at least one element in common. The second element is the
954   /// aggregated cost of all dependencies in common between \p Entry and that
955   /// partition.
956   std::pair<unsigned, CostType>
957   findMostSimilarPartition(const WorkListEntry &Entry, const SplitProposal &SP);
958 
959   const SplitGraph &SG;
960   unsigned NumParts;
961   SubmitProposalFn SubmitProposal;
962 
963   // A Cluster is considered large when its cost, excluding entry points,
964   // exceeds this value.
965   CostType LargeClusterThreshold = 0;
966   unsigned NumProposalsSubmitted = 0;
967   SmallVector<WorkListEntry> WorkList;
968 };
969 
970 RecursiveSearchSplitting::RecursiveSearchSplitting(
971     const SplitGraph &SG, unsigned NumParts, SubmitProposalFn SubmitProposal)
972     : SG(SG), NumParts(NumParts), SubmitProposal(SubmitProposal) {
973   // arbitrary max value as a safeguard. Anything above 10 will already be
974   // slow, this is just a max value to prevent extreme resource exhaustion or
975   // unbounded run time.
976   if (MaxDepth > 16)
977     report_fatal_error("[amdgpu-split-module] search depth of " +
978                        Twine(MaxDepth) + " is too high!");
979   LargeClusterThreshold =
980       (LargeFnFactor != 0.0)
981           ? CostType(((SG.getModuleCost() / NumParts) * LargeFnFactor))
982           : std::numeric_limits<CostType>::max();
983   LLVM_DEBUG(dbgs() << "[recursive search] large cluster threshold set at "
984                     << LargeClusterThreshold << "\n");
985 }
986 
987 void RecursiveSearchSplitting::run() {
988   {
989     SplitModuleTimer SMT("recursive_search_prepare", "preparing worklist");
990     setupWorkList();
991   }
992 
993   {
994     SplitModuleTimer SMT("recursive_search_pick", "partitioning");
995     SplitProposal SP(SG, NumParts);
996     pickPartition(/*BranchDepth=*/0, /*Idx=*/0, SP);
997   }
998 }
999 
1000 void RecursiveSearchSplitting::setupWorkList() {
1001   // e.g. if A and B are two worklist item, and they both call a non copyable
1002   // dependency C, this does:
1003   //    A=C
1004   //    B=C
1005   // => NodeEC will create a single group (A, B, C) and we create a new
1006   // WorkList entry for that group.
1007 
1008   EquivalenceClasses<unsigned> NodeEC;
1009   for (const SplitGraph::Node *N : SG.nodes()) {
1010     if (!N->isGraphEntryPoint())
1011       continue;
1012 
1013     NodeEC.insert(N->getID());
1014     N->visitAllDependencies([&](const SplitGraph::Node &Dep) {
1015       if (&Dep != N && Dep.isNonCopyable())
1016         NodeEC.unionSets(N->getID(), Dep.getID());
1017     });
1018   }
1019 
1020   for (auto I = NodeEC.begin(), E = NodeEC.end(); I != E; ++I) {
1021     if (!I->isLeader())
1022       continue;
1023 
1024     BitVector Cluster = SG.createNodesBitVector();
1025     for (auto MI = NodeEC.member_begin(I); MI != NodeEC.member_end(); ++MI) {
1026       const SplitGraph::Node &N = SG.getNode(*MI);
1027       if (N.isGraphEntryPoint())
1028         N.getDependencies(Cluster);
1029     }
1030     WorkList.emplace_back(std::move(Cluster));
1031   }
1032 
1033   // Calculate costs and other useful information.
1034   for (WorkListEntry &Entry : WorkList) {
1035     for (unsigned NodeID : Entry.Cluster.set_bits()) {
1036       const SplitGraph::Node &N = SG.getNode(NodeID);
1037       const CostType Cost = N.getIndividualCost();
1038 
1039       Entry.TotalCost += Cost;
1040       if (!N.isGraphEntryPoint()) {
1041         Entry.CostExcludingGraphEntryPoints += Cost;
1042         ++Entry.NumNonEntryNodes;
1043       }
1044     }
1045   }
1046 
1047   stable_sort(WorkList, [](const WorkListEntry &A, const WorkListEntry &B) {
1048     if (A.TotalCost != B.TotalCost)
1049       return A.TotalCost > B.TotalCost;
1050 
1051     if (A.CostExcludingGraphEntryPoints != B.CostExcludingGraphEntryPoints)
1052       return A.CostExcludingGraphEntryPoints > B.CostExcludingGraphEntryPoints;
1053 
1054     if (A.NumNonEntryNodes != B.NumNonEntryNodes)
1055       return A.NumNonEntryNodes > B.NumNonEntryNodes;
1056 
1057     return A.Cluster.count() > B.Cluster.count();
1058   });
1059 
1060   LLVM_DEBUG({
1061     dbgs() << "[recursive search] worklist:\n";
1062     for (const auto &[Idx, Entry] : enumerate(WorkList)) {
1063       dbgs() << "  - [" << Idx << "]: ";
1064       for (unsigned NodeID : Entry.Cluster.set_bits())
1065         dbgs() << NodeID << " ";
1066       dbgs() << "(total_cost:" << Entry.TotalCost
1067              << ", cost_excl_entries:" << Entry.CostExcludingGraphEntryPoints
1068              << ")\n";
1069     }
1070   });
1071 }
1072 
1073 void RecursiveSearchSplitting::pickPartition(unsigned Depth, unsigned Idx,
1074                                              SplitProposal SP) {
1075   while (Idx < WorkList.size()) {
1076     // Step 1: Determine candidate PIDs.
1077     //
1078     const WorkListEntry &Entry = WorkList[Idx];
1079     const BitVector &Cluster = Entry.Cluster;
1080 
1081     // Default option is to do load-balancing, AKA assign to least pressured
1082     // partition.
1083     const unsigned CheapestPID = SP.findCheapestPartition();
1084     assert(CheapestPID != InvalidPID);
1085 
1086     // Explore assigning to the kernel that contains the most dependencies in
1087     // common.
1088     const auto [MostSimilarPID, SimilarDepsCost] =
1089         findMostSimilarPartition(Entry, SP);
1090 
1091     // We can chose to explore only one path if we only have one valid path, or
1092     // if we reached maximum search depth and can no longer branch out.
1093     unsigned SinglePIDToTry = InvalidPID;
1094     if (MostSimilarPID == InvalidPID) // no similar PID found
1095       SinglePIDToTry = CheapestPID;
1096     else if (MostSimilarPID == CheapestPID) // both landed on the same PID
1097       SinglePIDToTry = CheapestPID;
1098     else if (Depth >= MaxDepth) {
1099       // We have to choose one path. Use a heuristic to guess which one will be
1100       // more appropriate.
1101       if (Entry.CostExcludingGraphEntryPoints > LargeClusterThreshold) {
1102         // Check if the amount of code in common makes it worth it.
1103         assert(SimilarDepsCost && Entry.CostExcludingGraphEntryPoints);
1104         const double Ratio =
1105             SimilarDepsCost / Entry.CostExcludingGraphEntryPoints;
1106         assert(Ratio >= 0.0 && Ratio <= 1.0);
1107         if (LargeFnOverlapForMerge > Ratio) {
1108           // For debug, just print "L", so we'll see "L3=P3" for instance, which
1109           // will mean we reached max depth and chose P3 based on this
1110           // heuristic.
1111           LLVM_DEBUG(dbgs() << 'L');
1112           SinglePIDToTry = MostSimilarPID;
1113         }
1114       } else
1115         SinglePIDToTry = CheapestPID;
1116     }
1117 
1118     // Step 2: Explore candidates.
1119 
1120     // When we only explore one possible path, and thus branch depth doesn't
1121     // increase, do not recurse, iterate instead.
1122     if (SinglePIDToTry != InvalidPID) {
1123       LLVM_DEBUG(dbgs() << Idx << "=P" << SinglePIDToTry << ' ');
1124       // Only one path to explore, don't clone SP, don't increase depth.
1125       SP.add(SinglePIDToTry, Cluster);
1126       ++Idx;
1127       continue;
1128     }
1129 
1130     assert(MostSimilarPID != InvalidPID);
1131 
1132     // We explore multiple paths: recurse at increased depth, then stop this
1133     // function.
1134 
1135     LLVM_DEBUG(dbgs() << '\n');
1136 
1137     // lb = load balancing = put in cheapest partition
1138     {
1139       SplitProposal BranchSP = SP;
1140       LLVM_DEBUG(dbgs().indent(Depth)
1141                  << " [lb] " << Idx << "=P" << CheapestPID << "? ");
1142       BranchSP.add(CheapestPID, Cluster);
1143       pickPartition(Depth + 1, Idx + 1, BranchSP);
1144     }
1145 
1146     // ms = most similar = put in partition with the most in common
1147     {
1148       SplitProposal BranchSP = SP;
1149       LLVM_DEBUG(dbgs().indent(Depth)
1150                  << " [ms] " << Idx << "=P" << MostSimilarPID << "? ");
1151       BranchSP.add(MostSimilarPID, Cluster);
1152       pickPartition(Depth + 1, Idx + 1, BranchSP);
1153     }
1154 
1155     return;
1156   }
1157 
1158   // Step 3: If we assigned all WorkList items, submit the proposal.
1159 
1160   assert(Idx == WorkList.size());
1161   assert(NumProposalsSubmitted <= (2u << MaxDepth) &&
1162          "Search got out of bounds?");
1163   SP.setName("recursive_search (depth=" + std::to_string(Depth) + ") #" +
1164              std::to_string(NumProposalsSubmitted++));
1165   LLVM_DEBUG(dbgs() << '\n');
1166   SubmitProposal(SP);
1167 }
1168 
1169 std::pair<unsigned, CostType>
1170 RecursiveSearchSplitting::findMostSimilarPartition(const WorkListEntry &Entry,
1171                                                    const SplitProposal &SP) {
1172   if (!Entry.NumNonEntryNodes)
1173     return {InvalidPID, 0};
1174 
1175   // We take the partition that is the most similar using Cost as a metric.
1176   // So we take the set of nodes in common, compute their aggregated cost, and
1177   // pick the partition with the highest cost in common.
1178   unsigned ChosenPID = InvalidPID;
1179   CostType ChosenCost = 0;
1180   for (unsigned PID = 0; PID < NumParts; ++PID) {
1181     BitVector BV = SP[PID];
1182     BV &= Entry.Cluster; // FIXME: & doesn't work between BVs?!
1183 
1184     if (BV.none())
1185       continue;
1186 
1187     const CostType Cost = SG.calculateCost(BV);
1188 
1189     if (ChosenPID == InvalidPID || ChosenCost < Cost ||
1190         (ChosenCost == Cost && PID > ChosenPID)) {
1191       ChosenPID = PID;
1192       ChosenCost = Cost;
1193     }
1194   }
1195 
1196   return {ChosenPID, ChosenCost};
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // DOTGraph Printing Support
1201 //===----------------------------------------------------------------------===//
1202 
1203 const SplitGraph::Node *mapEdgeToDst(const SplitGraph::Edge *E) {
1204   return E->Dst;
1205 }
1206 
1207 using SplitGraphEdgeDstIterator =
1208     mapped_iterator<SplitGraph::edges_iterator, decltype(&mapEdgeToDst)>;
1209 
1210 } // namespace
1211 
1212 template <> struct GraphTraits<SplitGraph> {
1213   using NodeRef = const SplitGraph::Node *;
1214   using nodes_iterator = SplitGraph::nodes_iterator;
1215   using ChildIteratorType = SplitGraphEdgeDstIterator;
1216 
1217   using EdgeRef = const SplitGraph::Edge *;
1218   using ChildEdgeIteratorType = SplitGraph::edges_iterator;
1219 
1220   static NodeRef getEntryNode(NodeRef N) { return N; }
1221 
1222   static ChildIteratorType child_begin(NodeRef Ref) {
1223     return {Ref->outgoing_edges().begin(), mapEdgeToDst};
1224   }
1225   static ChildIteratorType child_end(NodeRef Ref) {
1226     return {Ref->outgoing_edges().end(), mapEdgeToDst};
1227   }
1228 
1229   static nodes_iterator nodes_begin(const SplitGraph &G) {
1230     return G.nodes().begin();
1231   }
1232   static nodes_iterator nodes_end(const SplitGraph &G) {
1233     return G.nodes().end();
1234   }
1235 };
1236 
1237 template <> struct DOTGraphTraits<SplitGraph> : public DefaultDOTGraphTraits {
1238   DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
1239 
1240   static std::string getGraphName(const SplitGraph &SG) {
1241     return SG.getModule().getName().str();
1242   }
1243 
1244   std::string getNodeLabel(const SplitGraph::Node *N, const SplitGraph &SG) {
1245     return N->getName().str();
1246   }
1247 
1248   static std::string getNodeDescription(const SplitGraph::Node *N,
1249                                         const SplitGraph &SG) {
1250     std::string Result;
1251     if (N->isEntryFunctionCC())
1252       Result += "entry-fn-cc ";
1253     if (N->isNonCopyable())
1254       Result += "non-copyable ";
1255     Result += "cost:" + std::to_string(N->getIndividualCost());
1256     return Result;
1257   }
1258 
1259   static std::string getNodeAttributes(const SplitGraph::Node *N,
1260                                        const SplitGraph &SG) {
1261     return N->hasAnyIncomingEdges() ? "" : "color=\"red\"";
1262   }
1263 
1264   static std::string getEdgeAttributes(const SplitGraph::Node *N,
1265                                        SplitGraphEdgeDstIterator EI,
1266                                        const SplitGraph &SG) {
1267 
1268     switch ((*EI.getCurrent())->Kind) {
1269     case SplitGraph::EdgeKind::DirectCall:
1270       return "";
1271     case SplitGraph::EdgeKind::IndirectCall:
1272       return "style=\"dashed\"";
1273     }
1274     llvm_unreachable("Unknown SplitGraph::EdgeKind enum");
1275   }
1276 };
1277 
1278 //===----------------------------------------------------------------------===//
1279 // Driver
1280 //===----------------------------------------------------------------------===//
1281 
1282 namespace {
1283 
1284 // If we didn't externalize GVs, then local GVs need to be conservatively
1285 // imported into every module (including their initializers), and then cleaned
1286 // up afterwards.
1287 static bool needsConservativeImport(const GlobalValue *GV) {
1288   if (const auto *Var = dyn_cast<GlobalVariable>(GV))
1289     return Var->hasLocalLinkage();
1290   return isa<GlobalAlias>(GV);
1291 }
1292 
1293 /// Prints a summary of the partition \p N, represented by module \p M, to \p
1294 /// OS.
1295 static void printPartitionSummary(raw_ostream &OS, unsigned N, const Module &M,
1296                                   unsigned PartCost, unsigned ModuleCost) {
1297   OS << "*** Partition P" << N << " ***\n";
1298 
1299   for (const auto &Fn : M) {
1300     if (!Fn.isDeclaration())
1301       OS << " - [function] " << Fn.getName() << "\n";
1302   }
1303 
1304   for (const auto &GV : M.globals()) {
1305     if (GV.hasInitializer())
1306       OS << " - [global] " << GV.getName() << "\n";
1307   }
1308 
1309   OS << "Partition contains " << formatRatioOf(PartCost, ModuleCost)
1310      << "% of the source\n";
1311 }
1312 
1313 static void evaluateProposal(SplitProposal &Best, SplitProposal New) {
1314   SplitModuleTimer SMT("proposal_evaluation", "proposal ranking algorithm");
1315 
1316   LLVM_DEBUG({
1317     New.verifyCompleteness();
1318     if (DebugProposalSearch)
1319       New.print(dbgs());
1320   });
1321 
1322   const double CurBScore = Best.getBottleneckScore();
1323   const double CurCSScore = Best.getCodeSizeScore();
1324   const double NewBScore = New.getBottleneckScore();
1325   const double NewCSScore = New.getCodeSizeScore();
1326 
1327   // TODO: Improve this
1328   //    We can probably lower the precision of the comparison at first
1329   //    e.g. if we have
1330   //      - (Current): BScore: 0.489 CSCore 1.105
1331   //      - (New): BScore: 0.475 CSCore 1.305
1332   //    Currently we'd choose the new one because the bottleneck score is
1333   //    lower, but the new one duplicates more code. It may be worth it to
1334   //    discard the new proposal as the impact on build time is negligible.
1335 
1336   // Compare them
1337   bool IsBest = false;
1338   if (NewBScore < CurBScore)
1339     IsBest = true;
1340   else if (NewBScore == CurBScore)
1341     IsBest = (NewCSScore < CurCSScore); // Use code size as tie breaker.
1342 
1343   if (IsBest)
1344     Best = std::move(New);
1345 
1346   LLVM_DEBUG(if (DebugProposalSearch) {
1347     if (IsBest)
1348       dbgs() << "[search] new best proposal!\n";
1349     else
1350       dbgs() << "[search] discarding - not profitable\n";
1351   });
1352 }
1353 
1354 /// Trivial helper to create an identical copy of \p M.
1355 static std::unique_ptr<Module> cloneAll(const Module &M) {
1356   ValueToValueMapTy VMap;
1357   return CloneModule(M, VMap, [&](const GlobalValue *GV) { return true; });
1358 }
1359 
1360 /// Writes \p SG as a DOTGraph to \ref ModuleDotCfgDir if requested.
1361 static void writeDOTGraph(const SplitGraph &SG) {
1362   if (ModuleDotCfgOutput.empty())
1363     return;
1364 
1365   std::error_code EC;
1366   raw_fd_ostream OS(ModuleDotCfgOutput, EC);
1367   if (EC) {
1368     errs() << "[" DEBUG_TYPE "]: cannot open '" << ModuleDotCfgOutput
1369            << "' - DOTGraph will not be printed\n";
1370   }
1371   WriteGraph(OS, SG, /*ShortName=*/false,
1372              /*Title=*/SG.getModule().getName());
1373 }
1374 
1375 static void splitAMDGPUModule(
1376     GetTTIFn GetTTI, Module &M, unsigned NumParts,
1377     function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
1378   CallGraph CG(M);
1379 
1380   // Externalize functions whose address are taken.
1381   //
1382   // This is needed because partitioning is purely based on calls, but sometimes
1383   // a kernel/function may just look at the address of another local function
1384   // and not do anything (no calls). After partitioning, that local function may
1385   // end up in a different module (so it's just a declaration in the module
1386   // where its address is taken), which emits a "undefined hidden symbol" linker
1387   // error.
1388   //
1389   // Additionally, it guides partitioning to not duplicate this function if it's
1390   // called directly at some point.
1391   //
1392   // TODO: Could we be smarter about this ? This makes all functions whose
1393   // addresses are taken non-copyable. We should probably model this type of
1394   // constraint in the graph and use it to guide splitting, instead of
1395   // externalizing like this. Maybe non-copyable should really mean "keep one
1396   // visible copy, then internalize all other copies" for some functions?
1397   if (!NoExternalizeOnAddrTaken) {
1398     for (auto &Fn : M) {
1399       // TODO: Should aliases count? Probably not but they're so rare I'm not
1400       // sure it's worth fixing.
1401       if (Fn.hasLocalLinkage() && Fn.hasAddressTaken()) {
1402         LLVM_DEBUG(dbgs() << "[externalize] "; Fn.printAsOperand(dbgs());
1403                    dbgs() << " because its address is taken\n");
1404         externalize(Fn);
1405       }
1406     }
1407   }
1408 
1409   // Externalize local GVs, which avoids duplicating their initializers, which
1410   // in turns helps keep code size in check.
1411   if (!NoExternalizeGlobals) {
1412     for (auto &GV : M.globals()) {
1413       if (GV.hasLocalLinkage())
1414         LLVM_DEBUG(dbgs() << "[externalize] GV " << GV.getName() << '\n');
1415       externalize(GV);
1416     }
1417   }
1418 
1419   // Start by calculating the cost of every function in the module, as well as
1420   // the module's overall cost.
1421   FunctionsCostMap FnCosts;
1422   const CostType ModuleCost = calculateFunctionCosts(GetTTI, M, FnCosts);
1423 
1424   // Build the SplitGraph, which represents the module's functions and models
1425   // their dependencies accurately.
1426   SplitGraph SG(M, FnCosts, ModuleCost);
1427   SG.buildGraph(CG);
1428 
1429   if (SG.empty()) {
1430     LLVM_DEBUG(
1431         dbgs()
1432         << "[!] no nodes in graph, input is empty - no splitting possible\n");
1433     ModuleCallback(cloneAll(M));
1434     return;
1435   }
1436 
1437   LLVM_DEBUG({
1438     dbgs() << "[graph] nodes:\n";
1439     for (const SplitGraph::Node *N : SG.nodes()) {
1440       dbgs() << "  - [" << N->getID() << "]: " << N->getName() << " "
1441              << (N->isGraphEntryPoint() ? "(entry)" : "") << " "
1442              << (N->isNonCopyable() ? "(noncopyable)" : "") << "\n";
1443     }
1444   });
1445 
1446   writeDOTGraph(SG);
1447 
1448   LLVM_DEBUG(dbgs() << "[search] testing splitting strategies\n");
1449 
1450   std::optional<SplitProposal> Proposal;
1451   const auto EvaluateProposal = [&](SplitProposal SP) {
1452     SP.calculateScores();
1453     if (!Proposal)
1454       Proposal = std::move(SP);
1455     else
1456       evaluateProposal(*Proposal, std::move(SP));
1457   };
1458 
1459   // TODO: It would be very easy to create new strategies by just adding a base
1460   // class to RecursiveSearchSplitting and abstracting it away.
1461   RecursiveSearchSplitting(SG, NumParts, EvaluateProposal).run();
1462   LLVM_DEBUG(if (Proposal) dbgs() << "[search done] selected proposal: "
1463                                   << Proposal->getName() << "\n";);
1464 
1465   if (!Proposal) {
1466     LLVM_DEBUG(dbgs() << "[!] no proposal made, no splitting possible!\n");
1467     ModuleCallback(cloneAll(M));
1468     return;
1469   }
1470 
1471   LLVM_DEBUG(Proposal->print(dbgs()););
1472 
1473   std::optional<raw_fd_ostream> SummariesOS;
1474   if (!PartitionSummariesOutput.empty()) {
1475     std::error_code EC;
1476     SummariesOS.emplace(PartitionSummariesOutput, EC);
1477     if (EC)
1478       errs() << "[" DEBUG_TYPE "]: cannot open '" << PartitionSummariesOutput
1479              << "' - Partition summaries will not be printed\n";
1480   }
1481 
1482   for (unsigned PID = 0; PID < NumParts; ++PID) {
1483     SplitModuleTimer SMT2("modules_creation",
1484                           "creating modules for each partition");
1485     LLVM_DEBUG(dbgs() << "[split] creating new modules\n");
1486 
1487     DenseSet<const Function *> FnsInPart;
1488     for (unsigned NodeID : (*Proposal)[PID].set_bits())
1489       FnsInPart.insert(&SG.getNode(NodeID).getFunction());
1490 
1491     ValueToValueMapTy VMap;
1492     CostType PartCost = 0;
1493     std::unique_ptr<Module> MPart(
1494         CloneModule(M, VMap, [&](const GlobalValue *GV) {
1495           // Functions go in their assigned partition.
1496           if (const auto *Fn = dyn_cast<Function>(GV)) {
1497             if (FnsInPart.contains(Fn)) {
1498               PartCost += SG.getCost(*Fn);
1499               return true;
1500             }
1501             return false;
1502           }
1503 
1504           // Everything else goes in the first partition.
1505           return needsConservativeImport(GV) || PID == 0;
1506         }));
1507 
1508     // FIXME: Aliases aren't seen often, and their handling isn't perfect so
1509     // bugs are possible.
1510 
1511     // Clean-up conservatively imported GVs without any users.
1512     for (auto &GV : make_early_inc_range(MPart->global_values())) {
1513       if (needsConservativeImport(&GV) && GV.use_empty())
1514         GV.eraseFromParent();
1515     }
1516 
1517     if (SummariesOS)
1518       printPartitionSummary(*SummariesOS, PID, *MPart, PartCost, ModuleCost);
1519 
1520     LLVM_DEBUG(
1521         printPartitionSummary(dbgs(), PID, *MPart, PartCost, ModuleCost));
1522 
1523     ModuleCallback(std::move(MPart));
1524   }
1525 }
1526 } // namespace
1527 
1528 PreservedAnalyses AMDGPUSplitModulePass::run(Module &M,
1529                                              ModuleAnalysisManager &MAM) {
1530   SplitModuleTimer SMT(
1531       "total", "total pass runtime (incl. potentially waiting for lockfile)");
1532 
1533   FunctionAnalysisManager &FAM =
1534       MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
1535   const auto TTIGetter = [&FAM](Function &F) -> const TargetTransformInfo & {
1536     return FAM.getResult<TargetIRAnalysis>(F);
1537   };
1538 
1539   bool Done = false;
1540 #ifndef NDEBUG
1541   if (UseLockFile) {
1542     SmallString<128> LockFilePath;
1543     sys::path::system_temp_directory(/*ErasedOnReboot=*/true, LockFilePath);
1544     sys::path::append(LockFilePath, "amdgpu-split-module-debug");
1545     LLVM_DEBUG(dbgs() << DEBUG_TYPE " using lockfile '" << LockFilePath
1546                       << "'\n");
1547 
1548     while (true) {
1549       llvm::LockFileManager Locked(LockFilePath.str());
1550       switch (Locked) {
1551       case LockFileManager::LFS_Error:
1552         LLVM_DEBUG(
1553             dbgs() << "[amdgpu-split-module] unable to acquire lockfile, debug "
1554                       "output may be mangled by other processes\n");
1555         Locked.unsafeRemoveLockFile();
1556         break;
1557       case LockFileManager::LFS_Owned:
1558         break;
1559       case LockFileManager::LFS_Shared: {
1560         switch (Locked.waitForUnlock()) {
1561         case LockFileManager::Res_Success:
1562           break;
1563         case LockFileManager::Res_OwnerDied:
1564           continue; // try again to get the lock.
1565         case LockFileManager::Res_Timeout:
1566           LLVM_DEBUG(
1567               dbgs()
1568               << "[amdgpu-split-module] unable to acquire lockfile, debug "
1569                  "output may be mangled by other processes\n");
1570           Locked.unsafeRemoveLockFile();
1571           break; // give up
1572         }
1573         break;
1574       }
1575       }
1576 
1577       splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
1578       Done = true;
1579       break;
1580     }
1581   }
1582 #endif
1583 
1584   if (!Done)
1585     splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
1586 
1587   // We can change linkage/visibilities in the input, consider that nothing is
1588   // preserved just to be safe. This pass runs last anyway.
1589   return PreservedAnalyses::none();
1590 }
1591 } // namespace llvm
1592