xref: /llvm-project/llvm/lib/Transforms/Utils/CodeLayout.cpp (revision fbec1c2a08ce2ae9750ddf3cecc86c5dd2bbc9d8)
1f573f686Sspupyrev //===- CodeLayout.cpp - Implementation of code layout algorithms ----------===//
2f573f686Sspupyrev //
3f573f686Sspupyrev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f573f686Sspupyrev // See https://llvm.org/LICENSE.txt for license information.
5f573f686Sspupyrev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f573f686Sspupyrev //
7f573f686Sspupyrev //===----------------------------------------------------------------------===//
8f573f686Sspupyrev //
9a7e13a99Sspupyrev // The file implements "cache-aware" layout algorithms of basic blocks and
10a7e13a99Sspupyrev // functions in a binary.
11f573f686Sspupyrev //
12f573f686Sspupyrev // The algorithm tries to find a layout of nodes (basic blocks) of a given CFG
13f573f686Sspupyrev // optimizing jump locality and thus processor I-cache utilization. This is
14f573f686Sspupyrev // achieved via increasing the number of fall-through jumps and co-locating
15f573f686Sspupyrev // frequently executed nodes together. The name follows the underlying
16f573f686Sspupyrev // optimization problem, Extended-TSP, which is a generalization of classical
17f573f686Sspupyrev // (maximum) Traveling Salesmen Problem.
18f573f686Sspupyrev //
19f573f686Sspupyrev // The algorithm is a greedy heuristic that works with chains (ordered lists)
20f573f686Sspupyrev // of basic blocks. Initially all chains are isolated basic blocks. On every
21f573f686Sspupyrev // iteration, we pick a pair of chains whose merging yields the biggest increase
22f573f686Sspupyrev // in the ExtTSP score, which models how i-cache "friendly" a specific chain is.
23f573f686Sspupyrev // A pair of chains giving the maximum gain is merged into a new chain. The
24f573f686Sspupyrev // procedure stops when there is only one chain left, or when merging does not
25f573f686Sspupyrev // increase ExtTSP. In the latter case, the remaining chains are sorted by
26f573f686Sspupyrev // density in the decreasing order.
27f573f686Sspupyrev //
28f573f686Sspupyrev // An important aspect is the way two chains are merged. Unlike earlier
29f573f686Sspupyrev // algorithms (e.g., based on the approach of Pettis-Hansen), two
30f573f686Sspupyrev // chains, X and Y, are first split into three, X1, X2, and Y. Then we
31f573f686Sspupyrev // consider all possible ways of gluing the three chains (e.g., X1YX2, X1X2Y,
32f573f686Sspupyrev // X2X1Y, X2YX1, YX1X2, YX2X1) and choose the one producing the largest score.
33f573f686Sspupyrev // This improves the quality of the final result (the search space is larger)
34f573f686Sspupyrev // while keeping the implementation sufficiently fast.
35f573f686Sspupyrev //
36f573f686Sspupyrev // Reference:
37f573f686Sspupyrev //   * A. Newell and S. Pupyrev, Improved Basic Block Reordering,
38f573f686Sspupyrev //     IEEE Transactions on Computers, 2020
398d5b694dSspupyrev //     https://arxiv.org/abs/1809.04676
40f573f686Sspupyrev //
41f573f686Sspupyrev //===----------------------------------------------------------------------===//
42f573f686Sspupyrev 
43f573f686Sspupyrev #include "llvm/Transforms/Utils/CodeLayout.h"
44f573f686Sspupyrev #include "llvm/Support/CommandLine.h"
45a7e13a99Sspupyrev #include "llvm/Support/Debug.h"
46f573f686Sspupyrev 
47e3518210SYevgeny Rouban #include <cmath>
48bc59faa8Sspupyrev #include <set>
49e3518210SYevgeny Rouban 
50f573f686Sspupyrev using namespace llvm;
516b8d04c2SFangrui Song using namespace llvm::codelayout;
526b8d04c2SFangrui Song 
53f573f686Sspupyrev #define DEBUG_TYPE "code-layout"
54f573f686Sspupyrev 
551e692113SFangrui Song namespace llvm {
56dee058c6SHongtao Yu cl::opt<bool> EnableExtTspBlockPlacement(
57dee058c6SHongtao Yu     "enable-ext-tsp-block-placement", cl::Hidden, cl::init(false),
58dee058c6SHongtao Yu     cl::desc("Enable machine block placement based on the ext-tsp model, "
59dee058c6SHongtao Yu              "optimizing I-cache utilization."));
60dee058c6SHongtao Yu 
61bcdc0477Sspupyrev cl::opt<bool> ApplyExtTspWithoutProfile(
62bcdc0477Sspupyrev     "ext-tsp-apply-without-profile",
63bcdc0477Sspupyrev     cl::desc("Whether to apply ext-tsp placement for instances w/o profile"),
6436c7d79dSFangrui Song     cl::init(true), cl::Hidden);
651e692113SFangrui Song } // namespace llvm
66bcdc0477Sspupyrev 
67bc59faa8Sspupyrev // Algorithm-specific params for Ext-TSP. The values are tuned for the best
68bc59faa8Sspupyrev // performance of large-scale front-end bound binaries.
698d5b694dSspupyrev static cl::opt<double> ForwardWeightCond(
708d5b694dSspupyrev     "ext-tsp-forward-weight-cond", cl::ReallyHidden, cl::init(0.1),
718d5b694dSspupyrev     cl::desc("The weight of conditional forward jumps for ExtTSP value"));
72f573f686Sspupyrev 
738d5b694dSspupyrev static cl::opt<double> ForwardWeightUncond(
748d5b694dSspupyrev     "ext-tsp-forward-weight-uncond", cl::ReallyHidden, cl::init(0.1),
758d5b694dSspupyrev     cl::desc("The weight of unconditional forward jumps for ExtTSP value"));
768d5b694dSspupyrev 
778d5b694dSspupyrev static cl::opt<double> BackwardWeightCond(
788d5b694dSspupyrev     "ext-tsp-backward-weight-cond", cl::ReallyHidden, cl::init(0.1),
79a7e13a99Sspupyrev     cl::desc("The weight of conditional backward jumps for ExtTSP value"));
808d5b694dSspupyrev 
818d5b694dSspupyrev static cl::opt<double> BackwardWeightUncond(
828d5b694dSspupyrev     "ext-tsp-backward-weight-uncond", cl::ReallyHidden, cl::init(0.1),
83a7e13a99Sspupyrev     cl::desc("The weight of unconditional backward jumps for ExtTSP value"));
848d5b694dSspupyrev 
858d5b694dSspupyrev static cl::opt<double> FallthroughWeightCond(
868d5b694dSspupyrev     "ext-tsp-fallthrough-weight-cond", cl::ReallyHidden, cl::init(1.0),
878d5b694dSspupyrev     cl::desc("The weight of conditional fallthrough jumps for ExtTSP value"));
888d5b694dSspupyrev 
898d5b694dSspupyrev static cl::opt<double> FallthroughWeightUncond(
908d5b694dSspupyrev     "ext-tsp-fallthrough-weight-uncond", cl::ReallyHidden, cl::init(1.05),
918d5b694dSspupyrev     cl::desc("The weight of unconditional fallthrough jumps for ExtTSP value"));
92f573f686Sspupyrev 
93f573f686Sspupyrev static cl::opt<unsigned> ForwardDistance(
948d5b694dSspupyrev     "ext-tsp-forward-distance", cl::ReallyHidden, cl::init(1024),
95f573f686Sspupyrev     cl::desc("The maximum distance (in bytes) of a forward jump for ExtTSP"));
96f573f686Sspupyrev 
97f573f686Sspupyrev static cl::opt<unsigned> BackwardDistance(
988d5b694dSspupyrev     "ext-tsp-backward-distance", cl::ReallyHidden, cl::init(640),
99f573f686Sspupyrev     cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP"));
100f573f686Sspupyrev 
101bcdc0477Sspupyrev // The maximum size of a chain created by the algorithm. The size is bounded
102b90fcafcSspupyrev // so that the algorithm can efficiently process extremely large instances.
103bcdc0477Sspupyrev static cl::opt<unsigned>
104cc2fbc64Sspupyrev     MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(512),
105cc2fbc64Sspupyrev                  cl::desc("The maximum size of a chain to create"));
106bcdc0477Sspupyrev 
1075b39d8d3Sspupyrev // The maximum size of a chain for splitting. Larger values of the threshold
1085b39d8d3Sspupyrev // may yield better quality at the cost of worsen run-time.
1095b39d8d3Sspupyrev static cl::opt<unsigned> ChainSplitThreshold(
1105b39d8d3Sspupyrev     "ext-tsp-chain-split-threshold", cl::ReallyHidden, cl::init(128),
1115b39d8d3Sspupyrev     cl::desc("The maximum size of a chain to apply splitting"));
1125b39d8d3Sspupyrev 
113cc2fbc64Sspupyrev // The maximum ratio between densities of two chains for merging.
114cc2fbc64Sspupyrev static cl::opt<double> MaxMergeDensityRatio(
115cc2fbc64Sspupyrev     "ext-tsp-max-merge-density-ratio", cl::ReallyHidden, cl::init(100),
116cc2fbc64Sspupyrev     cl::desc("The maximum ratio between densities of two chains for merging"));
117f573f686Sspupyrev 
118f61179f8Sspupyrev // Algorithm-specific options for CDSort.
119f61179f8Sspupyrev static cl::opt<unsigned> CacheEntries("cdsort-cache-entries", cl::ReallyHidden,
120bc59faa8Sspupyrev                                       cl::desc("The size of the cache"));
121bc59faa8Sspupyrev 
122f61179f8Sspupyrev static cl::opt<unsigned> CacheSize("cdsort-cache-size", cl::ReallyHidden,
123bc59faa8Sspupyrev                                    cl::desc("The size of a line in the cache"));
124bc59faa8Sspupyrev 
125a2441837SFangrui Song static cl::opt<unsigned>
126a2441837SFangrui Song     CDMaxChainSize("cdsort-max-chain-size", cl::ReallyHidden,
127a2441837SFangrui Song                    cl::desc("The maximum size of a chain to create"));
128a2441837SFangrui Song 
129bc59faa8Sspupyrev static cl::opt<double> DistancePower(
130f61179f8Sspupyrev     "cdsort-distance-power", cl::ReallyHidden,
131bc59faa8Sspupyrev     cl::desc("The power exponent for the distance-based locality"));
132bc59faa8Sspupyrev 
133bc59faa8Sspupyrev static cl::opt<double> FrequencyScale(
134f61179f8Sspupyrev     "cdsort-frequency-scale", cl::ReallyHidden,
135bc59faa8Sspupyrev     cl::desc("The scale factor for the frequency-based locality"));
136bc59faa8Sspupyrev 
137f573f686Sspupyrev namespace {
138f573f686Sspupyrev 
139f573f686Sspupyrev // Epsilon for comparison of doubles.
140f573f686Sspupyrev constexpr double EPS = 1e-8;
141f573f686Sspupyrev 
1428d5b694dSspupyrev // Compute the Ext-TSP score for a given jump.
1438d5b694dSspupyrev double jumpExtTSPScore(uint64_t JumpDist, uint64_t JumpMaxDist, uint64_t Count,
1448d5b694dSspupyrev                        double Weight) {
1458d5b694dSspupyrev   if (JumpDist > JumpMaxDist)
1468d5b694dSspupyrev     return 0;
1478d5b694dSspupyrev   double Prob = 1.0 - static_cast<double>(JumpDist) / JumpMaxDist;
1488d5b694dSspupyrev   return Weight * Prob * Count;
1498d5b694dSspupyrev }
1508d5b694dSspupyrev 
151f573f686Sspupyrev // Compute the Ext-TSP score for a jump between a given pair of blocks,
152f573f686Sspupyrev // using their sizes, (estimated) addresses and the jump execution count.
153f573f686Sspupyrev double extTSPScore(uint64_t SrcAddr, uint64_t SrcSize, uint64_t DstAddr,
1548d5b694dSspupyrev                    uint64_t Count, bool IsConditional) {
155f573f686Sspupyrev   // Fallthrough
156f573f686Sspupyrev   if (SrcAddr + SrcSize == DstAddr) {
1578d5b694dSspupyrev     return jumpExtTSPScore(0, 1, Count,
1588d5b694dSspupyrev                            IsConditional ? FallthroughWeightCond
1598d5b694dSspupyrev                                          : FallthroughWeightUncond);
160f573f686Sspupyrev   }
161f573f686Sspupyrev   // Forward
162f573f686Sspupyrev   if (SrcAddr + SrcSize < DstAddr) {
1638d5b694dSspupyrev     const uint64_t Dist = DstAddr - (SrcAddr + SrcSize);
1648d5b694dSspupyrev     return jumpExtTSPScore(Dist, ForwardDistance, Count,
1658d5b694dSspupyrev                            IsConditional ? ForwardWeightCond
1668d5b694dSspupyrev                                          : ForwardWeightUncond);
167f573f686Sspupyrev   }
168f573f686Sspupyrev   // Backward
1698d5b694dSspupyrev   const uint64_t Dist = SrcAddr + SrcSize - DstAddr;
1708d5b694dSspupyrev   return jumpExtTSPScore(Dist, BackwardDistance, Count,
1718d5b694dSspupyrev                          IsConditional ? BackwardWeightCond
1728d5b694dSspupyrev                                        : BackwardWeightUncond);
173f573f686Sspupyrev }
174f573f686Sspupyrev 
175f573f686Sspupyrev /// A type of merging two chains, X and Y. The former chain is split into
176f573f686Sspupyrev /// X1 and X2 and then concatenated with Y in the order specified by the type.
177a7e13a99Sspupyrev enum class MergeTypeT : int { X_Y, Y_X, X1_Y_X2, Y_X2_X1, X2_X1_Y };
178f573f686Sspupyrev 
179f573f686Sspupyrev /// The gain of merging two chains, that is, the Ext-TSP score of the merge
180a7e13a99Sspupyrev /// together with the corresponding merge 'type' and 'offset'.
181a7e13a99Sspupyrev struct MergeGainT {
182a7e13a99Sspupyrev   explicit MergeGainT() = default;
183a7e13a99Sspupyrev   explicit MergeGainT(double Score, size_t MergeOffset, MergeTypeT MergeType)
184f573f686Sspupyrev       : Score(Score), MergeOffset(MergeOffset), MergeType(MergeType) {}
185f573f686Sspupyrev 
186f573f686Sspupyrev   double score() const { return Score; }
187f573f686Sspupyrev 
188f573f686Sspupyrev   size_t mergeOffset() const { return MergeOffset; }
189f573f686Sspupyrev 
190a7e13a99Sspupyrev   MergeTypeT mergeType() const { return MergeType; }
191a7e13a99Sspupyrev 
192a7e13a99Sspupyrev   void setMergeType(MergeTypeT Ty) { MergeType = Ty; }
193f573f686Sspupyrev 
194f573f686Sspupyrev   // Returns 'true' iff Other is preferred over this.
195a7e13a99Sspupyrev   bool operator<(const MergeGainT &Other) const {
196f573f686Sspupyrev     return (Other.Score > EPS && Other.Score > Score + EPS);
197f573f686Sspupyrev   }
198f573f686Sspupyrev 
199f573f686Sspupyrev   // Update the current gain if Other is preferred over this.
200a7e13a99Sspupyrev   void updateIfLessThan(const MergeGainT &Other) {
201f573f686Sspupyrev     if (*this < Other)
202f573f686Sspupyrev       *this = Other;
203f573f686Sspupyrev   }
204f573f686Sspupyrev 
205f573f686Sspupyrev private:
206f573f686Sspupyrev   double Score{-1.0};
207f573f686Sspupyrev   size_t MergeOffset{0};
208a7e13a99Sspupyrev   MergeTypeT MergeType{MergeTypeT::X_Y};
209f573f686Sspupyrev };
210f573f686Sspupyrev 
211a7e13a99Sspupyrev struct JumpT;
212a7e13a99Sspupyrev struct ChainT;
213a7e13a99Sspupyrev struct ChainEdge;
214f573f686Sspupyrev 
215a7e13a99Sspupyrev /// A node in the graph, typically corresponding to a basic block in the CFG or
216a7e13a99Sspupyrev /// a function in the call graph.
217a7e13a99Sspupyrev struct NodeT {
218a7e13a99Sspupyrev   NodeT(const NodeT &) = delete;
219a7e13a99Sspupyrev   NodeT(NodeT &&) = default;
220a7e13a99Sspupyrev   NodeT &operator=(const NodeT &) = delete;
221a7e13a99Sspupyrev   NodeT &operator=(NodeT &&) = default;
222f573f686Sspupyrev 
223b90fcafcSspupyrev   explicit NodeT(size_t Index, uint64_t Size, uint64_t Count)
224b90fcafcSspupyrev       : Index(Index), Size(Size), ExecutionCount(Count) {}
225a7e13a99Sspupyrev 
226f573f686Sspupyrev   bool isEntry() const { return Index == 0; }
227a7e13a99Sspupyrev 
228cc2fbc64Sspupyrev   // Check if Other is a successor of the node.
229cc2fbc64Sspupyrev   bool isSuccessor(const NodeT *Other) const;
230cc2fbc64Sspupyrev 
231a7e13a99Sspupyrev   // The total execution count of outgoing jumps.
232a7e13a99Sspupyrev   uint64_t outCount() const;
233a7e13a99Sspupyrev 
234a7e13a99Sspupyrev   // The total execution count of incoming jumps.
235a7e13a99Sspupyrev   uint64_t inCount() const;
236a7e13a99Sspupyrev 
237a7e13a99Sspupyrev   // The original index of the node in graph.
238a7e13a99Sspupyrev   size_t Index{0};
239a7e13a99Sspupyrev   // The index of the node in the current chain.
240a7e13a99Sspupyrev   size_t CurIndex{0};
241a7e13a99Sspupyrev   // The size of the node in the binary.
242a7e13a99Sspupyrev   uint64_t Size{0};
243a7e13a99Sspupyrev   // The execution count of the node in the profile data.
244a7e13a99Sspupyrev   uint64_t ExecutionCount{0};
245a7e13a99Sspupyrev   // The current chain of the node.
246a7e13a99Sspupyrev   ChainT *CurChain{nullptr};
247a7e13a99Sspupyrev   // The offset of the node in the current chain.
248a7e13a99Sspupyrev   mutable uint64_t EstimatedAddr{0};
249a7e13a99Sspupyrev   // Forced successor of the node in the graph.
250a7e13a99Sspupyrev   NodeT *ForcedSucc{nullptr};
251a7e13a99Sspupyrev   // Forced predecessor of the node in the graph.
252a7e13a99Sspupyrev   NodeT *ForcedPred{nullptr};
253a7e13a99Sspupyrev   // Outgoing jumps from the node.
254a7e13a99Sspupyrev   std::vector<JumpT *> OutJumps;
255a7e13a99Sspupyrev   // Incoming jumps to the node.
256a7e13a99Sspupyrev   std::vector<JumpT *> InJumps;
257f573f686Sspupyrev };
258f573f686Sspupyrev 
259a7e13a99Sspupyrev /// An arc in the graph, typically corresponding to a jump between two nodes.
260a7e13a99Sspupyrev struct JumpT {
261a7e13a99Sspupyrev   JumpT(const JumpT &) = delete;
262a7e13a99Sspupyrev   JumpT(JumpT &&) = default;
263a7e13a99Sspupyrev   JumpT &operator=(const JumpT &) = delete;
264a7e13a99Sspupyrev   JumpT &operator=(JumpT &&) = default;
265f573f686Sspupyrev 
266a7e13a99Sspupyrev   explicit JumpT(NodeT *Source, NodeT *Target, uint64_t ExecutionCount)
267a7e13a99Sspupyrev       : Source(Source), Target(Target), ExecutionCount(ExecutionCount) {}
268a7e13a99Sspupyrev 
269a7e13a99Sspupyrev   // Source node of the jump.
270a7e13a99Sspupyrev   NodeT *Source;
271a7e13a99Sspupyrev   // Target node of the jump.
272a7e13a99Sspupyrev   NodeT *Target;
273f573f686Sspupyrev   // Execution count of the arc in the profile data.
274f573f686Sspupyrev   uint64_t ExecutionCount{0};
2758d5b694dSspupyrev   // Whether the jump corresponds to a conditional branch.
2768d5b694dSspupyrev   bool IsConditional{false};
277a7e13a99Sspupyrev   // The offset of the jump from the source node.
278a7e13a99Sspupyrev   uint64_t Offset{0};
279f573f686Sspupyrev };
280f573f686Sspupyrev 
281a7e13a99Sspupyrev /// A chain (ordered sequence) of nodes in the graph.
282a7e13a99Sspupyrev struct ChainT {
283a7e13a99Sspupyrev   ChainT(const ChainT &) = delete;
284a7e13a99Sspupyrev   ChainT(ChainT &&) = default;
285a7e13a99Sspupyrev   ChainT &operator=(const ChainT &) = delete;
286a7e13a99Sspupyrev   ChainT &operator=(ChainT &&) = default;
287f573f686Sspupyrev 
288a7e13a99Sspupyrev   explicit ChainT(uint64_t Id, NodeT *Node)
289a7e13a99Sspupyrev       : Id(Id), ExecutionCount(Node->ExecutionCount), Size(Node->Size),
290a7e13a99Sspupyrev         Nodes(1, Node) {}
291f573f686Sspupyrev 
292a7e13a99Sspupyrev   size_t numBlocks() const { return Nodes.size(); }
293f573f686Sspupyrev 
294cc2fbc64Sspupyrev   double density() const { return ExecutionCount / Size; }
295a7e13a99Sspupyrev 
296a7e13a99Sspupyrev   bool isEntry() const { return Nodes[0]->Index == 0; }
297f573f686Sspupyrev 
2988d5b694dSspupyrev   bool isCold() const {
299a7e13a99Sspupyrev     for (NodeT *Node : Nodes) {
300a7e13a99Sspupyrev       if (Node->ExecutionCount > 0)
3018d5b694dSspupyrev         return false;
3028d5b694dSspupyrev     }
3038d5b694dSspupyrev     return true;
3048d5b694dSspupyrev   }
3058d5b694dSspupyrev 
306a7e13a99Sspupyrev   ChainEdge *getEdge(ChainT *Other) const {
307bc59faa8Sspupyrev     for (const auto &[Chain, ChainEdge] : Edges) {
308bc59faa8Sspupyrev       if (Chain == Other)
309bc59faa8Sspupyrev         return ChainEdge;
310f573f686Sspupyrev     }
311f573f686Sspupyrev     return nullptr;
312f573f686Sspupyrev   }
313f573f686Sspupyrev 
314a7e13a99Sspupyrev   void removeEdge(ChainT *Other) {
315f573f686Sspupyrev     auto It = Edges.begin();
316f573f686Sspupyrev     while (It != Edges.end()) {
317f573f686Sspupyrev       if (It->first == Other) {
318f573f686Sspupyrev         Edges.erase(It);
319f573f686Sspupyrev         return;
320f573f686Sspupyrev       }
321f573f686Sspupyrev       It++;
322f573f686Sspupyrev     }
323f573f686Sspupyrev   }
324f573f686Sspupyrev 
325a7e13a99Sspupyrev   void addEdge(ChainT *Other, ChainEdge *Edge) {
326f573f686Sspupyrev     Edges.push_back(std::make_pair(Other, Edge));
327f573f686Sspupyrev   }
328f573f686Sspupyrev 
3296b8d04c2SFangrui Song   void merge(ChainT *Other, std::vector<NodeT *> MergedBlocks) {
3306b8d04c2SFangrui Song     Nodes = std::move(MergedBlocks);
331bc59faa8Sspupyrev     // Update the chain's data.
332a7e13a99Sspupyrev     ExecutionCount += Other->ExecutionCount;
333a7e13a99Sspupyrev     Size += Other->Size;
334a7e13a99Sspupyrev     Id = Nodes[0]->Index;
335bc59faa8Sspupyrev     // Update the node's data.
336a7e13a99Sspupyrev     for (size_t Idx = 0; Idx < Nodes.size(); Idx++) {
337a7e13a99Sspupyrev       Nodes[Idx]->CurChain = this;
338a7e13a99Sspupyrev       Nodes[Idx]->CurIndex = Idx;
339f573f686Sspupyrev     }
340f573f686Sspupyrev   }
341f573f686Sspupyrev 
342a7e13a99Sspupyrev   void mergeEdges(ChainT *Other);
343f573f686Sspupyrev 
344f573f686Sspupyrev   void clear() {
345a7e13a99Sspupyrev     Nodes.clear();
346a7e13a99Sspupyrev     Nodes.shrink_to_fit();
347f573f686Sspupyrev     Edges.clear();
348f573f686Sspupyrev     Edges.shrink_to_fit();
349f573f686Sspupyrev   }
350f573f686Sspupyrev 
351f573f686Sspupyrev   // Unique chain identifier.
352f573f686Sspupyrev   uint64_t Id;
353f573f686Sspupyrev   // Cached ext-tsp score for the chain.
354a7e13a99Sspupyrev   double Score{0};
355cc2fbc64Sspupyrev   // The total execution count of the chain. Since the execution count of
356cc2fbc64Sspupyrev   // a basic block is uint64_t, using doubles here to avoid overflow.
357cc2fbc64Sspupyrev   double ExecutionCount{0};
358a7e13a99Sspupyrev   // The total size of the chain.
359a7e13a99Sspupyrev   uint64_t Size{0};
360a7e13a99Sspupyrev   // Nodes of the chain.
361a7e13a99Sspupyrev   std::vector<NodeT *> Nodes;
362f573f686Sspupyrev   // Adjacent chains and corresponding edges (lists of jumps).
363a7e13a99Sspupyrev   std::vector<std::pair<ChainT *, ChainEdge *>> Edges;
364f573f686Sspupyrev };
365f573f686Sspupyrev 
366a7e13a99Sspupyrev /// An edge in the graph representing jumps between two chains.
367a7e13a99Sspupyrev /// When nodes are merged into chains, the edges are combined too so that
368bc59faa8Sspupyrev /// there is always at most one edge between a pair of chains.
369a7e13a99Sspupyrev struct ChainEdge {
370f573f686Sspupyrev   ChainEdge(const ChainEdge &) = delete;
371f573f686Sspupyrev   ChainEdge(ChainEdge &&) = default;
372f573f686Sspupyrev   ChainEdge &operator=(const ChainEdge &) = delete;
373a7e13a99Sspupyrev   ChainEdge &operator=(ChainEdge &&) = delete;
374f573f686Sspupyrev 
375a7e13a99Sspupyrev   explicit ChainEdge(JumpT *Jump)
376f573f686Sspupyrev       : SrcChain(Jump->Source->CurChain), DstChain(Jump->Target->CurChain),
377f573f686Sspupyrev         Jumps(1, Jump) {}
378f573f686Sspupyrev 
379a7e13a99Sspupyrev   ChainT *srcChain() const { return SrcChain; }
380f573f686Sspupyrev 
381a7e13a99Sspupyrev   ChainT *dstChain() const { return DstChain; }
382f573f686Sspupyrev 
383a7e13a99Sspupyrev   bool isSelfEdge() const { return SrcChain == DstChain; }
384a7e13a99Sspupyrev 
385a7e13a99Sspupyrev   const std::vector<JumpT *> &jumps() const { return Jumps; }
386a7e13a99Sspupyrev 
387a7e13a99Sspupyrev   void appendJump(JumpT *Jump) { Jumps.push_back(Jump); }
388f573f686Sspupyrev 
389f573f686Sspupyrev   void moveJumps(ChainEdge *Other) {
390f573f686Sspupyrev     Jumps.insert(Jumps.end(), Other->Jumps.begin(), Other->Jumps.end());
391f573f686Sspupyrev     Other->Jumps.clear();
392f573f686Sspupyrev     Other->Jumps.shrink_to_fit();
393f573f686Sspupyrev   }
394f573f686Sspupyrev 
395a7e13a99Sspupyrev   void changeEndpoint(ChainT *From, ChainT *To) {
396a7e13a99Sspupyrev     if (From == SrcChain)
397a7e13a99Sspupyrev       SrcChain = To;
398a7e13a99Sspupyrev     if (From == DstChain)
399a7e13a99Sspupyrev       DstChain = To;
400a7e13a99Sspupyrev   }
401a7e13a99Sspupyrev 
402a7e13a99Sspupyrev   bool hasCachedMergeGain(ChainT *Src, ChainT *Dst) const {
403f573f686Sspupyrev     return Src == SrcChain ? CacheValidForward : CacheValidBackward;
404f573f686Sspupyrev   }
405f573f686Sspupyrev 
406a7e13a99Sspupyrev   MergeGainT getCachedMergeGain(ChainT *Src, ChainT *Dst) const {
407f573f686Sspupyrev     return Src == SrcChain ? CachedGainForward : CachedGainBackward;
408f573f686Sspupyrev   }
409f573f686Sspupyrev 
410a7e13a99Sspupyrev   void setCachedMergeGain(ChainT *Src, ChainT *Dst, MergeGainT MergeGain) {
411f573f686Sspupyrev     if (Src == SrcChain) {
412f573f686Sspupyrev       CachedGainForward = MergeGain;
413f573f686Sspupyrev       CacheValidForward = true;
414f573f686Sspupyrev     } else {
415f573f686Sspupyrev       CachedGainBackward = MergeGain;
416f573f686Sspupyrev       CacheValidBackward = true;
417f573f686Sspupyrev     }
418f573f686Sspupyrev   }
419f573f686Sspupyrev 
420f573f686Sspupyrev   void invalidateCache() {
421f573f686Sspupyrev     CacheValidForward = false;
422f573f686Sspupyrev     CacheValidBackward = false;
423f573f686Sspupyrev   }
424f573f686Sspupyrev 
425a7e13a99Sspupyrev   void setMergeGain(MergeGainT Gain) { CachedGain = Gain; }
426a7e13a99Sspupyrev 
427a7e13a99Sspupyrev   MergeGainT getMergeGain() const { return CachedGain; }
428a7e13a99Sspupyrev 
429a7e13a99Sspupyrev   double gain() const { return CachedGain.score(); }
430a7e13a99Sspupyrev 
431f573f686Sspupyrev private:
432f573f686Sspupyrev   // Source chain.
433a7e13a99Sspupyrev   ChainT *SrcChain{nullptr};
434f573f686Sspupyrev   // Destination chain.
435a7e13a99Sspupyrev   ChainT *DstChain{nullptr};
436a7e13a99Sspupyrev   // Original jumps in the binary with corresponding execution counts.
437a7e13a99Sspupyrev   std::vector<JumpT *> Jumps;
438a7e13a99Sspupyrev   // Cached gain value for merging the pair of chains.
439a7e13a99Sspupyrev   MergeGainT CachedGain;
440a7e13a99Sspupyrev 
441a7e13a99Sspupyrev   // Cached gain values for merging the pair of chains. Since the gain of
442a7e13a99Sspupyrev   // merging (Src, Dst) and (Dst, Src) might be different, we store both values
443a7e13a99Sspupyrev   // here and a flag indicating which of the options results in a higher gain.
444a7e13a99Sspupyrev   // Cached gain values.
445a7e13a99Sspupyrev   MergeGainT CachedGainForward;
446a7e13a99Sspupyrev   MergeGainT CachedGainBackward;
447f573f686Sspupyrev   // Whether the cached value must be recomputed.
448f573f686Sspupyrev   bool CacheValidForward{false};
449f573f686Sspupyrev   bool CacheValidBackward{false};
450f573f686Sspupyrev };
451f573f686Sspupyrev 
452cc2fbc64Sspupyrev bool NodeT::isSuccessor(const NodeT *Other) const {
453cc2fbc64Sspupyrev   for (JumpT *Jump : OutJumps)
454cc2fbc64Sspupyrev     if (Jump->Target == Other)
455cc2fbc64Sspupyrev       return true;
456cc2fbc64Sspupyrev   return false;
457cc2fbc64Sspupyrev }
458cc2fbc64Sspupyrev 
459a7e13a99Sspupyrev uint64_t NodeT::outCount() const {
460a7e13a99Sspupyrev   uint64_t Count = 0;
461bc59faa8Sspupyrev   for (JumpT *Jump : OutJumps)
462a7e13a99Sspupyrev     Count += Jump->ExecutionCount;
463a7e13a99Sspupyrev   return Count;
464a7e13a99Sspupyrev }
465f573f686Sspupyrev 
466a7e13a99Sspupyrev uint64_t NodeT::inCount() const {
467a7e13a99Sspupyrev   uint64_t Count = 0;
468bc59faa8Sspupyrev   for (JumpT *Jump : InJumps)
469a7e13a99Sspupyrev     Count += Jump->ExecutionCount;
470a7e13a99Sspupyrev   return Count;
471a7e13a99Sspupyrev }
472a7e13a99Sspupyrev 
473a7e13a99Sspupyrev void ChainT::mergeEdges(ChainT *Other) {
474bc59faa8Sspupyrev   // Update edges adjacent to chain Other.
475bc59faa8Sspupyrev   for (const auto &[DstChain, DstEdge] : Other->Edges) {
476a7e13a99Sspupyrev     ChainT *TargetChain = DstChain == Other ? this : DstChain;
4778d5b694dSspupyrev     ChainEdge *CurEdge = getEdge(TargetChain);
478f573f686Sspupyrev     if (CurEdge == nullptr) {
479f573f686Sspupyrev       DstEdge->changeEndpoint(Other, this);
480f573f686Sspupyrev       this->addEdge(TargetChain, DstEdge);
481bc59faa8Sspupyrev       if (DstChain != this && DstChain != Other)
482f573f686Sspupyrev         DstChain->addEdge(this, DstEdge);
483f573f686Sspupyrev     } else {
484f573f686Sspupyrev       CurEdge->moveJumps(DstEdge);
485f573f686Sspupyrev     }
486bc59faa8Sspupyrev     // Cleanup leftover edge.
487bc59faa8Sspupyrev     if (DstChain != Other)
488f573f686Sspupyrev       DstChain->removeEdge(Other);
489f573f686Sspupyrev   }
490f573f686Sspupyrev }
491f573f686Sspupyrev 
4925b39d8d3Sspupyrev using NodeIter = std::vector<NodeT *>::const_iterator;
4933d75c7c1Sspupyrev static std::vector<NodeT *> EmptyList;
494f573f686Sspupyrev 
495b90fcafcSspupyrev /// A wrapper around three concatenated vectors (chains) of nodes; it is used
496b90fcafcSspupyrev /// to avoid extra instantiation of the vectors.
497b90fcafcSspupyrev struct MergedNodesT {
4983d75c7c1Sspupyrev   MergedNodesT(NodeIter Begin1, NodeIter End1,
4993d75c7c1Sspupyrev                NodeIter Begin2 = EmptyList.begin(),
5003d75c7c1Sspupyrev                NodeIter End2 = EmptyList.end(),
5013d75c7c1Sspupyrev                NodeIter Begin3 = EmptyList.begin(),
5023d75c7c1Sspupyrev                NodeIter End3 = EmptyList.end())
503f573f686Sspupyrev       : Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
504f573f686Sspupyrev         End3(End3) {}
505f573f686Sspupyrev 
506f573f686Sspupyrev   template <typename F> void forEach(const F &Func) const {
507f573f686Sspupyrev     for (auto It = Begin1; It != End1; It++)
508f573f686Sspupyrev       Func(*It);
509f573f686Sspupyrev     for (auto It = Begin2; It != End2; It++)
510f573f686Sspupyrev       Func(*It);
511f573f686Sspupyrev     for (auto It = Begin3; It != End3; It++)
512f573f686Sspupyrev       Func(*It);
513f573f686Sspupyrev   }
514f573f686Sspupyrev 
515a7e13a99Sspupyrev   std::vector<NodeT *> getNodes() const {
516a7e13a99Sspupyrev     std::vector<NodeT *> Result;
517f573f686Sspupyrev     Result.reserve(std::distance(Begin1, End1) + std::distance(Begin2, End2) +
518f573f686Sspupyrev                    std::distance(Begin3, End3));
519f573f686Sspupyrev     Result.insert(Result.end(), Begin1, End1);
520f573f686Sspupyrev     Result.insert(Result.end(), Begin2, End2);
521f573f686Sspupyrev     Result.insert(Result.end(), Begin3, End3);
522f573f686Sspupyrev     return Result;
523f573f686Sspupyrev   }
524f573f686Sspupyrev 
525a7e13a99Sspupyrev   const NodeT *getFirstNode() const { return *Begin1; }
526f573f686Sspupyrev 
527f573f686Sspupyrev private:
5285b39d8d3Sspupyrev   NodeIter Begin1;
5295b39d8d3Sspupyrev   NodeIter End1;
5305b39d8d3Sspupyrev   NodeIter Begin2;
5315b39d8d3Sspupyrev   NodeIter End2;
5325b39d8d3Sspupyrev   NodeIter Begin3;
5335b39d8d3Sspupyrev   NodeIter End3;
534f573f686Sspupyrev };
535f573f686Sspupyrev 
536b90fcafcSspupyrev /// A wrapper around two concatenated vectors (chains) of jumps.
537b90fcafcSspupyrev struct MergedJumpsT {
538b90fcafcSspupyrev   MergedJumpsT(const std::vector<JumpT *> *Jumps1,
539b90fcafcSspupyrev                const std::vector<JumpT *> *Jumps2 = nullptr) {
540b90fcafcSspupyrev     assert(!Jumps1->empty() && "cannot merge empty jump list");
541b90fcafcSspupyrev     JumpArray[0] = Jumps1;
542b90fcafcSspupyrev     JumpArray[1] = Jumps2;
543b90fcafcSspupyrev   }
544b90fcafcSspupyrev 
545b90fcafcSspupyrev   template <typename F> void forEach(const F &Func) const {
546b90fcafcSspupyrev     for (auto Jumps : JumpArray)
547b90fcafcSspupyrev       if (Jumps != nullptr)
548b90fcafcSspupyrev         for (JumpT *Jump : *Jumps)
549b90fcafcSspupyrev           Func(Jump);
550b90fcafcSspupyrev   }
551b90fcafcSspupyrev 
552b90fcafcSspupyrev private:
553b90fcafcSspupyrev   std::array<const std::vector<JumpT *> *, 2> JumpArray{nullptr, nullptr};
554b90fcafcSspupyrev };
555b90fcafcSspupyrev 
556a7e13a99Sspupyrev /// Merge two chains of nodes respecting a given 'type' and 'offset'.
557a7e13a99Sspupyrev ///
558a7e13a99Sspupyrev /// If MergeType == 0, then the result is a concatenation of two chains.
559a7e13a99Sspupyrev /// Otherwise, the first chain is cut into two sub-chains at the offset,
560a7e13a99Sspupyrev /// and merged using all possible ways of concatenating three chains.
561b90fcafcSspupyrev MergedNodesT mergeNodes(const std::vector<NodeT *> &X,
5625b39d8d3Sspupyrev                         const std::vector<NodeT *> &Y, size_t MergeOffset,
5635b39d8d3Sspupyrev                         MergeTypeT MergeType) {
564bc59faa8Sspupyrev   // Split the first chain, X, into X1 and X2.
5655b39d8d3Sspupyrev   NodeIter BeginX1 = X.begin();
5665b39d8d3Sspupyrev   NodeIter EndX1 = X.begin() + MergeOffset;
5675b39d8d3Sspupyrev   NodeIter BeginX2 = X.begin() + MergeOffset;
5685b39d8d3Sspupyrev   NodeIter EndX2 = X.end();
5695b39d8d3Sspupyrev   NodeIter BeginY = Y.begin();
5705b39d8d3Sspupyrev   NodeIter EndY = Y.end();
571a7e13a99Sspupyrev 
572bc59faa8Sspupyrev   // Construct a new chain from the three existing ones.
573a7e13a99Sspupyrev   switch (MergeType) {
574a7e13a99Sspupyrev   case MergeTypeT::X_Y:
575b90fcafcSspupyrev     return MergedNodesT(BeginX1, EndX2, BeginY, EndY);
576a7e13a99Sspupyrev   case MergeTypeT::Y_X:
577b90fcafcSspupyrev     return MergedNodesT(BeginY, EndY, BeginX1, EndX2);
578a7e13a99Sspupyrev   case MergeTypeT::X1_Y_X2:
579b90fcafcSspupyrev     return MergedNodesT(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
580a7e13a99Sspupyrev   case MergeTypeT::Y_X2_X1:
581b90fcafcSspupyrev     return MergedNodesT(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
582a7e13a99Sspupyrev   case MergeTypeT::X2_X1_Y:
583b90fcafcSspupyrev     return MergedNodesT(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
584a7e13a99Sspupyrev   }
585a7e13a99Sspupyrev   llvm_unreachable("unexpected chain merge type");
586a7e13a99Sspupyrev }
587a7e13a99Sspupyrev 
588f573f686Sspupyrev /// The implementation of the ExtTSP algorithm.
589f573f686Sspupyrev class ExtTSPImpl {
590f573f686Sspupyrev public:
5916b8d04c2SFangrui Song   ExtTSPImpl(ArrayRef<uint64_t> NodeSizes, ArrayRef<uint64_t> NodeCounts,
5926b8d04c2SFangrui Song              ArrayRef<EdgeCount> EdgeCounts)
593a7e13a99Sspupyrev       : NumNodes(NodeSizes.size()) {
594f573f686Sspupyrev     initialize(NodeSizes, NodeCounts, EdgeCounts);
595f573f686Sspupyrev   }
596f573f686Sspupyrev 
597a7e13a99Sspupyrev   /// Run the algorithm and return an optimized ordering of nodes.
5986b8d04c2SFangrui Song   std::vector<uint64_t> run() {
599a7e13a99Sspupyrev     // Pass 1: Merge nodes with their mutually forced successors
600f573f686Sspupyrev     mergeForcedPairs();
601f573f686Sspupyrev 
602f573f686Sspupyrev     // Pass 2: Merge pairs of chains while improving the ExtTSP objective
603f573f686Sspupyrev     mergeChainPairs();
604f573f686Sspupyrev 
605a7e13a99Sspupyrev     // Pass 3: Merge cold nodes to reduce code size
606f573f686Sspupyrev     mergeColdChains();
607f573f686Sspupyrev 
608a7e13a99Sspupyrev     // Collect nodes from all chains
6096b8d04c2SFangrui Song     return concatChains();
610f573f686Sspupyrev   }
611f573f686Sspupyrev 
612f573f686Sspupyrev private:
613f573f686Sspupyrev   /// Initialize the algorithm's data structures.
6146b8d04c2SFangrui Song   void initialize(const ArrayRef<uint64_t> &NodeSizes,
6156b8d04c2SFangrui Song                   const ArrayRef<uint64_t> &NodeCounts,
6166b8d04c2SFangrui Song                   const ArrayRef<EdgeCount> &EdgeCounts) {
617cebc8379Sspupyrev     // Initialize nodes.
618a7e13a99Sspupyrev     AllNodes.reserve(NumNodes);
619a7e13a99Sspupyrev     for (uint64_t Idx = 0; Idx < NumNodes; Idx++) {
620a7e13a99Sspupyrev       uint64_t Size = std::max<uint64_t>(NodeSizes[Idx], 1ULL);
621a7e13a99Sspupyrev       uint64_t ExecutionCount = NodeCounts[Idx];
622bc59faa8Sspupyrev       // The execution count of the entry node is set to at least one.
623a7e13a99Sspupyrev       if (Idx == 0 && ExecutionCount == 0)
624f573f686Sspupyrev         ExecutionCount = 1;
625a7e13a99Sspupyrev       AllNodes.emplace_back(Idx, Size, ExecutionCount);
626f573f686Sspupyrev     }
627f573f686Sspupyrev 
628cebc8379Sspupyrev     // Initialize jumps between the nodes.
6298d5b694dSspupyrev     SuccNodes.resize(NumNodes);
6308d5b694dSspupyrev     PredNodes.resize(NumNodes);
6318d5b694dSspupyrev     std::vector<uint64_t> OutDegree(NumNodes, 0);
632f573f686Sspupyrev     AllJumps.reserve(EdgeCounts.size());
6336b8d04c2SFangrui Song     for (auto Edge : EdgeCounts) {
6346b8d04c2SFangrui Song       ++OutDegree[Edge.src];
635bc59faa8Sspupyrev       // Ignore self-edges.
6366b8d04c2SFangrui Song       if (Edge.src == Edge.dst)
637f573f686Sspupyrev         continue;
638f573f686Sspupyrev 
6396b8d04c2SFangrui Song       SuccNodes[Edge.src].push_back(Edge.dst);
6406b8d04c2SFangrui Song       PredNodes[Edge.dst].push_back(Edge.src);
6416b8d04c2SFangrui Song       if (Edge.count > 0) {
6426b8d04c2SFangrui Song         NodeT &PredNode = AllNodes[Edge.src];
6436b8d04c2SFangrui Song         NodeT &SuccNode = AllNodes[Edge.dst];
6446b8d04c2SFangrui Song         AllJumps.emplace_back(&PredNode, &SuccNode, Edge.count);
645a7e13a99Sspupyrev         SuccNode.InJumps.push_back(&AllJumps.back());
646a7e13a99Sspupyrev         PredNode.OutJumps.push_back(&AllJumps.back());
647cebc8379Sspupyrev         // Adjust execution counts.
648cebc8379Sspupyrev         PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Edge.count);
649cebc8379Sspupyrev         SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Edge.count);
650f573f686Sspupyrev       }
651f573f686Sspupyrev     }
652a7e13a99Sspupyrev     for (JumpT &Jump : AllJumps) {
653cc2fbc64Sspupyrev       assert(OutDegree[Jump.Source->Index] > 0 &&
654cc2fbc64Sspupyrev              "incorrectly computed out-degree of the block");
6558d5b694dSspupyrev       Jump.IsConditional = OutDegree[Jump.Source->Index] > 1;
6568d5b694dSspupyrev     }
657f573f686Sspupyrev 
658bc59faa8Sspupyrev     // Initialize chains.
659f573f686Sspupyrev     AllChains.reserve(NumNodes);
660f573f686Sspupyrev     HotChains.reserve(NumNodes);
661a7e13a99Sspupyrev     for (NodeT &Node : AllNodes) {
662b90fcafcSspupyrev       // Create a chain.
663a7e13a99Sspupyrev       AllChains.emplace_back(Node.Index, &Node);
664a7e13a99Sspupyrev       Node.CurChain = &AllChains.back();
665bc59faa8Sspupyrev       if (Node.ExecutionCount > 0)
666f573f686Sspupyrev         HotChains.push_back(&AllChains.back());
667f573f686Sspupyrev     }
668f573f686Sspupyrev 
669bc59faa8Sspupyrev     // Initialize chain edges.
670f573f686Sspupyrev     AllEdges.reserve(AllJumps.size());
671a7e13a99Sspupyrev     for (NodeT &PredNode : AllNodes) {
672a7e13a99Sspupyrev       for (JumpT *Jump : PredNode.OutJumps) {
673cebc8379Sspupyrev         assert(Jump->ExecutionCount > 0 && "incorrectly initialized jump");
674a7e13a99Sspupyrev         NodeT *SuccNode = Jump->Target;
675a7e13a99Sspupyrev         ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
676b90fcafcSspupyrev         // This edge is already present in the graph.
677f573f686Sspupyrev         if (CurEdge != nullptr) {
678a7e13a99Sspupyrev           assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
679f573f686Sspupyrev           CurEdge->appendJump(Jump);
680f573f686Sspupyrev           continue;
681f573f686Sspupyrev         }
682b90fcafcSspupyrev         // This is a new edge.
683f573f686Sspupyrev         AllEdges.emplace_back(Jump);
684a7e13a99Sspupyrev         PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
685a7e13a99Sspupyrev         SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
686f573f686Sspupyrev       }
687f573f686Sspupyrev     }
688f573f686Sspupyrev   }
689f573f686Sspupyrev 
690a7e13a99Sspupyrev   /// For a pair of nodes, A and B, node B is the forced successor of A,
691f573f686Sspupyrev   /// if (i) all jumps (based on profile) from A goes to B and (ii) all jumps
692a7e13a99Sspupyrev   /// to B are from A. Such nodes should be adjacent in the optimal ordering;
693a7e13a99Sspupyrev   /// the method finds and merges such pairs of nodes.
694f573f686Sspupyrev   void mergeForcedPairs() {
695b90fcafcSspupyrev     // Find forced pairs of blocks.
696a7e13a99Sspupyrev     for (NodeT &Node : AllNodes) {
697a7e13a99Sspupyrev       if (SuccNodes[Node.Index].size() == 1 &&
698a7e13a99Sspupyrev           PredNodes[SuccNodes[Node.Index][0]].size() == 1 &&
699a7e13a99Sspupyrev           SuccNodes[Node.Index][0] != 0) {
700a7e13a99Sspupyrev         size_t SuccIndex = SuccNodes[Node.Index][0];
701a7e13a99Sspupyrev         Node.ForcedSucc = &AllNodes[SuccIndex];
702a7e13a99Sspupyrev         AllNodes[SuccIndex].ForcedPred = &Node;
703f573f686Sspupyrev       }
704f573f686Sspupyrev     }
705f573f686Sspupyrev 
706f573f686Sspupyrev     // There might be 'cycles' in the forced dependencies, since profile
707f573f686Sspupyrev     // data isn't 100% accurate. Typically this is observed in loops, when the
708f573f686Sspupyrev     // loop edges are the hottest successors for the basic blocks of the loop.
709a7e13a99Sspupyrev     // Break the cycles by choosing the node with the smallest index as the
710f573f686Sspupyrev     // head. This helps to keep the original order of the loops, which likely
711f573f686Sspupyrev     // have already been rotated in the optimized manner.
712a7e13a99Sspupyrev     for (NodeT &Node : AllNodes) {
713a7e13a99Sspupyrev       if (Node.ForcedSucc == nullptr || Node.ForcedPred == nullptr)
714f573f686Sspupyrev         continue;
715f573f686Sspupyrev 
716a7e13a99Sspupyrev       NodeT *SuccNode = Node.ForcedSucc;
717a7e13a99Sspupyrev       while (SuccNode != nullptr && SuccNode != &Node) {
718a7e13a99Sspupyrev         SuccNode = SuccNode->ForcedSucc;
719f573f686Sspupyrev       }
720a7e13a99Sspupyrev       if (SuccNode == nullptr)
721f573f686Sspupyrev         continue;
722bc59faa8Sspupyrev       // Break the cycle.
723a7e13a99Sspupyrev       AllNodes[Node.ForcedPred->Index].ForcedSucc = nullptr;
724a7e13a99Sspupyrev       Node.ForcedPred = nullptr;
725f573f686Sspupyrev     }
726f573f686Sspupyrev 
727bc59faa8Sspupyrev     // Merge nodes with their fallthrough successors.
728a7e13a99Sspupyrev     for (NodeT &Node : AllNodes) {
729a7e13a99Sspupyrev       if (Node.ForcedPred == nullptr && Node.ForcedSucc != nullptr) {
730a7e13a99Sspupyrev         const NodeT *CurBlock = &Node;
731f573f686Sspupyrev         while (CurBlock->ForcedSucc != nullptr) {
732a7e13a99Sspupyrev           const NodeT *NextBlock = CurBlock->ForcedSucc;
733a7e13a99Sspupyrev           mergeChains(Node.CurChain, NextBlock->CurChain, 0, MergeTypeT::X_Y);
734f573f686Sspupyrev           CurBlock = NextBlock;
735f573f686Sspupyrev         }
736f573f686Sspupyrev       }
737f573f686Sspupyrev     }
738f573f686Sspupyrev   }
739f573f686Sspupyrev 
740f573f686Sspupyrev   /// Merge pairs of chains while improving the ExtTSP objective.
741f573f686Sspupyrev   void mergeChainPairs() {
742bc59faa8Sspupyrev     /// Deterministically compare pairs of chains.
743a7e13a99Sspupyrev     auto compareChainPairs = [](const ChainT *A1, const ChainT *B1,
744a7e13a99Sspupyrev                                 const ChainT *A2, const ChainT *B2) {
745b90fcafcSspupyrev       return std::make_tuple(A1->Id, B1->Id) < std::make_tuple(A2->Id, B2->Id);
746f573f686Sspupyrev     };
747f573f686Sspupyrev 
748f573f686Sspupyrev     while (HotChains.size() > 1) {
749a7e13a99Sspupyrev       ChainT *BestChainPred = nullptr;
750a7e13a99Sspupyrev       ChainT *BestChainSucc = nullptr;
751a7e13a99Sspupyrev       MergeGainT BestGain;
752bc59faa8Sspupyrev       // Iterate over all pairs of chains.
753a7e13a99Sspupyrev       for (ChainT *ChainPred : HotChains) {
754bc59faa8Sspupyrev         // Get candidates for merging with the current chain.
755bc59faa8Sspupyrev         for (const auto &[ChainSucc, Edge] : ChainPred->Edges) {
756bc59faa8Sspupyrev           // Ignore loop edges.
757cc2fbc64Sspupyrev           if (Edge->isSelfEdge())
758f573f686Sspupyrev             continue;
759cc2fbc64Sspupyrev           // Skip the merge if the combined chain violates the maximum specified
760cc2fbc64Sspupyrev           // size.
761bcdc0477Sspupyrev           if (ChainPred->numBlocks() + ChainSucc->numBlocks() >= MaxChainSize)
762bcdc0477Sspupyrev             continue;
763cc2fbc64Sspupyrev           // Don't merge the chains if they have vastly different densities.
764cc2fbc64Sspupyrev           // Skip the merge if the ratio between the densities exceeds
765cc2fbc64Sspupyrev           // MaxMergeDensityRatio. Smaller values of the option result in fewer
766cc2fbc64Sspupyrev           // merges, and hence, more chains.
767cebc8379Sspupyrev           const double ChainPredDensity = ChainPred->density();
768cebc8379Sspupyrev           const double ChainSuccDensity = ChainSucc->density();
769cebc8379Sspupyrev           assert(ChainPredDensity > 0.0 && ChainSuccDensity > 0.0 &&
770cc2fbc64Sspupyrev                  "incorrectly computed chain densities");
771cebc8379Sspupyrev           auto [MinDensity, MaxDensity] =
772cebc8379Sspupyrev               std::minmax(ChainPredDensity, ChainSuccDensity);
773cebc8379Sspupyrev           const double Ratio = MaxDensity / MinDensity;
774cc2fbc64Sspupyrev           if (Ratio > MaxMergeDensityRatio)
775cc2fbc64Sspupyrev             continue;
776bcdc0477Sspupyrev 
7775b39d8d3Sspupyrev           // Compute the gain of merging the two chains.
778a7e13a99Sspupyrev           MergeGainT CurGain = getBestMergeGain(ChainPred, ChainSucc, Edge);
779f573f686Sspupyrev           if (CurGain.score() <= EPS)
780f573f686Sspupyrev             continue;
781f573f686Sspupyrev 
782f573f686Sspupyrev           if (BestGain < CurGain ||
783f573f686Sspupyrev               (std::abs(CurGain.score() - BestGain.score()) < EPS &&
784f573f686Sspupyrev                compareChainPairs(ChainPred, ChainSucc, BestChainPred,
785f573f686Sspupyrev                                  BestChainSucc))) {
786f573f686Sspupyrev             BestGain = CurGain;
787f573f686Sspupyrev             BestChainPred = ChainPred;
788f573f686Sspupyrev             BestChainSucc = ChainSucc;
789f573f686Sspupyrev           }
790f573f686Sspupyrev         }
791f573f686Sspupyrev       }
792f573f686Sspupyrev 
793bc59faa8Sspupyrev       // Stop merging when there is no improvement.
794f573f686Sspupyrev       if (BestGain.score() <= EPS)
795f573f686Sspupyrev         break;
796f573f686Sspupyrev 
797bc59faa8Sspupyrev       // Merge the best pair of chains.
798f573f686Sspupyrev       mergeChains(BestChainPred, BestChainSucc, BestGain.mergeOffset(),
799f573f686Sspupyrev                   BestGain.mergeType());
800f573f686Sspupyrev     }
801f573f686Sspupyrev   }
802f573f686Sspupyrev 
803a7e13a99Sspupyrev   /// Merge remaining nodes into chains w/o taking jump counts into
804a7e13a99Sspupyrev   /// consideration. This allows to maintain the original node order in the
805bc59faa8Sspupyrev   /// absence of profile data.
806f573f686Sspupyrev   void mergeColdChains() {
807f573f686Sspupyrev     for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) {
8088d5b694dSspupyrev       // Iterating in reverse order to make sure original fallthrough jumps are
8098d5b694dSspupyrev       // merged first; this might be beneficial for code size.
810f573f686Sspupyrev       size_t NumSuccs = SuccNodes[SrcBB].size();
811f573f686Sspupyrev       for (size_t Idx = 0; Idx < NumSuccs; Idx++) {
812a7e13a99Sspupyrev         size_t DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1];
813a7e13a99Sspupyrev         ChainT *SrcChain = AllNodes[SrcBB].CurChain;
814a7e13a99Sspupyrev         ChainT *DstChain = AllNodes[DstBB].CurChain;
815f573f686Sspupyrev         if (SrcChain != DstChain && !DstChain->isEntry() &&
816a7e13a99Sspupyrev             SrcChain->Nodes.back()->Index == SrcBB &&
817a7e13a99Sspupyrev             DstChain->Nodes.front()->Index == DstBB &&
8188d5b694dSspupyrev             SrcChain->isCold() == DstChain->isCold()) {
819a7e13a99Sspupyrev           mergeChains(SrcChain, DstChain, 0, MergeTypeT::X_Y);
820f573f686Sspupyrev         }
821f573f686Sspupyrev       }
822f573f686Sspupyrev     }
823f573f686Sspupyrev   }
824f573f686Sspupyrev 
825a7e13a99Sspupyrev   /// Compute the Ext-TSP score for a given node order and a list of jumps.
826b90fcafcSspupyrev   double extTSPScore(const MergedNodesT &Nodes,
827b90fcafcSspupyrev                      const MergedJumpsT &Jumps) const {
828f573f686Sspupyrev     uint64_t CurAddr = 0;
829b90fcafcSspupyrev     Nodes.forEach([&](const NodeT *Node) {
830a7e13a99Sspupyrev       Node->EstimatedAddr = CurAddr;
831a7e13a99Sspupyrev       CurAddr += Node->Size;
832f573f686Sspupyrev     });
833f573f686Sspupyrev 
834f573f686Sspupyrev     double Score = 0;
835b90fcafcSspupyrev     Jumps.forEach([&](const JumpT *Jump) {
836a7e13a99Sspupyrev       const NodeT *SrcBlock = Jump->Source;
837a7e13a99Sspupyrev       const NodeT *DstBlock = Jump->Target;
838f573f686Sspupyrev       Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size,
8398d5b694dSspupyrev                              DstBlock->EstimatedAddr, Jump->ExecutionCount,
8408d5b694dSspupyrev                              Jump->IsConditional);
841b90fcafcSspupyrev     });
842f573f686Sspupyrev     return Score;
843f573f686Sspupyrev   }
844f573f686Sspupyrev 
845f573f686Sspupyrev   /// Compute the gain of merging two chains.
846f573f686Sspupyrev   ///
847f573f686Sspupyrev   /// The function considers all possible ways of merging two chains and
848f573f686Sspupyrev   /// computes the one having the largest increase in ExtTSP objective. The
849f573f686Sspupyrev   /// result is a pair with the first element being the gain and the second
850f573f686Sspupyrev   /// element being the corresponding merging type.
851a7e13a99Sspupyrev   MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
852f573f686Sspupyrev                               ChainEdge *Edge) const {
853b90fcafcSspupyrev     if (Edge->hasCachedMergeGain(ChainPred, ChainSucc))
854f573f686Sspupyrev       return Edge->getCachedMergeGain(ChainPred, ChainSucc);
855f573f686Sspupyrev 
856b90fcafcSspupyrev     assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
857bc59faa8Sspupyrev     // Precompute jumps between ChainPred and ChainSucc.
8580a7bf3aaSspupyrev     ChainEdge *EdgePP = ChainPred->getEdge(ChainPred);
859b90fcafcSspupyrev     MergedJumpsT Jumps(&Edge->jumps(), EdgePP ? &EdgePP->jumps() : nullptr);
860f573f686Sspupyrev 
861bc59faa8Sspupyrev     // This object holds the best chosen gain of merging two chains.
862a7e13a99Sspupyrev     MergeGainT Gain = MergeGainT();
863f573f686Sspupyrev 
864f573f686Sspupyrev     /// Given a merge offset and a list of merge types, try to merge two chains
865bc59faa8Sspupyrev     /// and update Gain with a better alternative.
866f573f686Sspupyrev     auto tryChainMerging = [&](size_t Offset,
867a7e13a99Sspupyrev                                const std::vector<MergeTypeT> &MergeTypes) {
868bc59faa8Sspupyrev       // Skip merging corresponding to concatenation w/o splitting.
869a7e13a99Sspupyrev       if (Offset == 0 || Offset == ChainPred->Nodes.size())
870f573f686Sspupyrev         return;
871bc59faa8Sspupyrev       // Skip merging if it breaks Forced successors.
872a7e13a99Sspupyrev       NodeT *Node = ChainPred->Nodes[Offset - 1];
873a7e13a99Sspupyrev       if (Node->ForcedSucc != nullptr)
874f573f686Sspupyrev         return;
875f573f686Sspupyrev       // Apply the merge, compute the corresponding gain, and update the best
876bc59faa8Sspupyrev       // value, if the merge is beneficial.
877a7e13a99Sspupyrev       for (const MergeTypeT &MergeType : MergeTypes) {
878f573f686Sspupyrev         Gain.updateIfLessThan(
879f573f686Sspupyrev             computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType));
880f573f686Sspupyrev       }
881f573f686Sspupyrev     };
882f573f686Sspupyrev 
883bc59faa8Sspupyrev     // Try to concatenate two chains w/o splitting.
884f573f686Sspupyrev     Gain.updateIfLessThan(
885a7e13a99Sspupyrev         computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeT::X_Y));
886f573f686Sspupyrev 
887bc59faa8Sspupyrev     // Attach (a part of) ChainPred before the first node of ChainSucc.
888a7e13a99Sspupyrev     for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) {
889a7e13a99Sspupyrev       const NodeT *SrcBlock = Jump->Source;
890f573f686Sspupyrev       if (SrcBlock->CurChain != ChainPred)
891f573f686Sspupyrev         continue;
892f573f686Sspupyrev       size_t Offset = SrcBlock->CurIndex + 1;
893a7e13a99Sspupyrev       tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y});
894f573f686Sspupyrev     }
895f573f686Sspupyrev 
896bc59faa8Sspupyrev     // Attach (a part of) ChainPred after the last node of ChainSucc.
897a7e13a99Sspupyrev     for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) {
898af935cf0SFangrui Song       const NodeT *DstBlock = Jump->Target;
899f573f686Sspupyrev       if (DstBlock->CurChain != ChainPred)
900f573f686Sspupyrev         continue;
901f573f686Sspupyrev       size_t Offset = DstBlock->CurIndex;
902a7e13a99Sspupyrev       tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1});
903f573f686Sspupyrev     }
904f573f686Sspupyrev 
905bc59faa8Sspupyrev     // Try to break ChainPred in various ways and concatenate with ChainSucc.
9065b39d8d3Sspupyrev     if (ChainPred->Nodes.size() <= ChainSplitThreshold) {
907a7e13a99Sspupyrev       for (size_t Offset = 1; Offset < ChainPred->Nodes.size(); Offset++) {
908cc2fbc64Sspupyrev         // Do not split the chain along a fall-through jump. One of the two
909cc2fbc64Sspupyrev         // loops above may still "break" such a jump whenever it results in a
910cc2fbc64Sspupyrev         // new fall-through.
911cc2fbc64Sspupyrev         const NodeT *BB = ChainPred->Nodes[Offset - 1];
912cc2fbc64Sspupyrev         const NodeT *BB2 = ChainPred->Nodes[Offset];
913cc2fbc64Sspupyrev         if (BB->isSuccessor(BB2))
914cc2fbc64Sspupyrev           continue;
915cc2fbc64Sspupyrev 
916cc2fbc64Sspupyrev         // In practice, applying X2_Y_X1 merging almost never provides benefits;
917cc2fbc64Sspupyrev         // thus, we exclude it from consideration to reduce the search space.
918a7e13a99Sspupyrev         tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1,
919a7e13a99Sspupyrev                                  MergeTypeT::X2_X1_Y});
920f573f686Sspupyrev       }
9215b39d8d3Sspupyrev     }
922cc2fbc64Sspupyrev 
923f573f686Sspupyrev     Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain);
924f573f686Sspupyrev     return Gain;
925f573f686Sspupyrev   }
926f573f686Sspupyrev 
927f573f686Sspupyrev   /// Compute the score gain of merging two chains, respecting a given
928f573f686Sspupyrev   /// merge 'type' and 'offset'.
929f573f686Sspupyrev   ///
930f573f686Sspupyrev   /// The two chains are not modified in the method.
931a7e13a99Sspupyrev   MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc,
932b90fcafcSspupyrev                               const MergedJumpsT &Jumps, size_t MergeOffset,
933b90fcafcSspupyrev                               MergeTypeT MergeType) const {
934b90fcafcSspupyrev     MergedNodesT MergedNodes =
935a7e13a99Sspupyrev         mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
936f573f686Sspupyrev 
937bc59faa8Sspupyrev     // Do not allow a merge that does not preserve the original entry point.
938f573f686Sspupyrev     if ((ChainPred->isEntry() || ChainSucc->isEntry()) &&
939b90fcafcSspupyrev         !MergedNodes.getFirstNode()->isEntry())
940a7e13a99Sspupyrev       return MergeGainT();
941f573f686Sspupyrev 
942bc59faa8Sspupyrev     // The gain for the new chain.
943b90fcafcSspupyrev     double NewScore = extTSPScore(MergedNodes, Jumps);
944b90fcafcSspupyrev     double CurScore = ChainPred->Score;
945b90fcafcSspupyrev     return MergeGainT(NewScore - CurScore, MergeOffset, MergeType);
946f573f686Sspupyrev   }
947f573f686Sspupyrev 
948f573f686Sspupyrev   /// Merge chain From into chain Into, update the list of active chains,
949f573f686Sspupyrev   /// adjacency information, and the corresponding cached values.
950a7e13a99Sspupyrev   void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset,
951a7e13a99Sspupyrev                    MergeTypeT MergeType) {
952f573f686Sspupyrev     assert(Into != From && "a chain cannot be merged with itself");
953f573f686Sspupyrev 
954bc59faa8Sspupyrev     // Merge the nodes.
955b90fcafcSspupyrev     MergedNodesT MergedNodes =
956a7e13a99Sspupyrev         mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
957a7e13a99Sspupyrev     Into->merge(From, MergedNodes.getNodes());
958a7e13a99Sspupyrev 
959bc59faa8Sspupyrev     // Merge the edges.
960f573f686Sspupyrev     Into->mergeEdges(From);
961f573f686Sspupyrev     From->clear();
962f573f686Sspupyrev 
963bc59faa8Sspupyrev     // Update cached ext-tsp score for the new chain.
9648d5b694dSspupyrev     ChainEdge *SelfEdge = Into->getEdge(Into);
965f573f686Sspupyrev     if (SelfEdge != nullptr) {
966b90fcafcSspupyrev       MergedNodes = MergedNodesT(Into->Nodes.begin(), Into->Nodes.end());
967b90fcafcSspupyrev       MergedJumpsT MergedJumps(&SelfEdge->jumps());
968b90fcafcSspupyrev       Into->Score = extTSPScore(MergedNodes, MergedJumps);
969f573f686Sspupyrev     }
970f573f686Sspupyrev 
971bc59faa8Sspupyrev     // Remove the chain from the list of active chains.
972f9306f6dSKazu Hirata     llvm::erase(HotChains, From);
973f573f686Sspupyrev 
974bc59faa8Sspupyrev     // Invalidate caches.
975a7e13a99Sspupyrev     for (auto EdgeIt : Into->Edges)
976a7e13a99Sspupyrev       EdgeIt.second->invalidateCache();
977f573f686Sspupyrev   }
978f573f686Sspupyrev 
979a7e13a99Sspupyrev   /// Concatenate all chains into the final order.
9806b8d04c2SFangrui Song   std::vector<uint64_t> concatChains() {
981cc2fbc64Sspupyrev     // Collect non-empty chains.
982a7e13a99Sspupyrev     std::vector<const ChainT *> SortedChains;
983a7e13a99Sspupyrev     for (ChainT &Chain : AllChains) {
984cc2fbc64Sspupyrev       if (!Chain.Nodes.empty())
985f573f686Sspupyrev         SortedChains.push_back(&Chain);
986f573f686Sspupyrev     }
987f573f686Sspupyrev 
988bc59faa8Sspupyrev     // Sorting chains by density in the decreasing order.
989fef144ceSKazu Hirata     std::sort(SortedChains.begin(), SortedChains.end(),
990fef144ceSKazu Hirata               [&](const ChainT *L, const ChainT *R) {
991b90fcafcSspupyrev                 // Place the entry point at the beginning of the order.
992a7e13a99Sspupyrev                 if (L->isEntry() != R->isEntry())
993a7e13a99Sspupyrev                   return L->isEntry();
994f573f686Sspupyrev 
995bc59faa8Sspupyrev                 // Compare by density and break ties by chain identifiers.
996cc2fbc64Sspupyrev                 return std::make_tuple(-L->density(), L->Id) <
997cc2fbc64Sspupyrev                        std::make_tuple(-R->density(), R->Id);
998f573f686Sspupyrev               });
999f573f686Sspupyrev 
1000bc59faa8Sspupyrev     // Collect the nodes in the order specified by their chains.
10016b8d04c2SFangrui Song     std::vector<uint64_t> Order;
1002f573f686Sspupyrev     Order.reserve(NumNodes);
10036b8d04c2SFangrui Song     for (const ChainT *Chain : SortedChains)
10046b8d04c2SFangrui Song       for (NodeT *Node : Chain->Nodes)
1005a7e13a99Sspupyrev         Order.push_back(Node->Index);
10066b8d04c2SFangrui Song     return Order;
1007f573f686Sspupyrev   }
1008f573f686Sspupyrev 
1009f573f686Sspupyrev private:
1010f573f686Sspupyrev   /// The number of nodes in the graph.
1011f573f686Sspupyrev   const size_t NumNodes;
1012f573f686Sspupyrev 
1013f573f686Sspupyrev   /// Successors of each node.
1014f573f686Sspupyrev   std::vector<std::vector<uint64_t>> SuccNodes;
1015f573f686Sspupyrev 
1016f573f686Sspupyrev   /// Predecessors of each node.
1017f573f686Sspupyrev   std::vector<std::vector<uint64_t>> PredNodes;
1018f573f686Sspupyrev 
1019a7e13a99Sspupyrev   /// All nodes (basic blocks) in the graph.
1020a7e13a99Sspupyrev   std::vector<NodeT> AllNodes;
1021f573f686Sspupyrev 
1022a7e13a99Sspupyrev   /// All jumps between the nodes.
1023a7e13a99Sspupyrev   std::vector<JumpT> AllJumps;
1024f573f686Sspupyrev 
1025a7e13a99Sspupyrev   /// All chains of nodes.
1026a7e13a99Sspupyrev   std::vector<ChainT> AllChains;
1027f573f686Sspupyrev 
1028a7e13a99Sspupyrev   /// All edges between the chains.
1029f573f686Sspupyrev   std::vector<ChainEdge> AllEdges;
1030f573f686Sspupyrev 
1031f573f686Sspupyrev   /// Active chains. The vector gets updated at runtime when chains are merged.
1032a7e13a99Sspupyrev   std::vector<ChainT *> HotChains;
1033f573f686Sspupyrev };
1034f573f686Sspupyrev 
1035f61179f8Sspupyrev /// The implementation of the Cache-Directed Sort (CDSort) algorithm for
1036f61179f8Sspupyrev /// ordering functions represented by a call graph.
1037bc59faa8Sspupyrev class CDSortImpl {
1038bc59faa8Sspupyrev public:
10396b8d04c2SFangrui Song   CDSortImpl(const CDSortConfig &Config, ArrayRef<uint64_t> NodeSizes,
10406b8d04c2SFangrui Song              ArrayRef<uint64_t> NodeCounts, ArrayRef<EdgeCount> EdgeCounts,
10416b8d04c2SFangrui Song              ArrayRef<uint64_t> EdgeOffsets)
1042bc59faa8Sspupyrev       : Config(Config), NumNodes(NodeSizes.size()) {
1043bc59faa8Sspupyrev     initialize(NodeSizes, NodeCounts, EdgeCounts, EdgeOffsets);
1044bc59faa8Sspupyrev   }
1045bc59faa8Sspupyrev 
1046bc59faa8Sspupyrev   /// Run the algorithm and return an ordered set of function clusters.
10476b8d04c2SFangrui Song   std::vector<uint64_t> run() {
1048bc59faa8Sspupyrev     // Merge pairs of chains while improving the objective.
1049bc59faa8Sspupyrev     mergeChainPairs();
1050bc59faa8Sspupyrev 
1051bc59faa8Sspupyrev     // Collect nodes from all the chains.
10526b8d04c2SFangrui Song     return concatChains();
1053bc59faa8Sspupyrev   }
1054bc59faa8Sspupyrev 
1055bc59faa8Sspupyrev private:
1056bc59faa8Sspupyrev   /// Initialize the algorithm's data structures.
10576b8d04c2SFangrui Song   void initialize(const ArrayRef<uint64_t> &NodeSizes,
10586b8d04c2SFangrui Song                   const ArrayRef<uint64_t> &NodeCounts,
10596b8d04c2SFangrui Song                   const ArrayRef<EdgeCount> &EdgeCounts,
10606b8d04c2SFangrui Song                   const ArrayRef<uint64_t> &EdgeOffsets) {
1061bc59faa8Sspupyrev     // Initialize nodes.
1062bc59faa8Sspupyrev     AllNodes.reserve(NumNodes);
1063bc59faa8Sspupyrev     for (uint64_t Node = 0; Node < NumNodes; Node++) {
1064bc59faa8Sspupyrev       uint64_t Size = std::max<uint64_t>(NodeSizes[Node], 1ULL);
1065bc59faa8Sspupyrev       uint64_t ExecutionCount = NodeCounts[Node];
1066bc59faa8Sspupyrev       AllNodes.emplace_back(Node, Size, ExecutionCount);
1067bc59faa8Sspupyrev       TotalSamples += ExecutionCount;
1068bc59faa8Sspupyrev       if (ExecutionCount > 0)
1069bc59faa8Sspupyrev         TotalSize += Size;
1070bc59faa8Sspupyrev     }
1071bc59faa8Sspupyrev 
1072bc59faa8Sspupyrev     // Initialize jumps between the nodes.
1073bc59faa8Sspupyrev     SuccNodes.resize(NumNodes);
1074bc59faa8Sspupyrev     PredNodes.resize(NumNodes);
1075bc59faa8Sspupyrev     AllJumps.reserve(EdgeCounts.size());
1076bc59faa8Sspupyrev     for (size_t I = 0; I < EdgeCounts.size(); I++) {
10776b8d04c2SFangrui Song       auto [Pred, Succ, Count] = EdgeCounts[I];
1078bc59faa8Sspupyrev       // Ignore recursive calls.
1079bc59faa8Sspupyrev       if (Pred == Succ)
1080bc59faa8Sspupyrev         continue;
1081bc59faa8Sspupyrev 
1082bc59faa8Sspupyrev       SuccNodes[Pred].push_back(Succ);
1083bc59faa8Sspupyrev       PredNodes[Succ].push_back(Pred);
10846b8d04c2SFangrui Song       if (Count > 0) {
1085bc59faa8Sspupyrev         NodeT &PredNode = AllNodes[Pred];
1086bc59faa8Sspupyrev         NodeT &SuccNode = AllNodes[Succ];
10876b8d04c2SFangrui Song         AllJumps.emplace_back(&PredNode, &SuccNode, Count);
1088bc59faa8Sspupyrev         AllJumps.back().Offset = EdgeOffsets[I];
1089bc59faa8Sspupyrev         SuccNode.InJumps.push_back(&AllJumps.back());
1090bc59faa8Sspupyrev         PredNode.OutJumps.push_back(&AllJumps.back());
1091cebc8379Sspupyrev         // Adjust execution counts.
1092cebc8379Sspupyrev         PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Count);
1093cebc8379Sspupyrev         SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Count);
1094bc59faa8Sspupyrev       }
1095bc59faa8Sspupyrev     }
1096bc59faa8Sspupyrev 
1097bc59faa8Sspupyrev     // Initialize chains.
1098bc59faa8Sspupyrev     AllChains.reserve(NumNodes);
1099bc59faa8Sspupyrev     for (NodeT &Node : AllNodes) {
1100bc59faa8Sspupyrev       // Adjust execution counts.
1101bc59faa8Sspupyrev       Node.ExecutionCount = std::max(Node.ExecutionCount, Node.inCount());
1102bc59faa8Sspupyrev       Node.ExecutionCount = std::max(Node.ExecutionCount, Node.outCount());
1103bc59faa8Sspupyrev       // Create chain.
1104bc59faa8Sspupyrev       AllChains.emplace_back(Node.Index, &Node);
1105bc59faa8Sspupyrev       Node.CurChain = &AllChains.back();
1106bc59faa8Sspupyrev     }
1107bc59faa8Sspupyrev 
1108bc59faa8Sspupyrev     // Initialize chain edges.
1109bc59faa8Sspupyrev     AllEdges.reserve(AllJumps.size());
1110bc59faa8Sspupyrev     for (NodeT &PredNode : AllNodes) {
1111bc59faa8Sspupyrev       for (JumpT *Jump : PredNode.OutJumps) {
1112bc59faa8Sspupyrev         NodeT *SuccNode = Jump->Target;
1113bc59faa8Sspupyrev         ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
1114cebc8379Sspupyrev         // This edge is already present in the graph.
1115bc59faa8Sspupyrev         if (CurEdge != nullptr) {
1116bc59faa8Sspupyrev           assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
1117bc59faa8Sspupyrev           CurEdge->appendJump(Jump);
1118bc59faa8Sspupyrev           continue;
1119bc59faa8Sspupyrev         }
1120cebc8379Sspupyrev         // This is a new edge.
1121bc59faa8Sspupyrev         AllEdges.emplace_back(Jump);
1122bc59faa8Sspupyrev         PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
1123bc59faa8Sspupyrev         SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
1124bc59faa8Sspupyrev       }
1125bc59faa8Sspupyrev     }
1126bc59faa8Sspupyrev   }
1127bc59faa8Sspupyrev 
1128bc59faa8Sspupyrev   /// Merge pairs of chains while there is an improvement in the objective.
1129bc59faa8Sspupyrev   void mergeChainPairs() {
1130bc59faa8Sspupyrev     // Create a priority queue containing all edges ordered by the merge gain.
1131bc59faa8Sspupyrev     auto GainComparator = [](ChainEdge *L, ChainEdge *R) {
1132bc59faa8Sspupyrev       return std::make_tuple(-L->gain(), L->srcChain()->Id, L->dstChain()->Id) <
1133bc59faa8Sspupyrev              std::make_tuple(-R->gain(), R->srcChain()->Id, R->dstChain()->Id);
1134bc59faa8Sspupyrev     };
1135bc59faa8Sspupyrev     std::set<ChainEdge *, decltype(GainComparator)> Queue(GainComparator);
1136bc59faa8Sspupyrev 
1137bc59faa8Sspupyrev     // Insert the edges into the queue.
1138beffc821SFangrui Song     [[maybe_unused]] size_t NumActiveChains = 0;
1139beffc821SFangrui Song     for (NodeT &Node : AllNodes) {
1140beffc821SFangrui Song       if (Node.ExecutionCount == 0)
1141beffc821SFangrui Song         continue;
1142beffc821SFangrui Song       ++NumActiveChains;
1143beffc821SFangrui Song       for (const auto &[_, Edge] : Node.CurChain->Edges) {
1144bc59faa8Sspupyrev         // Ignore self-edges.
1145bc59faa8Sspupyrev         if (Edge->isSelfEdge())
1146bc59faa8Sspupyrev           continue;
1147bc59faa8Sspupyrev         // Ignore already processed edges.
1148bc59faa8Sspupyrev         if (Edge->gain() != -1.0)
1149bc59faa8Sspupyrev           continue;
1150bc59faa8Sspupyrev 
1151bc59faa8Sspupyrev         // Compute the gain of merging the two chains.
1152bc59faa8Sspupyrev         MergeGainT Gain = getBestMergeGain(Edge);
1153bc59faa8Sspupyrev         Edge->setMergeGain(Gain);
1154bc59faa8Sspupyrev 
1155bc59faa8Sspupyrev         if (Edge->gain() > EPS)
1156bc59faa8Sspupyrev           Queue.insert(Edge);
1157bc59faa8Sspupyrev       }
1158bc59faa8Sspupyrev     }
1159bc59faa8Sspupyrev 
1160bc59faa8Sspupyrev     // Merge the chains while the gain of merging is positive.
1161bc59faa8Sspupyrev     while (!Queue.empty()) {
1162bc59faa8Sspupyrev       // Extract the best (top) edge for merging.
1163bc59faa8Sspupyrev       ChainEdge *BestEdge = *Queue.begin();
1164bc59faa8Sspupyrev       Queue.erase(Queue.begin());
1165bc59faa8Sspupyrev       ChainT *BestSrcChain = BestEdge->srcChain();
1166bc59faa8Sspupyrev       ChainT *BestDstChain = BestEdge->dstChain();
1167bc59faa8Sspupyrev 
1168bc59faa8Sspupyrev       // Remove outdated edges from the queue.
1169b4b42bd6Sspupyrev       for (const auto &[_, ChainEdge] : BestSrcChain->Edges)
1170bc59faa8Sspupyrev         Queue.erase(ChainEdge);
1171b4b42bd6Sspupyrev       for (const auto &[_, ChainEdge] : BestDstChain->Edges)
1172bc59faa8Sspupyrev         Queue.erase(ChainEdge);
1173bc59faa8Sspupyrev 
1174bc59faa8Sspupyrev       // Merge the best pair of chains.
1175bc59faa8Sspupyrev       MergeGainT BestGain = BestEdge->getMergeGain();
1176bc59faa8Sspupyrev       mergeChains(BestSrcChain, BestDstChain, BestGain.mergeOffset(),
1177bc59faa8Sspupyrev                   BestGain.mergeType());
1178beffc821SFangrui Song       --NumActiveChains;
1179bc59faa8Sspupyrev 
1180bc59faa8Sspupyrev       // Insert newly created edges into the queue.
1181b4b42bd6Sspupyrev       for (const auto &[_, Edge] : BestSrcChain->Edges) {
1182bc59faa8Sspupyrev         // Ignore loop edges.
1183bc59faa8Sspupyrev         if (Edge->isSelfEdge())
1184bc59faa8Sspupyrev           continue;
1185a2441837SFangrui Song         if (Edge->srcChain()->numBlocks() + Edge->dstChain()->numBlocks() >
1186a2441837SFangrui Song             Config.MaxChainSize)
1187a2441837SFangrui Song           continue;
1188bc59faa8Sspupyrev 
1189bc59faa8Sspupyrev         // Compute the gain of merging the two chains.
1190bc59faa8Sspupyrev         MergeGainT Gain = getBestMergeGain(Edge);
1191bc59faa8Sspupyrev         Edge->setMergeGain(Gain);
1192bc59faa8Sspupyrev 
1193bc59faa8Sspupyrev         if (Edge->gain() > EPS)
1194bc59faa8Sspupyrev           Queue.insert(Edge);
1195bc59faa8Sspupyrev       }
1196bc59faa8Sspupyrev     }
1197beffc821SFangrui Song 
1198beffc821SFangrui Song     LLVM_DEBUG(dbgs() << "Cache-directed function sorting reduced the number"
1199beffc821SFangrui Song                       << " of chains from " << NumNodes << " to "
1200beffc821SFangrui Song                       << NumActiveChains << "\n");
1201bc59faa8Sspupyrev   }
1202bc59faa8Sspupyrev 
1203bc59faa8Sspupyrev   /// Compute the gain of merging two chains.
1204bc59faa8Sspupyrev   ///
1205bc59faa8Sspupyrev   /// The function considers all possible ways of merging two chains and
1206bc59faa8Sspupyrev   /// computes the one having the largest increase in ExtTSP objective. The
1207bc59faa8Sspupyrev   /// result is a pair with the first element being the gain and the second
1208bc59faa8Sspupyrev   /// element being the corresponding merging type.
1209bc59faa8Sspupyrev   MergeGainT getBestMergeGain(ChainEdge *Edge) const {
1210b90fcafcSspupyrev     assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
1211bc59faa8Sspupyrev     // Precompute jumps between ChainPred and ChainSucc.
1212b90fcafcSspupyrev     MergedJumpsT Jumps(&Edge->jumps());
1213bc59faa8Sspupyrev     ChainT *SrcChain = Edge->srcChain();
1214bc59faa8Sspupyrev     ChainT *DstChain = Edge->dstChain();
1215bc59faa8Sspupyrev 
1216bc59faa8Sspupyrev     // This object holds the best currently chosen gain of merging two chains.
1217bc59faa8Sspupyrev     MergeGainT Gain = MergeGainT();
1218bc59faa8Sspupyrev 
1219bc59faa8Sspupyrev     /// Given a list of merge types, try to merge two chains and update Gain
1220bc59faa8Sspupyrev     /// with a better alternative.
1221bc59faa8Sspupyrev     auto tryChainMerging = [&](const std::vector<MergeTypeT> &MergeTypes) {
1222bc59faa8Sspupyrev       // Apply the merge, compute the corresponding gain, and update the best
1223bc59faa8Sspupyrev       // value, if the merge is beneficial.
1224bc59faa8Sspupyrev       for (const MergeTypeT &MergeType : MergeTypes) {
1225bc59faa8Sspupyrev         MergeGainT NewGain =
1226bc59faa8Sspupyrev             computeMergeGain(SrcChain, DstChain, Jumps, MergeType);
1227bc59faa8Sspupyrev 
1228bc59faa8Sspupyrev         // When forward and backward gains are the same, prioritize merging that
1229bc59faa8Sspupyrev         // preserves the original order of the functions in the binary.
1230bc59faa8Sspupyrev         if (std::abs(Gain.score() - NewGain.score()) < EPS) {
1231bc59faa8Sspupyrev           if ((MergeType == MergeTypeT::X_Y && SrcChain->Id < DstChain->Id) ||
1232bc59faa8Sspupyrev               (MergeType == MergeTypeT::Y_X && SrcChain->Id > DstChain->Id)) {
1233bc59faa8Sspupyrev             Gain = NewGain;
1234bc59faa8Sspupyrev           }
1235bc59faa8Sspupyrev         } else if (NewGain.score() > Gain.score() + EPS) {
1236bc59faa8Sspupyrev           Gain = NewGain;
1237bc59faa8Sspupyrev         }
1238bc59faa8Sspupyrev       }
1239bc59faa8Sspupyrev     };
1240bc59faa8Sspupyrev 
1241bc59faa8Sspupyrev     // Try to concatenate two chains w/o splitting.
1242bc59faa8Sspupyrev     tryChainMerging({MergeTypeT::X_Y, MergeTypeT::Y_X});
1243bc59faa8Sspupyrev 
1244bc59faa8Sspupyrev     return Gain;
1245bc59faa8Sspupyrev   }
1246bc59faa8Sspupyrev 
1247bc59faa8Sspupyrev   /// Compute the score gain of merging two chains, respecting a given type.
1248bc59faa8Sspupyrev   ///
1249bc59faa8Sspupyrev   /// The two chains are not modified in the method.
1250bc59faa8Sspupyrev   MergeGainT computeMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
1251b90fcafcSspupyrev                               const MergedJumpsT &Jumps,
1252bc59faa8Sspupyrev                               MergeTypeT MergeType) const {
1253bc59faa8Sspupyrev     // This doesn't depend on the ordering of the nodes
1254bc59faa8Sspupyrev     double FreqGain = freqBasedLocalityGain(ChainPred, ChainSucc);
1255bc59faa8Sspupyrev 
1256bc59faa8Sspupyrev     // Merge offset is always 0, as the chains are not split.
1257bc59faa8Sspupyrev     size_t MergeOffset = 0;
1258bc59faa8Sspupyrev     auto MergedBlocks =
1259bc59faa8Sspupyrev         mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
1260bc59faa8Sspupyrev     double DistGain = distBasedLocalityGain(MergedBlocks, Jumps);
1261bc59faa8Sspupyrev 
1262bc59faa8Sspupyrev     double GainScore = DistGain + Config.FrequencyScale * FreqGain;
1263bc59faa8Sspupyrev     // Scale the result to increase the importance of merging short chains.
1264bc59faa8Sspupyrev     if (GainScore >= 0.0)
1265bc59faa8Sspupyrev       GainScore /= std::min(ChainPred->Size, ChainSucc->Size);
1266bc59faa8Sspupyrev 
1267bc59faa8Sspupyrev     return MergeGainT(GainScore, MergeOffset, MergeType);
1268bc59faa8Sspupyrev   }
1269bc59faa8Sspupyrev 
1270bc59faa8Sspupyrev   /// Compute the change of the frequency locality after merging the chains.
1271bc59faa8Sspupyrev   double freqBasedLocalityGain(ChainT *ChainPred, ChainT *ChainSucc) const {
1272bc59faa8Sspupyrev     auto missProbability = [&](double ChainDensity) {
1273bc59faa8Sspupyrev       double PageSamples = ChainDensity * Config.CacheSize;
1274bc59faa8Sspupyrev       if (PageSamples >= TotalSamples)
1275bc59faa8Sspupyrev         return 0.0;
1276bc59faa8Sspupyrev       double P = PageSamples / TotalSamples;
1277bc59faa8Sspupyrev       return pow(1.0 - P, static_cast<double>(Config.CacheEntries));
1278bc59faa8Sspupyrev     };
1279bc59faa8Sspupyrev 
1280bc59faa8Sspupyrev     // Cache misses on the chains before merging.
1281bc59faa8Sspupyrev     double CurScore =
1282bc59faa8Sspupyrev         ChainPred->ExecutionCount * missProbability(ChainPred->density()) +
1283bc59faa8Sspupyrev         ChainSucc->ExecutionCount * missProbability(ChainSucc->density());
1284bc59faa8Sspupyrev 
1285bc59faa8Sspupyrev     // Cache misses on the merged chain
1286bc59faa8Sspupyrev     double MergedCounts = ChainPred->ExecutionCount + ChainSucc->ExecutionCount;
1287bc59faa8Sspupyrev     double MergedSize = ChainPred->Size + ChainSucc->Size;
1288bc59faa8Sspupyrev     double MergedDensity = static_cast<double>(MergedCounts) / MergedSize;
1289bc59faa8Sspupyrev     double NewScore = MergedCounts * missProbability(MergedDensity);
1290bc59faa8Sspupyrev 
1291bc59faa8Sspupyrev     return CurScore - NewScore;
1292bc59faa8Sspupyrev   }
1293bc59faa8Sspupyrev 
1294bc59faa8Sspupyrev   /// Compute the distance locality for a jump / call.
1295bc59faa8Sspupyrev   double distScore(uint64_t SrcAddr, uint64_t DstAddr, uint64_t Count) const {
1296bc59faa8Sspupyrev     uint64_t Dist = SrcAddr <= DstAddr ? DstAddr - SrcAddr : SrcAddr - DstAddr;
1297bc59faa8Sspupyrev     double D = Dist == 0 ? 0.1 : static_cast<double>(Dist);
1298bc59faa8Sspupyrev     return static_cast<double>(Count) * std::pow(D, -Config.DistancePower);
1299bc59faa8Sspupyrev   }
1300bc59faa8Sspupyrev 
1301bc59faa8Sspupyrev   /// Compute the change of the distance locality after merging the chains.
1302b90fcafcSspupyrev   double distBasedLocalityGain(const MergedNodesT &Nodes,
1303b90fcafcSspupyrev                                const MergedJumpsT &Jumps) const {
1304bc59faa8Sspupyrev     uint64_t CurAddr = 0;
1305b90fcafcSspupyrev     Nodes.forEach([&](const NodeT *Node) {
1306bc59faa8Sspupyrev       Node->EstimatedAddr = CurAddr;
1307bc59faa8Sspupyrev       CurAddr += Node->Size;
1308bc59faa8Sspupyrev     });
1309bc59faa8Sspupyrev 
1310bc59faa8Sspupyrev     double CurScore = 0;
1311bc59faa8Sspupyrev     double NewScore = 0;
1312b90fcafcSspupyrev     Jumps.forEach([&](const JumpT *Jump) {
1313b90fcafcSspupyrev       uint64_t SrcAddr = Jump->Source->EstimatedAddr + Jump->Offset;
1314b90fcafcSspupyrev       uint64_t DstAddr = Jump->Target->EstimatedAddr;
1315b90fcafcSspupyrev       NewScore += distScore(SrcAddr, DstAddr, Jump->ExecutionCount);
1316b90fcafcSspupyrev       CurScore += distScore(0, TotalSize, Jump->ExecutionCount);
1317b90fcafcSspupyrev     });
1318bc59faa8Sspupyrev     return NewScore - CurScore;
1319bc59faa8Sspupyrev   }
1320bc59faa8Sspupyrev 
1321bc59faa8Sspupyrev   /// Merge chain From into chain Into, update the list of active chains,
1322bc59faa8Sspupyrev   /// adjacency information, and the corresponding cached values.
1323bc59faa8Sspupyrev   void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset,
1324bc59faa8Sspupyrev                    MergeTypeT MergeType) {
1325bc59faa8Sspupyrev     assert(Into != From && "a chain cannot be merged with itself");
1326bc59faa8Sspupyrev 
1327bc59faa8Sspupyrev     // Merge the nodes.
1328b90fcafcSspupyrev     MergedNodesT MergedNodes =
1329bc59faa8Sspupyrev         mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
1330bc59faa8Sspupyrev     Into->merge(From, MergedNodes.getNodes());
1331bc59faa8Sspupyrev 
1332bc59faa8Sspupyrev     // Merge the edges.
1333bc59faa8Sspupyrev     Into->mergeEdges(From);
1334bc59faa8Sspupyrev     From->clear();
1335bc59faa8Sspupyrev   }
1336bc59faa8Sspupyrev 
1337bc59faa8Sspupyrev   /// Concatenate all chains into the final order.
13386b8d04c2SFangrui Song   std::vector<uint64_t> concatChains() {
1339bc59faa8Sspupyrev     // Collect chains and calculate density stats for their sorting.
1340bc59faa8Sspupyrev     std::vector<const ChainT *> SortedChains;
1341bc59faa8Sspupyrev     DenseMap<const ChainT *, double> ChainDensity;
1342bc59faa8Sspupyrev     for (ChainT &Chain : AllChains) {
1343bc59faa8Sspupyrev       if (!Chain.Nodes.empty()) {
1344bc59faa8Sspupyrev         SortedChains.push_back(&Chain);
1345bc59faa8Sspupyrev         // Using doubles to avoid overflow of ExecutionCounts.
1346bc59faa8Sspupyrev         double Size = 0;
1347bc59faa8Sspupyrev         double ExecutionCount = 0;
1348bc59faa8Sspupyrev         for (NodeT *Node : Chain.Nodes) {
1349bc59faa8Sspupyrev           Size += static_cast<double>(Node->Size);
1350bc59faa8Sspupyrev           ExecutionCount += static_cast<double>(Node->ExecutionCount);
1351bc59faa8Sspupyrev         }
1352bc59faa8Sspupyrev         assert(Size > 0 && "a chain of zero size");
1353bc59faa8Sspupyrev         ChainDensity[&Chain] = ExecutionCount / Size;
1354bc59faa8Sspupyrev       }
1355bc59faa8Sspupyrev     }
1356bc59faa8Sspupyrev 
1357bc59faa8Sspupyrev     // Sort chains by density in the decreasing order.
1358fef144ceSKazu Hirata     std::sort(SortedChains.begin(), SortedChains.end(),
1359fef144ceSKazu Hirata               [&](const ChainT *L, const ChainT *R) {
1360bc59faa8Sspupyrev                 const double DL = ChainDensity[L];
1361bc59faa8Sspupyrev                 const double DR = ChainDensity[R];
1362bc59faa8Sspupyrev                 // Compare by density and break ties by chain identifiers.
1363fef144ceSKazu Hirata                 return std::make_tuple(-DL, L->Id) <
1364fef144ceSKazu Hirata                        std::make_tuple(-DR, R->Id);
1365bc59faa8Sspupyrev               });
1366bc59faa8Sspupyrev 
1367bc59faa8Sspupyrev     // Collect the nodes in the order specified by their chains.
13686b8d04c2SFangrui Song     std::vector<uint64_t> Order;
1369bc59faa8Sspupyrev     Order.reserve(NumNodes);
1370bc59faa8Sspupyrev     for (const ChainT *Chain : SortedChains)
1371bc59faa8Sspupyrev       for (NodeT *Node : Chain->Nodes)
1372bc59faa8Sspupyrev         Order.push_back(Node->Index);
13736b8d04c2SFangrui Song     return Order;
1374bc59faa8Sspupyrev   }
1375bc59faa8Sspupyrev 
1376bc59faa8Sspupyrev private:
1377bc59faa8Sspupyrev   /// Config for the algorithm.
1378bc59faa8Sspupyrev   const CDSortConfig Config;
1379bc59faa8Sspupyrev 
1380bc59faa8Sspupyrev   /// The number of nodes in the graph.
1381bc59faa8Sspupyrev   const size_t NumNodes;
1382bc59faa8Sspupyrev 
1383bc59faa8Sspupyrev   /// Successors of each node.
1384bc59faa8Sspupyrev   std::vector<std::vector<uint64_t>> SuccNodes;
1385bc59faa8Sspupyrev 
1386bc59faa8Sspupyrev   /// Predecessors of each node.
1387bc59faa8Sspupyrev   std::vector<std::vector<uint64_t>> PredNodes;
1388bc59faa8Sspupyrev 
1389bc59faa8Sspupyrev   /// All nodes (functions) in the graph.
1390bc59faa8Sspupyrev   std::vector<NodeT> AllNodes;
1391bc59faa8Sspupyrev 
1392bc59faa8Sspupyrev   /// All jumps (function calls) between the nodes.
1393bc59faa8Sspupyrev   std::vector<JumpT> AllJumps;
1394bc59faa8Sspupyrev 
1395bc59faa8Sspupyrev   /// All chains of nodes.
1396bc59faa8Sspupyrev   std::vector<ChainT> AllChains;
1397bc59faa8Sspupyrev 
1398bc59faa8Sspupyrev   /// All edges between the chains.
1399bc59faa8Sspupyrev   std::vector<ChainEdge> AllEdges;
1400bc59faa8Sspupyrev 
1401bc59faa8Sspupyrev   /// The total number of samples in the graph.
1402bc59faa8Sspupyrev   uint64_t TotalSamples{0};
1403bc59faa8Sspupyrev 
1404bc59faa8Sspupyrev   /// The total size of the nodes in the graph.
1405bc59faa8Sspupyrev   uint64_t TotalSize{0};
1406bc59faa8Sspupyrev };
1407bc59faa8Sspupyrev 
1408f573f686Sspupyrev } // end of anonymous namespace
1409f573f686Sspupyrev 
1410a7e13a99Sspupyrev std::vector<uint64_t>
14116b8d04c2SFangrui Song codelayout::computeExtTspLayout(ArrayRef<uint64_t> NodeSizes,
14126b8d04c2SFangrui Song                                 ArrayRef<uint64_t> NodeCounts,
14136b8d04c2SFangrui Song                                 ArrayRef<EdgeCount> EdgeCounts) {
1414bc59faa8Sspupyrev   // Verify correctness of the input data.
1415f573f686Sspupyrev   assert(NodeCounts.size() == NodeSizes.size() && "Incorrect input");
1416a7e13a99Sspupyrev   assert(NodeSizes.size() > 2 && "Incorrect input");
1417f573f686Sspupyrev 
1418bc59faa8Sspupyrev   // Apply the reordering algorithm.
1419a7e13a99Sspupyrev   ExtTSPImpl Alg(NodeSizes, NodeCounts, EdgeCounts);
14206b8d04c2SFangrui Song   std::vector<uint64_t> Result = Alg.run();
1421f573f686Sspupyrev 
1422bc59faa8Sspupyrev   // Verify correctness of the output.
1423f573f686Sspupyrev   assert(Result.front() == 0 && "Original entry point is not preserved");
1424a7e13a99Sspupyrev   assert(Result.size() == NodeSizes.size() && "Incorrect size of layout");
1425f573f686Sspupyrev   return Result;
1426f573f686Sspupyrev }
1427f573f686Sspupyrev 
14286b8d04c2SFangrui Song double codelayout::calcExtTspScore(ArrayRef<uint64_t> Order,
14296b8d04c2SFangrui Song                                    ArrayRef<uint64_t> NodeSizes,
14306b8d04c2SFangrui Song                                    ArrayRef<EdgeCount> EdgeCounts) {
1431bc59faa8Sspupyrev   // Estimate addresses of the blocks in memory.
1432*fbec1c2aSEllis Hoag   SmallVector<uint64_t> Addr(NodeSizes.size(), 0);
1433*fbec1c2aSEllis Hoag   for (uint64_t Idx = 1; Idx < Order.size(); Idx++)
1434f573f686Sspupyrev     Addr[Order[Idx]] = Addr[Order[Idx - 1]] + NodeSizes[Order[Idx - 1]];
1435*fbec1c2aSEllis Hoag   SmallVector<uint64_t> OutDegree(NodeSizes.size(), 0);
1436*fbec1c2aSEllis Hoag   for (auto &Edge : EdgeCounts)
14376b8d04c2SFangrui Song     ++OutDegree[Edge.src];
1438f573f686Sspupyrev 
1439bc59faa8Sspupyrev   // Increase the score for each jump.
1440f573f686Sspupyrev   double Score = 0;
1441*fbec1c2aSEllis Hoag   for (auto &Edge : EdgeCounts) {
14426b8d04c2SFangrui Song     bool IsConditional = OutDegree[Edge.src] > 1;
14436b8d04c2SFangrui Song     Score += ::extTSPScore(Addr[Edge.src], NodeSizes[Edge.src], Addr[Edge.dst],
14446b8d04c2SFangrui Song                            Edge.count, IsConditional);
1445f573f686Sspupyrev   }
1446f573f686Sspupyrev   return Score;
1447f573f686Sspupyrev }
1448f573f686Sspupyrev 
14496b8d04c2SFangrui Song double codelayout::calcExtTspScore(ArrayRef<uint64_t> NodeSizes,
14506b8d04c2SFangrui Song                                    ArrayRef<EdgeCount> EdgeCounts) {
1451*fbec1c2aSEllis Hoag   SmallVector<uint64_t> Order(NodeSizes.size());
1452*fbec1c2aSEllis Hoag   for (uint64_t Idx = 0; Idx < NodeSizes.size(); Idx++)
1453f573f686Sspupyrev     Order[Idx] = Idx;
1454*fbec1c2aSEllis Hoag   return calcExtTspScore(Order, NodeSizes, EdgeCounts);
1455f573f686Sspupyrev }
1456bc59faa8Sspupyrev 
14576b8d04c2SFangrui Song std::vector<uint64_t> codelayout::computeCacheDirectedLayout(
14586b8d04c2SFangrui Song     const CDSortConfig &Config, ArrayRef<uint64_t> FuncSizes,
14596b8d04c2SFangrui Song     ArrayRef<uint64_t> FuncCounts, ArrayRef<EdgeCount> CallCounts,
14606b8d04c2SFangrui Song     ArrayRef<uint64_t> CallOffsets) {
1461bc59faa8Sspupyrev   // Verify correctness of the input data.
1462bc59faa8Sspupyrev   assert(FuncCounts.size() == FuncSizes.size() && "Incorrect input");
1463bc59faa8Sspupyrev 
1464bc59faa8Sspupyrev   // Apply the reordering algorithm.
1465bc59faa8Sspupyrev   CDSortImpl Alg(Config, FuncSizes, FuncCounts, CallCounts, CallOffsets);
14666b8d04c2SFangrui Song   std::vector<uint64_t> Result = Alg.run();
1467bc59faa8Sspupyrev   assert(Result.size() == FuncSizes.size() && "Incorrect size of layout");
1468bc59faa8Sspupyrev   return Result;
1469bc59faa8Sspupyrev }
1470bc59faa8Sspupyrev 
14716b8d04c2SFangrui Song std::vector<uint64_t> codelayout::computeCacheDirectedLayout(
14726b8d04c2SFangrui Song     ArrayRef<uint64_t> FuncSizes, ArrayRef<uint64_t> FuncCounts,
14736b8d04c2SFangrui Song     ArrayRef<EdgeCount> CallCounts, ArrayRef<uint64_t> CallOffsets) {
1474bc59faa8Sspupyrev   CDSortConfig Config;
1475bc59faa8Sspupyrev   // Populate the config from the command-line options.
1476bc59faa8Sspupyrev   if (CacheEntries.getNumOccurrences() > 0)
1477bc59faa8Sspupyrev     Config.CacheEntries = CacheEntries;
1478bc59faa8Sspupyrev   if (CacheSize.getNumOccurrences() > 0)
1479bc59faa8Sspupyrev     Config.CacheSize = CacheSize;
1480a2441837SFangrui Song   if (CDMaxChainSize.getNumOccurrences() > 0)
1481a2441837SFangrui Song     Config.MaxChainSize = CDMaxChainSize;
1482bc59faa8Sspupyrev   if (DistancePower.getNumOccurrences() > 0)
1483bc59faa8Sspupyrev     Config.DistancePower = DistancePower;
1484bc59faa8Sspupyrev   if (FrequencyScale.getNumOccurrences() > 0)
1485bc59faa8Sspupyrev     Config.FrequencyScale = FrequencyScale;
14866b8d04c2SFangrui Song   return computeCacheDirectedLayout(Config, FuncSizes, FuncCounts, CallCounts,
14876b8d04c2SFangrui Song                                     CallOffsets);
1488bc59faa8Sspupyrev }
1489