xref: /llvm-project/llvm/lib/Transforms/Utils/CodeLayout.cpp (revision 05d167fc201b4f2e96108be0d682f6800a70c23d)
1 //===- CodeLayout.cpp - Implementation of code layout algorithms ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The file implements "cache-aware" layout algorithms of basic blocks and
10 // functions in a binary.
11 //
12 // The algorithm tries to find a layout of nodes (basic blocks) of a given CFG
13 // optimizing jump locality and thus processor I-cache utilization. This is
14 // achieved via increasing the number of fall-through jumps and co-locating
15 // frequently executed nodes together. The name follows the underlying
16 // optimization problem, Extended-TSP, which is a generalization of classical
17 // (maximum) Traveling Salesmen Problem.
18 //
19 // The algorithm is a greedy heuristic that works with chains (ordered lists)
20 // of basic blocks. Initially all chains are isolated basic blocks. On every
21 // iteration, we pick a pair of chains whose merging yields the biggest increase
22 // in the ExtTSP score, which models how i-cache "friendly" a specific chain is.
23 // A pair of chains giving the maximum gain is merged into a new chain. The
24 // procedure stops when there is only one chain left, or when merging does not
25 // increase ExtTSP. In the latter case, the remaining chains are sorted by
26 // density in the decreasing order.
27 //
28 // An important aspect is the way two chains are merged. Unlike earlier
29 // algorithms (e.g., based on the approach of Pettis-Hansen), two
30 // chains, X and Y, are first split into three, X1, X2, and Y. Then we
31 // consider all possible ways of gluing the three chains (e.g., X1YX2, X1X2Y,
32 // X2X1Y, X2YX1, YX1X2, YX2X1) and choose the one producing the largest score.
33 // This improves the quality of the final result (the search space is larger)
34 // while keeping the implementation sufficiently fast.
35 //
36 // Reference:
37 //   * A. Newell and S. Pupyrev, Improved Basic Block Reordering,
38 //     IEEE Transactions on Computers, 2020
39 //     https://arxiv.org/abs/1809.04676
40 //
41 //===----------------------------------------------------------------------===//
42 
43 #include "llvm/Transforms/Utils/CodeLayout.h"
44 #include "llvm/Support/CommandLine.h"
45 #include "llvm/Support/Debug.h"
46 
47 #include <cmath>
48 #include <set>
49 
50 using namespace llvm;
51 using namespace llvm::codelayout;
52 
53 #define DEBUG_TYPE "code-layout"
54 
55 namespace llvm {
56 cl::opt<bool> EnableExtTspBlockPlacement(
57     "enable-ext-tsp-block-placement", cl::Hidden, cl::init(false),
58     cl::desc("Enable machine block placement based on the ext-tsp model, "
59              "optimizing I-cache utilization."));
60 
61 cl::opt<bool> ApplyExtTspWithoutProfile(
62     "ext-tsp-apply-without-profile",
63     cl::desc("Whether to apply ext-tsp placement for instances w/o profile"),
64     cl::init(true), cl::Hidden);
65 } // namespace llvm
66 
67 // Algorithm-specific params for Ext-TSP. The values are tuned for the best
68 // performance of large-scale front-end bound binaries.
69 static cl::opt<double> ForwardWeightCond(
70     "ext-tsp-forward-weight-cond", cl::ReallyHidden, cl::init(0.1),
71     cl::desc("The weight of conditional forward jumps for ExtTSP value"));
72 
73 static cl::opt<double> ForwardWeightUncond(
74     "ext-tsp-forward-weight-uncond", cl::ReallyHidden, cl::init(0.1),
75     cl::desc("The weight of unconditional forward jumps for ExtTSP value"));
76 
77 static cl::opt<double> BackwardWeightCond(
78     "ext-tsp-backward-weight-cond", cl::ReallyHidden, cl::init(0.1),
79     cl::desc("The weight of conditional backward jumps for ExtTSP value"));
80 
81 static cl::opt<double> BackwardWeightUncond(
82     "ext-tsp-backward-weight-uncond", cl::ReallyHidden, cl::init(0.1),
83     cl::desc("The weight of unconditional backward jumps for ExtTSP value"));
84 
85 static cl::opt<double> FallthroughWeightCond(
86     "ext-tsp-fallthrough-weight-cond", cl::ReallyHidden, cl::init(1.0),
87     cl::desc("The weight of conditional fallthrough jumps for ExtTSP value"));
88 
89 static cl::opt<double> FallthroughWeightUncond(
90     "ext-tsp-fallthrough-weight-uncond", cl::ReallyHidden, cl::init(1.05),
91     cl::desc("The weight of unconditional fallthrough jumps for ExtTSP value"));
92 
93 static cl::opt<unsigned> ForwardDistance(
94     "ext-tsp-forward-distance", cl::ReallyHidden, cl::init(1024),
95     cl::desc("The maximum distance (in bytes) of a forward jump for ExtTSP"));
96 
97 static cl::opt<unsigned> BackwardDistance(
98     "ext-tsp-backward-distance", cl::ReallyHidden, cl::init(640),
99     cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP"));
100 
101 // The maximum size of a chain created by the algorithm. The size is bounded
102 // so that the algorithm can efficiently process extremely large instances.
103 static cl::opt<unsigned>
104     MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(512),
105                  cl::desc("The maximum size of a chain to create"));
106 
107 // The maximum size of a chain for splitting. Larger values of the threshold
108 // may yield better quality at the cost of worsen run-time.
109 static cl::opt<unsigned> ChainSplitThreshold(
110     "ext-tsp-chain-split-threshold", cl::ReallyHidden, cl::init(128),
111     cl::desc("The maximum size of a chain to apply splitting"));
112 
113 // The maximum ratio between densities of two chains for merging.
114 static cl::opt<double> MaxMergeDensityRatio(
115     "ext-tsp-max-merge-density-ratio", cl::ReallyHidden, cl::init(100),
116     cl::desc("The maximum ratio between densities of two chains for merging"));
117 
118 // Algorithm-specific options for CDSort.
119 static cl::opt<unsigned> CacheEntries("cdsort-cache-entries", cl::ReallyHidden,
120                                       cl::desc("The size of the cache"));
121 
122 static cl::opt<unsigned> CacheSize("cdsort-cache-size", cl::ReallyHidden,
123                                    cl::desc("The size of a line in the cache"));
124 
125 static cl::opt<unsigned>
126     CDMaxChainSize("cdsort-max-chain-size", cl::ReallyHidden,
127                    cl::desc("The maximum size of a chain to create"));
128 
129 static cl::opt<double> DistancePower(
130     "cdsort-distance-power", cl::ReallyHidden,
131     cl::desc("The power exponent for the distance-based locality"));
132 
133 static cl::opt<double> FrequencyScale(
134     "cdsort-frequency-scale", cl::ReallyHidden,
135     cl::desc("The scale factor for the frequency-based locality"));
136 
137 namespace {
138 
139 // Epsilon for comparison of doubles.
140 constexpr double EPS = 1e-8;
141 
142 // Compute the Ext-TSP score for a given jump.
143 double jumpExtTSPScore(uint64_t JumpDist, uint64_t JumpMaxDist, uint64_t Count,
144                        double Weight) {
145   if (JumpDist > JumpMaxDist)
146     return 0;
147   double Prob = 1.0 - static_cast<double>(JumpDist) / JumpMaxDist;
148   return Weight * Prob * Count;
149 }
150 
151 // Compute the Ext-TSP score for a jump between a given pair of blocks,
152 // using their sizes, (estimated) addresses and the jump execution count.
153 double extTSPScore(uint64_t SrcAddr, uint64_t SrcSize, uint64_t DstAddr,
154                    uint64_t Count, bool IsConditional) {
155   // Fallthrough
156   if (SrcAddr + SrcSize == DstAddr) {
157     return jumpExtTSPScore(0, 1, Count,
158                            IsConditional ? FallthroughWeightCond
159                                          : FallthroughWeightUncond);
160   }
161   // Forward
162   if (SrcAddr + SrcSize < DstAddr) {
163     const uint64_t Dist = DstAddr - (SrcAddr + SrcSize);
164     return jumpExtTSPScore(Dist, ForwardDistance, Count,
165                            IsConditional ? ForwardWeightCond
166                                          : ForwardWeightUncond);
167   }
168   // Backward
169   const uint64_t Dist = SrcAddr + SrcSize - DstAddr;
170   return jumpExtTSPScore(Dist, BackwardDistance, Count,
171                          IsConditional ? BackwardWeightCond
172                                        : BackwardWeightUncond);
173 }
174 
175 /// A type of merging two chains, X and Y. The former chain is split into
176 /// X1 and X2 and then concatenated with Y in the order specified by the type.
177 enum class MergeTypeT : int { X_Y, Y_X, X1_Y_X2, Y_X2_X1, X2_X1_Y };
178 
179 /// The gain of merging two chains, that is, the Ext-TSP score of the merge
180 /// together with the corresponding merge 'type' and 'offset'.
181 struct MergeGainT {
182   explicit MergeGainT() = default;
183   explicit MergeGainT(double Score, size_t MergeOffset, MergeTypeT MergeType)
184       : Score(Score), MergeOffset(MergeOffset), MergeType(MergeType) {}
185 
186   double score() const { return Score; }
187 
188   size_t mergeOffset() const { return MergeOffset; }
189 
190   MergeTypeT mergeType() const { return MergeType; }
191 
192   void setMergeType(MergeTypeT Ty) { MergeType = Ty; }
193 
194   // Returns 'true' iff Other is preferred over this.
195   bool operator<(const MergeGainT &Other) const {
196     return (Other.Score > EPS && Other.Score > Score + EPS);
197   }
198 
199   // Update the current gain if Other is preferred over this.
200   void updateIfLessThan(const MergeGainT &Other) {
201     if (*this < Other)
202       *this = Other;
203   }
204 
205 private:
206   double Score{-1.0};
207   size_t MergeOffset{0};
208   MergeTypeT MergeType{MergeTypeT::X_Y};
209 };
210 
211 struct JumpT;
212 struct ChainT;
213 struct ChainEdge;
214 
215 /// A node in the graph, typically corresponding to a basic block in the CFG or
216 /// a function in the call graph.
217 struct NodeT {
218   NodeT(const NodeT &) = delete;
219   NodeT(NodeT &&) = default;
220   NodeT &operator=(const NodeT &) = delete;
221   NodeT &operator=(NodeT &&) = default;
222 
223   explicit NodeT(size_t Index, uint64_t Size, uint64_t Count)
224       : Index(Index), Size(Size), ExecutionCount(Count) {}
225 
226   bool isEntry() const { return Index == 0; }
227 
228   // Check if Other is a successor of the node.
229   bool isSuccessor(const NodeT *Other) const;
230 
231   // The total execution count of outgoing jumps.
232   uint64_t outCount() const;
233 
234   // The total execution count of incoming jumps.
235   uint64_t inCount() const;
236 
237   // The original index of the node in graph.
238   size_t Index{0};
239   // The index of the node in the current chain.
240   size_t CurIndex{0};
241   // The size of the node in the binary.
242   uint64_t Size{0};
243   // The execution count of the node in the profile data.
244   uint64_t ExecutionCount{0};
245   // The current chain of the node.
246   ChainT *CurChain{nullptr};
247   // The offset of the node in the current chain.
248   mutable uint64_t EstimatedAddr{0};
249   // Forced successor of the node in the graph.
250   NodeT *ForcedSucc{nullptr};
251   // Forced predecessor of the node in the graph.
252   NodeT *ForcedPred{nullptr};
253   // Outgoing jumps from the node.
254   std::vector<JumpT *> OutJumps;
255   // Incoming jumps to the node.
256   std::vector<JumpT *> InJumps;
257 };
258 
259 /// An arc in the graph, typically corresponding to a jump between two nodes.
260 struct JumpT {
261   JumpT(const JumpT &) = delete;
262   JumpT(JumpT &&) = default;
263   JumpT &operator=(const JumpT &) = delete;
264   JumpT &operator=(JumpT &&) = default;
265 
266   explicit JumpT(NodeT *Source, NodeT *Target, uint64_t ExecutionCount)
267       : Source(Source), Target(Target), ExecutionCount(ExecutionCount) {}
268 
269   // Source node of the jump.
270   NodeT *Source;
271   // Target node of the jump.
272   NodeT *Target;
273   // Execution count of the arc in the profile data.
274   uint64_t ExecutionCount{0};
275   // Whether the jump corresponds to a conditional branch.
276   bool IsConditional{false};
277   // The offset of the jump from the source node.
278   uint64_t Offset{0};
279 };
280 
281 /// A chain (ordered sequence) of nodes in the graph.
282 struct ChainT {
283   ChainT(const ChainT &) = delete;
284   ChainT(ChainT &&) = default;
285   ChainT &operator=(const ChainT &) = delete;
286   ChainT &operator=(ChainT &&) = default;
287 
288   explicit ChainT(uint64_t Id, NodeT *Node)
289       : Id(Id), ExecutionCount(Node->ExecutionCount), Size(Node->Size),
290         Nodes(1, Node) {}
291 
292   size_t numBlocks() const { return Nodes.size(); }
293 
294   double density() const { return ExecutionCount / Size; }
295 
296   bool isEntry() const { return Nodes[0]->Index == 0; }
297 
298   bool isCold() const {
299     for (NodeT *Node : Nodes) {
300       if (Node->ExecutionCount > 0)
301         return false;
302     }
303     return true;
304   }
305 
306   ChainEdge *getEdge(ChainT *Other) const {
307     for (const auto &[Chain, ChainEdge] : Edges) {
308       if (Chain == Other)
309         return ChainEdge;
310     }
311     return nullptr;
312   }
313 
314   void removeEdge(ChainT *Other) {
315     auto It = Edges.begin();
316     while (It != Edges.end()) {
317       if (It->first == Other) {
318         Edges.erase(It);
319         return;
320       }
321       It++;
322     }
323   }
324 
325   void addEdge(ChainT *Other, ChainEdge *Edge) {
326     Edges.push_back(std::make_pair(Other, Edge));
327   }
328 
329   void merge(ChainT *Other, std::vector<NodeT *> MergedBlocks) {
330     Nodes = std::move(MergedBlocks);
331     // Update the chain's data.
332     ExecutionCount += Other->ExecutionCount;
333     Size += Other->Size;
334     Id = Nodes[0]->Index;
335     // Update the node's data.
336     for (size_t Idx = 0; Idx < Nodes.size(); Idx++) {
337       Nodes[Idx]->CurChain = this;
338       Nodes[Idx]->CurIndex = Idx;
339     }
340   }
341 
342   void mergeEdges(ChainT *Other);
343 
344   void clear() {
345     Nodes.clear();
346     Nodes.shrink_to_fit();
347     Edges.clear();
348     Edges.shrink_to_fit();
349   }
350 
351   // Unique chain identifier.
352   uint64_t Id;
353   // Cached ext-tsp score for the chain.
354   double Score{0};
355   // The total execution count of the chain. Since the execution count of
356   // a basic block is uint64_t, using doubles here to avoid overflow.
357   double ExecutionCount{0};
358   // The total size of the chain.
359   uint64_t Size{0};
360   // Nodes of the chain.
361   std::vector<NodeT *> Nodes;
362   // Adjacent chains and corresponding edges (lists of jumps).
363   std::vector<std::pair<ChainT *, ChainEdge *>> Edges;
364 };
365 
366 /// An edge in the graph representing jumps between two chains.
367 /// When nodes are merged into chains, the edges are combined too so that
368 /// there is always at most one edge between a pair of chains.
369 struct ChainEdge {
370   ChainEdge(const ChainEdge &) = delete;
371   ChainEdge(ChainEdge &&) = default;
372   ChainEdge &operator=(const ChainEdge &) = delete;
373   ChainEdge &operator=(ChainEdge &&) = delete;
374 
375   explicit ChainEdge(JumpT *Jump)
376       : SrcChain(Jump->Source->CurChain), DstChain(Jump->Target->CurChain),
377         Jumps(1, Jump) {}
378 
379   ChainT *srcChain() const { return SrcChain; }
380 
381   ChainT *dstChain() const { return DstChain; }
382 
383   bool isSelfEdge() const { return SrcChain == DstChain; }
384 
385   const std::vector<JumpT *> &jumps() const { return Jumps; }
386 
387   void appendJump(JumpT *Jump) { Jumps.push_back(Jump); }
388 
389   void moveJumps(ChainEdge *Other) {
390     Jumps.insert(Jumps.end(), Other->Jumps.begin(), Other->Jumps.end());
391     Other->Jumps.clear();
392     Other->Jumps.shrink_to_fit();
393   }
394 
395   void changeEndpoint(ChainT *From, ChainT *To) {
396     if (From == SrcChain)
397       SrcChain = To;
398     if (From == DstChain)
399       DstChain = To;
400   }
401 
402   bool hasCachedMergeGain(ChainT *Src, ChainT *Dst) const {
403     return Src == SrcChain ? CacheValidForward : CacheValidBackward;
404   }
405 
406   MergeGainT getCachedMergeGain(ChainT *Src, ChainT *Dst) const {
407     return Src == SrcChain ? CachedGainForward : CachedGainBackward;
408   }
409 
410   void setCachedMergeGain(ChainT *Src, ChainT *Dst, MergeGainT MergeGain) {
411     if (Src == SrcChain) {
412       CachedGainForward = MergeGain;
413       CacheValidForward = true;
414     } else {
415       CachedGainBackward = MergeGain;
416       CacheValidBackward = true;
417     }
418   }
419 
420   void invalidateCache() {
421     CacheValidForward = false;
422     CacheValidBackward = false;
423   }
424 
425   void setMergeGain(MergeGainT Gain) { CachedGain = Gain; }
426 
427   MergeGainT getMergeGain() const { return CachedGain; }
428 
429   double gain() const { return CachedGain.score(); }
430 
431 private:
432   // Source chain.
433   ChainT *SrcChain{nullptr};
434   // Destination chain.
435   ChainT *DstChain{nullptr};
436   // Original jumps in the binary with corresponding execution counts.
437   std::vector<JumpT *> Jumps;
438   // Cached gain value for merging the pair of chains.
439   MergeGainT CachedGain;
440 
441   // Cached gain values for merging the pair of chains. Since the gain of
442   // merging (Src, Dst) and (Dst, Src) might be different, we store both values
443   // here and a flag indicating which of the options results in a higher gain.
444   // Cached gain values.
445   MergeGainT CachedGainForward;
446   MergeGainT CachedGainBackward;
447   // Whether the cached value must be recomputed.
448   bool CacheValidForward{false};
449   bool CacheValidBackward{false};
450 };
451 
452 bool NodeT::isSuccessor(const NodeT *Other) const {
453   for (JumpT *Jump : OutJumps)
454     if (Jump->Target == Other)
455       return true;
456   return false;
457 }
458 
459 uint64_t NodeT::outCount() const {
460   uint64_t Count = 0;
461   for (JumpT *Jump : OutJumps)
462     Count += Jump->ExecutionCount;
463   return Count;
464 }
465 
466 uint64_t NodeT::inCount() const {
467   uint64_t Count = 0;
468   for (JumpT *Jump : InJumps)
469     Count += Jump->ExecutionCount;
470   return Count;
471 }
472 
473 void ChainT::mergeEdges(ChainT *Other) {
474   // Update edges adjacent to chain Other.
475   for (const auto &[DstChain, DstEdge] : Other->Edges) {
476     ChainT *TargetChain = DstChain == Other ? this : DstChain;
477     ChainEdge *CurEdge = getEdge(TargetChain);
478     if (CurEdge == nullptr) {
479       DstEdge->changeEndpoint(Other, this);
480       this->addEdge(TargetChain, DstEdge);
481       if (DstChain != this && DstChain != Other)
482         DstChain->addEdge(this, DstEdge);
483     } else {
484       CurEdge->moveJumps(DstEdge);
485     }
486     // Cleanup leftover edge.
487     if (DstChain != Other)
488       DstChain->removeEdge(Other);
489   }
490 }
491 
492 using NodeIter = std::vector<NodeT *>::const_iterator;
493 static std::vector<NodeT *> EmptyList;
494 
495 /// A wrapper around three concatenated vectors (chains) of nodes; it is used
496 /// to avoid extra instantiation of the vectors.
497 struct MergedNodesT {
498   MergedNodesT(NodeIter Begin1, NodeIter End1,
499                NodeIter Begin2 = EmptyList.begin(),
500                NodeIter End2 = EmptyList.end(),
501                NodeIter Begin3 = EmptyList.begin(),
502                NodeIter End3 = EmptyList.end())
503       : Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
504         End3(End3) {}
505 
506   template <typename F> void forEach(const F &Func) const {
507     for (auto It = Begin1; It != End1; It++)
508       Func(*It);
509     for (auto It = Begin2; It != End2; It++)
510       Func(*It);
511     for (auto It = Begin3; It != End3; It++)
512       Func(*It);
513   }
514 
515   std::vector<NodeT *> getNodes() const {
516     std::vector<NodeT *> Result;
517     Result.reserve(std::distance(Begin1, End1) + std::distance(Begin2, End2) +
518                    std::distance(Begin3, End3));
519     Result.insert(Result.end(), Begin1, End1);
520     Result.insert(Result.end(), Begin2, End2);
521     Result.insert(Result.end(), Begin3, End3);
522     return Result;
523   }
524 
525   const NodeT *getFirstNode() const { return *Begin1; }
526 
527 private:
528   NodeIter Begin1;
529   NodeIter End1;
530   NodeIter Begin2;
531   NodeIter End2;
532   NodeIter Begin3;
533   NodeIter End3;
534 };
535 
536 /// A wrapper around two concatenated vectors (chains) of jumps.
537 struct MergedJumpsT {
538   MergedJumpsT(const std::vector<JumpT *> *Jumps1,
539                const std::vector<JumpT *> *Jumps2 = nullptr) {
540     assert(!Jumps1->empty() && "cannot merge empty jump list");
541     JumpArray[0] = Jumps1;
542     JumpArray[1] = Jumps2;
543   }
544 
545   template <typename F> void forEach(const F &Func) const {
546     for (auto Jumps : JumpArray)
547       if (Jumps != nullptr)
548         for (JumpT *Jump : *Jumps)
549           Func(Jump);
550   }
551 
552 private:
553   std::array<const std::vector<JumpT *> *, 2> JumpArray{nullptr, nullptr};
554 };
555 
556 /// Merge two chains of nodes respecting a given 'type' and 'offset'.
557 ///
558 /// If MergeType == 0, then the result is a concatenation of two chains.
559 /// Otherwise, the first chain is cut into two sub-chains at the offset,
560 /// and merged using all possible ways of concatenating three chains.
561 MergedNodesT mergeNodes(const std::vector<NodeT *> &X,
562                         const std::vector<NodeT *> &Y, size_t MergeOffset,
563                         MergeTypeT MergeType) {
564   // Split the first chain, X, into X1 and X2.
565   NodeIter BeginX1 = X.begin();
566   NodeIter EndX1 = X.begin() + MergeOffset;
567   NodeIter BeginX2 = X.begin() + MergeOffset;
568   NodeIter EndX2 = X.end();
569   NodeIter BeginY = Y.begin();
570   NodeIter EndY = Y.end();
571 
572   // Construct a new chain from the three existing ones.
573   switch (MergeType) {
574   case MergeTypeT::X_Y:
575     return MergedNodesT(BeginX1, EndX2, BeginY, EndY);
576   case MergeTypeT::Y_X:
577     return MergedNodesT(BeginY, EndY, BeginX1, EndX2);
578   case MergeTypeT::X1_Y_X2:
579     return MergedNodesT(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
580   case MergeTypeT::Y_X2_X1:
581     return MergedNodesT(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
582   case MergeTypeT::X2_X1_Y:
583     return MergedNodesT(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
584   }
585   llvm_unreachable("unexpected chain merge type");
586 }
587 
588 /// The implementation of the ExtTSP algorithm.
589 class ExtTSPImpl {
590 public:
591   ExtTSPImpl(ArrayRef<uint64_t> NodeSizes, ArrayRef<uint64_t> NodeCounts,
592              ArrayRef<EdgeCount> EdgeCounts)
593       : NumNodes(NodeSizes.size()) {
594     initialize(NodeSizes, NodeCounts, EdgeCounts);
595   }
596 
597   /// Run the algorithm and return an optimized ordering of nodes.
598   std::vector<uint64_t> run() {
599     // Pass 1: Merge nodes with their mutually forced successors
600     mergeForcedPairs();
601 
602     // Pass 2: Merge pairs of chains while improving the ExtTSP objective
603     mergeChainPairs();
604 
605     // Pass 3: Merge cold nodes to reduce code size
606     mergeColdChains();
607 
608     // Collect nodes from all chains
609     return concatChains();
610   }
611 
612 private:
613   /// Initialize the algorithm's data structures.
614   void initialize(const ArrayRef<uint64_t> &NodeSizes,
615                   const ArrayRef<uint64_t> &NodeCounts,
616                   const ArrayRef<EdgeCount> &EdgeCounts) {
617     // Initialize nodes.
618     AllNodes.reserve(NumNodes);
619     for (uint64_t Idx = 0; Idx < NumNodes; Idx++) {
620       uint64_t Size = std::max<uint64_t>(NodeSizes[Idx], 1ULL);
621       uint64_t ExecutionCount = NodeCounts[Idx];
622       // The execution count of the entry node is set to at least one.
623       if (Idx == 0 && ExecutionCount == 0)
624         ExecutionCount = 1;
625       AllNodes.emplace_back(Idx, Size, ExecutionCount);
626     }
627 
628     // Initialize jumps between the nodes.
629     SuccNodes.resize(NumNodes);
630     PredNodes.resize(NumNodes);
631     std::vector<uint64_t> OutDegree(NumNodes, 0);
632     AllJumps.reserve(EdgeCounts.size());
633     for (auto Edge : EdgeCounts) {
634       ++OutDegree[Edge.src];
635       // Ignore self-edges.
636       if (Edge.src == Edge.dst)
637         continue;
638 
639       SuccNodes[Edge.src].push_back(Edge.dst);
640       PredNodes[Edge.dst].push_back(Edge.src);
641       if (Edge.count > 0) {
642         NodeT &PredNode = AllNodes[Edge.src];
643         NodeT &SuccNode = AllNodes[Edge.dst];
644         AllJumps.emplace_back(&PredNode, &SuccNode, Edge.count);
645         SuccNode.InJumps.push_back(&AllJumps.back());
646         PredNode.OutJumps.push_back(&AllJumps.back());
647         // Adjust execution counts.
648         PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Edge.count);
649         SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Edge.count);
650       }
651     }
652     for (JumpT &Jump : AllJumps) {
653       assert(OutDegree[Jump.Source->Index] > 0 &&
654              "incorrectly computed out-degree of the block");
655       Jump.IsConditional = OutDegree[Jump.Source->Index] > 1;
656     }
657 
658     // Initialize chains.
659     AllChains.reserve(NumNodes);
660     HotChains.reserve(NumNodes);
661     for (NodeT &Node : AllNodes) {
662       // Create a chain.
663       AllChains.emplace_back(Node.Index, &Node);
664       Node.CurChain = &AllChains.back();
665       if (Node.ExecutionCount > 0)
666         HotChains.push_back(&AllChains.back());
667     }
668 
669     // Initialize chain edges.
670     AllEdges.reserve(AllJumps.size());
671     for (NodeT &PredNode : AllNodes) {
672       for (JumpT *Jump : PredNode.OutJumps) {
673         assert(Jump->ExecutionCount > 0 && "incorrectly initialized jump");
674         NodeT *SuccNode = Jump->Target;
675         ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
676         // This edge is already present in the graph.
677         if (CurEdge != nullptr) {
678           assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
679           CurEdge->appendJump(Jump);
680           continue;
681         }
682         // This is a new edge.
683         AllEdges.emplace_back(Jump);
684         PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
685         SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
686       }
687     }
688   }
689 
690   /// For a pair of nodes, A and B, node B is the forced successor of A,
691   /// if (i) all jumps (based on profile) from A goes to B and (ii) all jumps
692   /// to B are from A. Such nodes should be adjacent in the optimal ordering;
693   /// the method finds and merges such pairs of nodes.
694   void mergeForcedPairs() {
695     // Find forced pairs of blocks.
696     for (NodeT &Node : AllNodes) {
697       if (SuccNodes[Node.Index].size() == 1 &&
698           PredNodes[SuccNodes[Node.Index][0]].size() == 1 &&
699           SuccNodes[Node.Index][0] != 0) {
700         size_t SuccIndex = SuccNodes[Node.Index][0];
701         Node.ForcedSucc = &AllNodes[SuccIndex];
702         AllNodes[SuccIndex].ForcedPred = &Node;
703       }
704     }
705 
706     // There might be 'cycles' in the forced dependencies, since profile
707     // data isn't 100% accurate. Typically this is observed in loops, when the
708     // loop edges are the hottest successors for the basic blocks of the loop.
709     // Break the cycles by choosing the node with the smallest index as the
710     // head. This helps to keep the original order of the loops, which likely
711     // have already been rotated in the optimized manner.
712     for (NodeT &Node : AllNodes) {
713       if (Node.ForcedSucc == nullptr || Node.ForcedPred == nullptr)
714         continue;
715 
716       NodeT *SuccNode = Node.ForcedSucc;
717       while (SuccNode != nullptr && SuccNode != &Node) {
718         SuccNode = SuccNode->ForcedSucc;
719       }
720       if (SuccNode == nullptr)
721         continue;
722       // Break the cycle.
723       AllNodes[Node.ForcedPred->Index].ForcedSucc = nullptr;
724       Node.ForcedPred = nullptr;
725     }
726 
727     // Merge nodes with their fallthrough successors.
728     for (NodeT &Node : AllNodes) {
729       if (Node.ForcedPred == nullptr && Node.ForcedSucc != nullptr) {
730         const NodeT *CurBlock = &Node;
731         while (CurBlock->ForcedSucc != nullptr) {
732           const NodeT *NextBlock = CurBlock->ForcedSucc;
733           mergeChains(Node.CurChain, NextBlock->CurChain, 0, MergeTypeT::X_Y);
734           CurBlock = NextBlock;
735         }
736       }
737     }
738   }
739 
740   /// Merge pairs of chains while improving the ExtTSP objective.
741   void mergeChainPairs() {
742     /// Deterministically compare pairs of chains.
743     auto compareChainPairs = [](const ChainT *A1, const ChainT *B1,
744                                 const ChainT *A2, const ChainT *B2) {
745       return std::make_tuple(A1->Id, B1->Id) < std::make_tuple(A2->Id, B2->Id);
746     };
747 
748     while (HotChains.size() > 1) {
749       ChainT *BestChainPred = nullptr;
750       ChainT *BestChainSucc = nullptr;
751       MergeGainT BestGain;
752       // Iterate over all pairs of chains.
753       for (ChainT *ChainPred : HotChains) {
754         // Get candidates for merging with the current chain.
755         for (const auto &[ChainSucc, Edge] : ChainPred->Edges) {
756           // Ignore loop edges.
757           if (Edge->isSelfEdge())
758             continue;
759           // Skip the merge if the combined chain violates the maximum specified
760           // size.
761           if (ChainPred->numBlocks() + ChainSucc->numBlocks() >= MaxChainSize)
762             continue;
763           // Don't merge the chains if they have vastly different densities.
764           // Skip the merge if the ratio between the densities exceeds
765           // MaxMergeDensityRatio. Smaller values of the option result in fewer
766           // merges, and hence, more chains.
767           const double ChainPredDensity = ChainPred->density();
768           const double ChainSuccDensity = ChainSucc->density();
769           assert(ChainPredDensity > 0.0 && ChainSuccDensity > 0.0 &&
770                  "incorrectly computed chain densities");
771           auto [MinDensity, MaxDensity] =
772               std::minmax(ChainPredDensity, ChainSuccDensity);
773           const double Ratio = MaxDensity / MinDensity;
774           if (Ratio > MaxMergeDensityRatio)
775             continue;
776 
777           // Compute the gain of merging the two chains.
778           MergeGainT CurGain = getBestMergeGain(ChainPred, ChainSucc, Edge);
779           if (CurGain.score() <= EPS)
780             continue;
781 
782           if (BestGain < CurGain ||
783               (std::abs(CurGain.score() - BestGain.score()) < EPS &&
784                compareChainPairs(ChainPred, ChainSucc, BestChainPred,
785                                  BestChainSucc))) {
786             BestGain = CurGain;
787             BestChainPred = ChainPred;
788             BestChainSucc = ChainSucc;
789           }
790         }
791       }
792 
793       // Stop merging when there is no improvement.
794       if (BestGain.score() <= EPS)
795         break;
796 
797       // Merge the best pair of chains.
798       mergeChains(BestChainPred, BestChainSucc, BestGain.mergeOffset(),
799                   BestGain.mergeType());
800     }
801   }
802 
803   /// Merge remaining nodes into chains w/o taking jump counts into
804   /// consideration. This allows to maintain the original node order in the
805   /// absence of profile data.
806   void mergeColdChains() {
807     for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) {
808       // Iterating in reverse order to make sure original fallthrough jumps are
809       // merged first; this might be beneficial for code size.
810       size_t NumSuccs = SuccNodes[SrcBB].size();
811       for (size_t Idx = 0; Idx < NumSuccs; Idx++) {
812         size_t DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1];
813         ChainT *SrcChain = AllNodes[SrcBB].CurChain;
814         ChainT *DstChain = AllNodes[DstBB].CurChain;
815         if (SrcChain != DstChain && !DstChain->isEntry() &&
816             SrcChain->Nodes.back()->Index == SrcBB &&
817             DstChain->Nodes.front()->Index == DstBB &&
818             SrcChain->isCold() == DstChain->isCold()) {
819           mergeChains(SrcChain, DstChain, 0, MergeTypeT::X_Y);
820         }
821       }
822     }
823   }
824 
825   /// Compute the Ext-TSP score for a given node order and a list of jumps.
826   double extTSPScore(const MergedNodesT &Nodes,
827                      const MergedJumpsT &Jumps) const {
828     uint64_t CurAddr = 0;
829     Nodes.forEach([&](const NodeT *Node) {
830       Node->EstimatedAddr = CurAddr;
831       CurAddr += Node->Size;
832     });
833 
834     double Score = 0;
835     Jumps.forEach([&](const JumpT *Jump) {
836       const NodeT *SrcBlock = Jump->Source;
837       const NodeT *DstBlock = Jump->Target;
838       Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size,
839                              DstBlock->EstimatedAddr, Jump->ExecutionCount,
840                              Jump->IsConditional);
841     });
842     return Score;
843   }
844 
845   /// Compute the gain of merging two chains.
846   ///
847   /// The function considers all possible ways of merging two chains and
848   /// computes the one having the largest increase in ExtTSP objective. The
849   /// result is a pair with the first element being the gain and the second
850   /// element being the corresponding merging type.
851   MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
852                               ChainEdge *Edge) const {
853     if (Edge->hasCachedMergeGain(ChainPred, ChainSucc))
854       return Edge->getCachedMergeGain(ChainPred, ChainSucc);
855 
856     assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
857     // Precompute jumps between ChainPred and ChainSucc.
858     ChainEdge *EdgePP = ChainPred->getEdge(ChainPred);
859     MergedJumpsT Jumps(&Edge->jumps(), EdgePP ? &EdgePP->jumps() : nullptr);
860 
861     // This object holds the best chosen gain of merging two chains.
862     MergeGainT Gain = MergeGainT();
863 
864     /// Given a merge offset and a list of merge types, try to merge two chains
865     /// and update Gain with a better alternative.
866     auto tryChainMerging = [&](size_t Offset,
867                                const std::vector<MergeTypeT> &MergeTypes) {
868       // Skip merging corresponding to concatenation w/o splitting.
869       if (Offset == 0 || Offset == ChainPred->Nodes.size())
870         return;
871       // Skip merging if it breaks Forced successors.
872       NodeT *Node = ChainPred->Nodes[Offset - 1];
873       if (Node->ForcedSucc != nullptr)
874         return;
875       // Apply the merge, compute the corresponding gain, and update the best
876       // value, if the merge is beneficial.
877       for (const MergeTypeT &MergeType : MergeTypes) {
878         Gain.updateIfLessThan(
879             computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType));
880       }
881     };
882 
883     // Try to concatenate two chains w/o splitting.
884     Gain.updateIfLessThan(
885         computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeT::X_Y));
886 
887     // Attach (a part of) ChainPred before the first node of ChainSucc.
888     for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) {
889       const NodeT *SrcBlock = Jump->Source;
890       if (SrcBlock->CurChain != ChainPred)
891         continue;
892       size_t Offset = SrcBlock->CurIndex + 1;
893       tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y});
894     }
895 
896     // Attach (a part of) ChainPred after the last node of ChainSucc.
897     for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) {
898       const NodeT *DstBlock = Jump->Target;
899       if (DstBlock->CurChain != ChainPred)
900         continue;
901       size_t Offset = DstBlock->CurIndex;
902       tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1});
903     }
904 
905     // Try to break ChainPred in various ways and concatenate with ChainSucc.
906     if (ChainPred->Nodes.size() <= ChainSplitThreshold) {
907       for (size_t Offset = 1; Offset < ChainPred->Nodes.size(); Offset++) {
908         // Do not split the chain along a fall-through jump. One of the two
909         // loops above may still "break" such a jump whenever it results in a
910         // new fall-through.
911         const NodeT *BB = ChainPred->Nodes[Offset - 1];
912         const NodeT *BB2 = ChainPred->Nodes[Offset];
913         if (BB->isSuccessor(BB2))
914           continue;
915 
916         // In practice, applying X2_Y_X1 merging almost never provides benefits;
917         // thus, we exclude it from consideration to reduce the search space.
918         tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1,
919                                  MergeTypeT::X2_X1_Y});
920       }
921     }
922 
923     Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain);
924     return Gain;
925   }
926 
927   /// Compute the score gain of merging two chains, respecting a given
928   /// merge 'type' and 'offset'.
929   ///
930   /// The two chains are not modified in the method.
931   MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc,
932                               const MergedJumpsT &Jumps, size_t MergeOffset,
933                               MergeTypeT MergeType) const {
934     MergedNodesT MergedNodes =
935         mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
936 
937     // Do not allow a merge that does not preserve the original entry point.
938     if ((ChainPred->isEntry() || ChainSucc->isEntry()) &&
939         !MergedNodes.getFirstNode()->isEntry())
940       return MergeGainT();
941 
942     // The gain for the new chain.
943     double NewScore = extTSPScore(MergedNodes, Jumps);
944     double CurScore = ChainPred->Score;
945     return MergeGainT(NewScore - CurScore, MergeOffset, MergeType);
946   }
947 
948   /// Merge chain From into chain Into, update the list of active chains,
949   /// adjacency information, and the corresponding cached values.
950   void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset,
951                    MergeTypeT MergeType) {
952     assert(Into != From && "a chain cannot be merged with itself");
953 
954     // Merge the nodes.
955     MergedNodesT MergedNodes =
956         mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
957     Into->merge(From, MergedNodes.getNodes());
958 
959     // Merge the edges.
960     Into->mergeEdges(From);
961     From->clear();
962 
963     // Update cached ext-tsp score for the new chain.
964     ChainEdge *SelfEdge = Into->getEdge(Into);
965     if (SelfEdge != nullptr) {
966       MergedNodes = MergedNodesT(Into->Nodes.begin(), Into->Nodes.end());
967       MergedJumpsT MergedJumps(&SelfEdge->jumps());
968       Into->Score = extTSPScore(MergedNodes, MergedJumps);
969     }
970 
971     // Remove the chain from the list of active chains.
972     llvm::erase(HotChains, From);
973 
974     // Invalidate caches.
975     for (auto EdgeIt : Into->Edges)
976       EdgeIt.second->invalidateCache();
977   }
978 
979   /// Concatenate all chains into the final order.
980   std::vector<uint64_t> concatChains() {
981     // Collect non-empty chains.
982     std::vector<const ChainT *> SortedChains;
983     for (ChainT &Chain : AllChains) {
984       if (!Chain.Nodes.empty())
985         SortedChains.push_back(&Chain);
986     }
987 
988     // Sorting chains by density in the decreasing order.
989     llvm::sort(SortedChains, [&](const ChainT *L, const ChainT *R) {
990       // Place the entry point at the beginning of the order.
991       if (L->isEntry() != R->isEntry())
992         return L->isEntry();
993 
994       // Compare by density and break ties by chain identifiers.
995       return std::make_tuple(-L->density(), L->Id) <
996              std::make_tuple(-R->density(), R->Id);
997     });
998 
999     // Collect the nodes in the order specified by their chains.
1000     std::vector<uint64_t> Order;
1001     Order.reserve(NumNodes);
1002     for (const ChainT *Chain : SortedChains)
1003       for (NodeT *Node : Chain->Nodes)
1004         Order.push_back(Node->Index);
1005     return Order;
1006   }
1007 
1008 private:
1009   /// The number of nodes in the graph.
1010   const size_t NumNodes;
1011 
1012   /// Successors of each node.
1013   std::vector<std::vector<uint64_t>> SuccNodes;
1014 
1015   /// Predecessors of each node.
1016   std::vector<std::vector<uint64_t>> PredNodes;
1017 
1018   /// All nodes (basic blocks) in the graph.
1019   std::vector<NodeT> AllNodes;
1020 
1021   /// All jumps between the nodes.
1022   std::vector<JumpT> AllJumps;
1023 
1024   /// All chains of nodes.
1025   std::vector<ChainT> AllChains;
1026 
1027   /// All edges between the chains.
1028   std::vector<ChainEdge> AllEdges;
1029 
1030   /// Active chains. The vector gets updated at runtime when chains are merged.
1031   std::vector<ChainT *> HotChains;
1032 };
1033 
1034 /// The implementation of the Cache-Directed Sort (CDSort) algorithm for
1035 /// ordering functions represented by a call graph.
1036 class CDSortImpl {
1037 public:
1038   CDSortImpl(const CDSortConfig &Config, ArrayRef<uint64_t> NodeSizes,
1039              ArrayRef<uint64_t> NodeCounts, ArrayRef<EdgeCount> EdgeCounts,
1040              ArrayRef<uint64_t> EdgeOffsets)
1041       : Config(Config), NumNodes(NodeSizes.size()) {
1042     initialize(NodeSizes, NodeCounts, EdgeCounts, EdgeOffsets);
1043   }
1044 
1045   /// Run the algorithm and return an ordered set of function clusters.
1046   std::vector<uint64_t> run() {
1047     // Merge pairs of chains while improving the objective.
1048     mergeChainPairs();
1049 
1050     // Collect nodes from all the chains.
1051     return concatChains();
1052   }
1053 
1054 private:
1055   /// Initialize the algorithm's data structures.
1056   void initialize(const ArrayRef<uint64_t> &NodeSizes,
1057                   const ArrayRef<uint64_t> &NodeCounts,
1058                   const ArrayRef<EdgeCount> &EdgeCounts,
1059                   const ArrayRef<uint64_t> &EdgeOffsets) {
1060     // Initialize nodes.
1061     AllNodes.reserve(NumNodes);
1062     for (uint64_t Node = 0; Node < NumNodes; Node++) {
1063       uint64_t Size = std::max<uint64_t>(NodeSizes[Node], 1ULL);
1064       uint64_t ExecutionCount = NodeCounts[Node];
1065       AllNodes.emplace_back(Node, Size, ExecutionCount);
1066       TotalSamples += ExecutionCount;
1067       if (ExecutionCount > 0)
1068         TotalSize += Size;
1069     }
1070 
1071     // Initialize jumps between the nodes.
1072     SuccNodes.resize(NumNodes);
1073     PredNodes.resize(NumNodes);
1074     AllJumps.reserve(EdgeCounts.size());
1075     for (size_t I = 0; I < EdgeCounts.size(); I++) {
1076       auto [Pred, Succ, Count] = EdgeCounts[I];
1077       // Ignore recursive calls.
1078       if (Pred == Succ)
1079         continue;
1080 
1081       SuccNodes[Pred].push_back(Succ);
1082       PredNodes[Succ].push_back(Pred);
1083       if (Count > 0) {
1084         NodeT &PredNode = AllNodes[Pred];
1085         NodeT &SuccNode = AllNodes[Succ];
1086         AllJumps.emplace_back(&PredNode, &SuccNode, Count);
1087         AllJumps.back().Offset = EdgeOffsets[I];
1088         SuccNode.InJumps.push_back(&AllJumps.back());
1089         PredNode.OutJumps.push_back(&AllJumps.back());
1090         // Adjust execution counts.
1091         PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Count);
1092         SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Count);
1093       }
1094     }
1095 
1096     // Initialize chains.
1097     AllChains.reserve(NumNodes);
1098     for (NodeT &Node : AllNodes) {
1099       // Adjust execution counts.
1100       Node.ExecutionCount = std::max(Node.ExecutionCount, Node.inCount());
1101       Node.ExecutionCount = std::max(Node.ExecutionCount, Node.outCount());
1102       // Create chain.
1103       AllChains.emplace_back(Node.Index, &Node);
1104       Node.CurChain = &AllChains.back();
1105     }
1106 
1107     // Initialize chain edges.
1108     AllEdges.reserve(AllJumps.size());
1109     for (NodeT &PredNode : AllNodes) {
1110       for (JumpT *Jump : PredNode.OutJumps) {
1111         NodeT *SuccNode = Jump->Target;
1112         ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
1113         // This edge is already present in the graph.
1114         if (CurEdge != nullptr) {
1115           assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
1116           CurEdge->appendJump(Jump);
1117           continue;
1118         }
1119         // This is a new edge.
1120         AllEdges.emplace_back(Jump);
1121         PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
1122         SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
1123       }
1124     }
1125   }
1126 
1127   /// Merge pairs of chains while there is an improvement in the objective.
1128   void mergeChainPairs() {
1129     // Create a priority queue containing all edges ordered by the merge gain.
1130     auto GainComparator = [](ChainEdge *L, ChainEdge *R) {
1131       return std::make_tuple(-L->gain(), L->srcChain()->Id, L->dstChain()->Id) <
1132              std::make_tuple(-R->gain(), R->srcChain()->Id, R->dstChain()->Id);
1133     };
1134     std::set<ChainEdge *, decltype(GainComparator)> Queue(GainComparator);
1135 
1136     // Insert the edges into the queue.
1137     [[maybe_unused]] size_t NumActiveChains = 0;
1138     for (NodeT &Node : AllNodes) {
1139       if (Node.ExecutionCount == 0)
1140         continue;
1141       ++NumActiveChains;
1142       for (const auto &[_, Edge] : Node.CurChain->Edges) {
1143         // Ignore self-edges.
1144         if (Edge->isSelfEdge())
1145           continue;
1146         // Ignore already processed edges.
1147         if (Edge->gain() != -1.0)
1148           continue;
1149 
1150         // Compute the gain of merging the two chains.
1151         MergeGainT Gain = getBestMergeGain(Edge);
1152         Edge->setMergeGain(Gain);
1153 
1154         if (Edge->gain() > EPS)
1155           Queue.insert(Edge);
1156       }
1157     }
1158 
1159     // Merge the chains while the gain of merging is positive.
1160     while (!Queue.empty()) {
1161       // Extract the best (top) edge for merging.
1162       ChainEdge *BestEdge = *Queue.begin();
1163       Queue.erase(Queue.begin());
1164       ChainT *BestSrcChain = BestEdge->srcChain();
1165       ChainT *BestDstChain = BestEdge->dstChain();
1166 
1167       // Remove outdated edges from the queue.
1168       for (const auto &[_, ChainEdge] : BestSrcChain->Edges)
1169         Queue.erase(ChainEdge);
1170       for (const auto &[_, ChainEdge] : BestDstChain->Edges)
1171         Queue.erase(ChainEdge);
1172 
1173       // Merge the best pair of chains.
1174       MergeGainT BestGain = BestEdge->getMergeGain();
1175       mergeChains(BestSrcChain, BestDstChain, BestGain.mergeOffset(),
1176                   BestGain.mergeType());
1177       --NumActiveChains;
1178 
1179       // Insert newly created edges into the queue.
1180       for (const auto &[_, Edge] : BestSrcChain->Edges) {
1181         // Ignore loop edges.
1182         if (Edge->isSelfEdge())
1183           continue;
1184         if (Edge->srcChain()->numBlocks() + Edge->dstChain()->numBlocks() >
1185             Config.MaxChainSize)
1186           continue;
1187 
1188         // Compute the gain of merging the two chains.
1189         MergeGainT Gain = getBestMergeGain(Edge);
1190         Edge->setMergeGain(Gain);
1191 
1192         if (Edge->gain() > EPS)
1193           Queue.insert(Edge);
1194       }
1195     }
1196 
1197     LLVM_DEBUG(dbgs() << "Cache-directed function sorting reduced the number"
1198                       << " of chains from " << NumNodes << " to "
1199                       << NumActiveChains << "\n");
1200   }
1201 
1202   /// Compute the gain of merging two chains.
1203   ///
1204   /// The function considers all possible ways of merging two chains and
1205   /// computes the one having the largest increase in ExtTSP objective. The
1206   /// result is a pair with the first element being the gain and the second
1207   /// element being the corresponding merging type.
1208   MergeGainT getBestMergeGain(ChainEdge *Edge) const {
1209     assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
1210     // Precompute jumps between ChainPred and ChainSucc.
1211     MergedJumpsT Jumps(&Edge->jumps());
1212     ChainT *SrcChain = Edge->srcChain();
1213     ChainT *DstChain = Edge->dstChain();
1214 
1215     // This object holds the best currently chosen gain of merging two chains.
1216     MergeGainT Gain = MergeGainT();
1217 
1218     /// Given a list of merge types, try to merge two chains and update Gain
1219     /// with a better alternative.
1220     auto tryChainMerging = [&](const std::vector<MergeTypeT> &MergeTypes) {
1221       // Apply the merge, compute the corresponding gain, and update the best
1222       // value, if the merge is beneficial.
1223       for (const MergeTypeT &MergeType : MergeTypes) {
1224         MergeGainT NewGain =
1225             computeMergeGain(SrcChain, DstChain, Jumps, MergeType);
1226 
1227         // When forward and backward gains are the same, prioritize merging that
1228         // preserves the original order of the functions in the binary.
1229         if (std::abs(Gain.score() - NewGain.score()) < EPS) {
1230           if ((MergeType == MergeTypeT::X_Y && SrcChain->Id < DstChain->Id) ||
1231               (MergeType == MergeTypeT::Y_X && SrcChain->Id > DstChain->Id)) {
1232             Gain = NewGain;
1233           }
1234         } else if (NewGain.score() > Gain.score() + EPS) {
1235           Gain = NewGain;
1236         }
1237       }
1238     };
1239 
1240     // Try to concatenate two chains w/o splitting.
1241     tryChainMerging({MergeTypeT::X_Y, MergeTypeT::Y_X});
1242 
1243     return Gain;
1244   }
1245 
1246   /// Compute the score gain of merging two chains, respecting a given type.
1247   ///
1248   /// The two chains are not modified in the method.
1249   MergeGainT computeMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
1250                               const MergedJumpsT &Jumps,
1251                               MergeTypeT MergeType) const {
1252     // This doesn't depend on the ordering of the nodes
1253     double FreqGain = freqBasedLocalityGain(ChainPred, ChainSucc);
1254 
1255     // Merge offset is always 0, as the chains are not split.
1256     size_t MergeOffset = 0;
1257     auto MergedBlocks =
1258         mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
1259     double DistGain = distBasedLocalityGain(MergedBlocks, Jumps);
1260 
1261     double GainScore = DistGain + Config.FrequencyScale * FreqGain;
1262     // Scale the result to increase the importance of merging short chains.
1263     if (GainScore >= 0.0)
1264       GainScore /= std::min(ChainPred->Size, ChainSucc->Size);
1265 
1266     return MergeGainT(GainScore, MergeOffset, MergeType);
1267   }
1268 
1269   /// Compute the change of the frequency locality after merging the chains.
1270   double freqBasedLocalityGain(ChainT *ChainPred, ChainT *ChainSucc) const {
1271     auto missProbability = [&](double ChainDensity) {
1272       double PageSamples = ChainDensity * Config.CacheSize;
1273       if (PageSamples >= TotalSamples)
1274         return 0.0;
1275       double P = PageSamples / TotalSamples;
1276       return pow(1.0 - P, static_cast<double>(Config.CacheEntries));
1277     };
1278 
1279     // Cache misses on the chains before merging.
1280     double CurScore =
1281         ChainPred->ExecutionCount * missProbability(ChainPred->density()) +
1282         ChainSucc->ExecutionCount * missProbability(ChainSucc->density());
1283 
1284     // Cache misses on the merged chain
1285     double MergedCounts = ChainPred->ExecutionCount + ChainSucc->ExecutionCount;
1286     double MergedSize = ChainPred->Size + ChainSucc->Size;
1287     double MergedDensity = static_cast<double>(MergedCounts) / MergedSize;
1288     double NewScore = MergedCounts * missProbability(MergedDensity);
1289 
1290     return CurScore - NewScore;
1291   }
1292 
1293   /// Compute the distance locality for a jump / call.
1294   double distScore(uint64_t SrcAddr, uint64_t DstAddr, uint64_t Count) const {
1295     uint64_t Dist = SrcAddr <= DstAddr ? DstAddr - SrcAddr : SrcAddr - DstAddr;
1296     double D = Dist == 0 ? 0.1 : static_cast<double>(Dist);
1297     return static_cast<double>(Count) * std::pow(D, -Config.DistancePower);
1298   }
1299 
1300   /// Compute the change of the distance locality after merging the chains.
1301   double distBasedLocalityGain(const MergedNodesT &Nodes,
1302                                const MergedJumpsT &Jumps) const {
1303     uint64_t CurAddr = 0;
1304     Nodes.forEach([&](const NodeT *Node) {
1305       Node->EstimatedAddr = CurAddr;
1306       CurAddr += Node->Size;
1307     });
1308 
1309     double CurScore = 0;
1310     double NewScore = 0;
1311     Jumps.forEach([&](const JumpT *Jump) {
1312       uint64_t SrcAddr = Jump->Source->EstimatedAddr + Jump->Offset;
1313       uint64_t DstAddr = Jump->Target->EstimatedAddr;
1314       NewScore += distScore(SrcAddr, DstAddr, Jump->ExecutionCount);
1315       CurScore += distScore(0, TotalSize, Jump->ExecutionCount);
1316     });
1317     return NewScore - CurScore;
1318   }
1319 
1320   /// Merge chain From into chain Into, update the list of active chains,
1321   /// adjacency information, and the corresponding cached values.
1322   void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset,
1323                    MergeTypeT MergeType) {
1324     assert(Into != From && "a chain cannot be merged with itself");
1325 
1326     // Merge the nodes.
1327     MergedNodesT MergedNodes =
1328         mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
1329     Into->merge(From, MergedNodes.getNodes());
1330 
1331     // Merge the edges.
1332     Into->mergeEdges(From);
1333     From->clear();
1334   }
1335 
1336   /// Concatenate all chains into the final order.
1337   std::vector<uint64_t> concatChains() {
1338     // Collect chains and calculate density stats for their sorting.
1339     std::vector<const ChainT *> SortedChains;
1340     DenseMap<const ChainT *, double> ChainDensity;
1341     for (ChainT &Chain : AllChains) {
1342       if (!Chain.Nodes.empty()) {
1343         SortedChains.push_back(&Chain);
1344         // Using doubles to avoid overflow of ExecutionCounts.
1345         double Size = 0;
1346         double ExecutionCount = 0;
1347         for (NodeT *Node : Chain.Nodes) {
1348           Size += static_cast<double>(Node->Size);
1349           ExecutionCount += static_cast<double>(Node->ExecutionCount);
1350         }
1351         assert(Size > 0 && "a chain of zero size");
1352         ChainDensity[&Chain] = ExecutionCount / Size;
1353       }
1354     }
1355 
1356     // Sort chains by density in the decreasing order.
1357     llvm::sort(SortedChains, [&](const ChainT *L, const ChainT *R) {
1358       const double DL = ChainDensity[L];
1359       const double DR = ChainDensity[R];
1360       // Compare by density and break ties by chain identifiers.
1361       return std::make_tuple(-DL, L->Id) < std::make_tuple(-DR, R->Id);
1362     });
1363 
1364     // Collect the nodes in the order specified by their chains.
1365     std::vector<uint64_t> Order;
1366     Order.reserve(NumNodes);
1367     for (const ChainT *Chain : SortedChains)
1368       for (NodeT *Node : Chain->Nodes)
1369         Order.push_back(Node->Index);
1370     return Order;
1371   }
1372 
1373 private:
1374   /// Config for the algorithm.
1375   const CDSortConfig Config;
1376 
1377   /// The number of nodes in the graph.
1378   const size_t NumNodes;
1379 
1380   /// Successors of each node.
1381   std::vector<std::vector<uint64_t>> SuccNodes;
1382 
1383   /// Predecessors of each node.
1384   std::vector<std::vector<uint64_t>> PredNodes;
1385 
1386   /// All nodes (functions) in the graph.
1387   std::vector<NodeT> AllNodes;
1388 
1389   /// All jumps (function calls) between the nodes.
1390   std::vector<JumpT> AllJumps;
1391 
1392   /// All chains of nodes.
1393   std::vector<ChainT> AllChains;
1394 
1395   /// All edges between the chains.
1396   std::vector<ChainEdge> AllEdges;
1397 
1398   /// The total number of samples in the graph.
1399   uint64_t TotalSamples{0};
1400 
1401   /// The total size of the nodes in the graph.
1402   uint64_t TotalSize{0};
1403 };
1404 
1405 } // end of anonymous namespace
1406 
1407 std::vector<uint64_t>
1408 codelayout::computeExtTspLayout(ArrayRef<uint64_t> NodeSizes,
1409                                 ArrayRef<uint64_t> NodeCounts,
1410                                 ArrayRef<EdgeCount> EdgeCounts) {
1411   // Verify correctness of the input data.
1412   assert(NodeCounts.size() == NodeSizes.size() && "Incorrect input");
1413   assert(NodeSizes.size() > 2 && "Incorrect input");
1414 
1415   // Apply the reordering algorithm.
1416   ExtTSPImpl Alg(NodeSizes, NodeCounts, EdgeCounts);
1417   std::vector<uint64_t> Result = Alg.run();
1418 
1419   // Verify correctness of the output.
1420   assert(Result.front() == 0 && "Original entry point is not preserved");
1421   assert(Result.size() == NodeSizes.size() && "Incorrect size of layout");
1422   return Result;
1423 }
1424 
1425 double codelayout::calcExtTspScore(ArrayRef<uint64_t> Order,
1426                                    ArrayRef<uint64_t> NodeSizes,
1427                                    ArrayRef<uint64_t> NodeCounts,
1428                                    ArrayRef<EdgeCount> EdgeCounts) {
1429   // Estimate addresses of the blocks in memory.
1430   std::vector<uint64_t> Addr(NodeSizes.size(), 0);
1431   for (size_t Idx = 1; Idx < Order.size(); Idx++) {
1432     Addr[Order[Idx]] = Addr[Order[Idx - 1]] + NodeSizes[Order[Idx - 1]];
1433   }
1434   std::vector<uint64_t> OutDegree(NodeSizes.size(), 0);
1435   for (auto Edge : EdgeCounts)
1436     ++OutDegree[Edge.src];
1437 
1438   // Increase the score for each jump.
1439   double Score = 0;
1440   for (auto Edge : EdgeCounts) {
1441     bool IsConditional = OutDegree[Edge.src] > 1;
1442     Score += ::extTSPScore(Addr[Edge.src], NodeSizes[Edge.src], Addr[Edge.dst],
1443                            Edge.count, IsConditional);
1444   }
1445   return Score;
1446 }
1447 
1448 double codelayout::calcExtTspScore(ArrayRef<uint64_t> NodeSizes,
1449                                    ArrayRef<uint64_t> NodeCounts,
1450                                    ArrayRef<EdgeCount> EdgeCounts) {
1451   std::vector<uint64_t> Order(NodeSizes.size());
1452   for (size_t Idx = 0; Idx < NodeSizes.size(); Idx++) {
1453     Order[Idx] = Idx;
1454   }
1455   return calcExtTspScore(Order, NodeSizes, NodeCounts, EdgeCounts);
1456 }
1457 
1458 std::vector<uint64_t> codelayout::computeCacheDirectedLayout(
1459     const CDSortConfig &Config, ArrayRef<uint64_t> FuncSizes,
1460     ArrayRef<uint64_t> FuncCounts, ArrayRef<EdgeCount> CallCounts,
1461     ArrayRef<uint64_t> CallOffsets) {
1462   // Verify correctness of the input data.
1463   assert(FuncCounts.size() == FuncSizes.size() && "Incorrect input");
1464 
1465   // Apply the reordering algorithm.
1466   CDSortImpl Alg(Config, FuncSizes, FuncCounts, CallCounts, CallOffsets);
1467   std::vector<uint64_t> Result = Alg.run();
1468   assert(Result.size() == FuncSizes.size() && "Incorrect size of layout");
1469   return Result;
1470 }
1471 
1472 std::vector<uint64_t> codelayout::computeCacheDirectedLayout(
1473     ArrayRef<uint64_t> FuncSizes, ArrayRef<uint64_t> FuncCounts,
1474     ArrayRef<EdgeCount> CallCounts, ArrayRef<uint64_t> CallOffsets) {
1475   CDSortConfig Config;
1476   // Populate the config from the command-line options.
1477   if (CacheEntries.getNumOccurrences() > 0)
1478     Config.CacheEntries = CacheEntries;
1479   if (CacheSize.getNumOccurrences() > 0)
1480     Config.CacheSize = CacheSize;
1481   if (CDMaxChainSize.getNumOccurrences() > 0)
1482     Config.MaxChainSize = CDMaxChainSize;
1483   if (DistancePower.getNumOccurrences() > 0)
1484     Config.DistancePower = DistancePower;
1485   if (FrequencyScale.getNumOccurrences() > 0)
1486     Config.FrequencyScale = FrequencyScale;
1487   return computeCacheDirectedLayout(Config, FuncSizes, FuncCounts, CallCounts,
1488                                     CallOffsets);
1489 }
1490