1*7330f729Sjoerg //===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
2*7330f729Sjoerg //
3*7330f729Sjoerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*7330f729Sjoerg // See https://llvm.org/LICENSE.txt for license information.
5*7330f729Sjoerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*7330f729Sjoerg //
7*7330f729Sjoerg //===----------------------------------------------------------------------===//
8*7330f729Sjoerg //
9*7330f729Sjoerg // This file defines utilities for propagating synthetic counts.
10*7330f729Sjoerg //
11*7330f729Sjoerg //===----------------------------------------------------------------------===//
12*7330f729Sjoerg
13*7330f729Sjoerg #include "llvm/Analysis/SyntheticCountsUtils.h"
14*7330f729Sjoerg #include "llvm/ADT/DenseSet.h"
15*7330f729Sjoerg #include "llvm/ADT/SCCIterator.h"
16*7330f729Sjoerg #include "llvm/Analysis/CallGraph.h"
17*7330f729Sjoerg #include "llvm/IR/Function.h"
18*7330f729Sjoerg #include "llvm/IR/InstIterator.h"
19*7330f729Sjoerg #include "llvm/IR/Instructions.h"
20*7330f729Sjoerg #include "llvm/IR/ModuleSummaryIndex.h"
21*7330f729Sjoerg
22*7330f729Sjoerg using namespace llvm;
23*7330f729Sjoerg
24*7330f729Sjoerg // Given an SCC, propagate entry counts along the edge of the SCC nodes.
25*7330f729Sjoerg template <typename CallGraphType>
propagateFromSCC(const SccTy & SCC,GetProfCountTy GetProfCount,AddCountTy AddCount)26*7330f729Sjoerg void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
27*7330f729Sjoerg const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) {
28*7330f729Sjoerg
29*7330f729Sjoerg DenseSet<NodeRef> SCCNodes;
30*7330f729Sjoerg SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges;
31*7330f729Sjoerg
32*7330f729Sjoerg for (auto &Node : SCC)
33*7330f729Sjoerg SCCNodes.insert(Node);
34*7330f729Sjoerg
35*7330f729Sjoerg // Partition the edges coming out of the SCC into those whose destination is
36*7330f729Sjoerg // in the SCC and the rest.
37*7330f729Sjoerg for (const auto &Node : SCCNodes) {
38*7330f729Sjoerg for (auto &E : children_edges<CallGraphType>(Node)) {
39*7330f729Sjoerg if (SCCNodes.count(CGT::edge_dest(E)))
40*7330f729Sjoerg SCCEdges.emplace_back(Node, E);
41*7330f729Sjoerg else
42*7330f729Sjoerg NonSCCEdges.emplace_back(Node, E);
43*7330f729Sjoerg }
44*7330f729Sjoerg }
45*7330f729Sjoerg
46*7330f729Sjoerg // For nodes in the same SCC, update the counts in two steps:
47*7330f729Sjoerg // 1. Compute the additional count for each node by propagating the counts
48*7330f729Sjoerg // along all incoming edges to the node that originate from within the same
49*7330f729Sjoerg // SCC and summing them up.
50*7330f729Sjoerg // 2. Add the additional counts to the nodes in the SCC.
51*7330f729Sjoerg // This ensures that the order of
52*7330f729Sjoerg // traversal of nodes within the SCC doesn't affect the final result.
53*7330f729Sjoerg
54*7330f729Sjoerg DenseMap<NodeRef, Scaled64> AdditionalCounts;
55*7330f729Sjoerg for (auto &E : SCCEdges) {
56*7330f729Sjoerg auto OptProfCount = GetProfCount(E.first, E.second);
57*7330f729Sjoerg if (!OptProfCount)
58*7330f729Sjoerg continue;
59*7330f729Sjoerg auto Callee = CGT::edge_dest(E.second);
60*7330f729Sjoerg AdditionalCounts[Callee] += OptProfCount.getValue();
61*7330f729Sjoerg }
62*7330f729Sjoerg
63*7330f729Sjoerg // Update the counts for the nodes in the SCC.
64*7330f729Sjoerg for (auto &Entry : AdditionalCounts)
65*7330f729Sjoerg AddCount(Entry.first, Entry.second);
66*7330f729Sjoerg
67*7330f729Sjoerg // Now update the counts for nodes outside the SCC.
68*7330f729Sjoerg for (auto &E : NonSCCEdges) {
69*7330f729Sjoerg auto OptProfCount = GetProfCount(E.first, E.second);
70*7330f729Sjoerg if (!OptProfCount)
71*7330f729Sjoerg continue;
72*7330f729Sjoerg auto Callee = CGT::edge_dest(E.second);
73*7330f729Sjoerg AddCount(Callee, OptProfCount.getValue());
74*7330f729Sjoerg }
75*7330f729Sjoerg }
76*7330f729Sjoerg
77*7330f729Sjoerg /// Propgate synthetic entry counts on a callgraph \p CG.
78*7330f729Sjoerg ///
79*7330f729Sjoerg /// This performs a reverse post-order traversal of the callgraph SCC. For each
80*7330f729Sjoerg /// SCC, it first propagates the entry counts to the nodes within the SCC
81*7330f729Sjoerg /// through call edges and updates them in one shot. Then the entry counts are
82*7330f729Sjoerg /// propagated to nodes outside the SCC. This requires \p GraphTraits
83*7330f729Sjoerg /// to have a specialization for \p CallGraphType.
84*7330f729Sjoerg
85*7330f729Sjoerg template <typename CallGraphType>
propagate(const CallGraphType & CG,GetProfCountTy GetProfCount,AddCountTy AddCount)86*7330f729Sjoerg void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
87*7330f729Sjoerg GetProfCountTy GetProfCount,
88*7330f729Sjoerg AddCountTy AddCount) {
89*7330f729Sjoerg std::vector<SccTy> SCCs;
90*7330f729Sjoerg
91*7330f729Sjoerg // Collect all the SCCs.
92*7330f729Sjoerg for (auto I = scc_begin(CG); !I.isAtEnd(); ++I)
93*7330f729Sjoerg SCCs.push_back(*I);
94*7330f729Sjoerg
95*7330f729Sjoerg // The callgraph-scc needs to be visited in top-down order for propagation.
96*7330f729Sjoerg // The scc iterator returns the scc in bottom-up order, so reverse the SCCs
97*7330f729Sjoerg // and call propagateFromSCC.
98*7330f729Sjoerg for (auto &SCC : reverse(SCCs))
99*7330f729Sjoerg propagateFromSCC(SCC, GetProfCount, AddCount);
100*7330f729Sjoerg }
101*7330f729Sjoerg
102*7330f729Sjoerg template class llvm::SyntheticCountsUtils<const CallGraph *>;
103*7330f729Sjoerg template class llvm::SyntheticCountsUtils<ModuleSummaryIndex *>;
104