1bdf20261SEaswaran Raman //===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
2bdf20261SEaswaran Raman //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdf20261SEaswaran Raman //
7bdf20261SEaswaran Raman //===----------------------------------------------------------------------===//
8bdf20261SEaswaran Raman //
9bdf20261SEaswaran Raman // This file defines utilities for propagating synthetic counts.
10bdf20261SEaswaran Raman //
11bdf20261SEaswaran Raman //===----------------------------------------------------------------------===//
12bdf20261SEaswaran Raman
13bdf20261SEaswaran Raman #include "llvm/Analysis/SyntheticCountsUtils.h"
14bdf20261SEaswaran Raman #include "llvm/ADT/DenseSet.h"
15bdf20261SEaswaran Raman #include "llvm/ADT/SCCIterator.h"
16bdf20261SEaswaran Raman #include "llvm/Analysis/CallGraph.h"
175a7056faSEaswaran Raman #include "llvm/IR/ModuleSummaryIndex.h"
18bdf20261SEaswaran Raman
19bdf20261SEaswaran Raman using namespace llvm;
20bdf20261SEaswaran Raman
218410c374SEaswaran Raman // Given an SCC, propagate entry counts along the edge of the SCC nodes.
228410c374SEaswaran Raman template <typename CallGraphType>
propagateFromSCC(const SccTy & SCC,GetProfCountTy GetProfCount,AddCountTy AddCount)238410c374SEaswaran Raman void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
24b45994b8SEaswaran Raman const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) {
25bdf20261SEaswaran Raman
265a7056faSEaswaran Raman DenseSet<NodeRef> SCCNodes;
278410c374SEaswaran Raman SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges;
28bdf20261SEaswaran Raman
298410c374SEaswaran Raman for (auto &Node : SCC)
308410c374SEaswaran Raman SCCNodes.insert(Node);
318410c374SEaswaran Raman
328410c374SEaswaran Raman // Partition the edges coming out of the SCC into those whose destination is
338410c374SEaswaran Raman // in the SCC and the rest.
348410c374SEaswaran Raman for (const auto &Node : SCCNodes) {
3506a71533SEaswaran Raman for (auto &E : children_edges<CallGraphType>(Node)) {
368410c374SEaswaran Raman if (SCCNodes.count(CGT::edge_dest(E)))
378410c374SEaswaran Raman SCCEdges.emplace_back(Node, E);
388410c374SEaswaran Raman else
398410c374SEaswaran Raman NonSCCEdges.emplace_back(Node, E);
40bdf20261SEaswaran Raman }
41bdf20261SEaswaran Raman }
42bdf20261SEaswaran Raman
438410c374SEaswaran Raman // For nodes in the same SCC, update the counts in two steps:
448410c374SEaswaran Raman // 1. Compute the additional count for each node by propagating the counts
458410c374SEaswaran Raman // along all incoming edges to the node that originate from within the same
468410c374SEaswaran Raman // SCC and summing them up.
478410c374SEaswaran Raman // 2. Add the additional counts to the nodes in the SCC.
48bdf20261SEaswaran Raman // This ensures that the order of
498410c374SEaswaran Raman // traversal of nodes within the SCC doesn't affect the final result.
50bdf20261SEaswaran Raman
51b45994b8SEaswaran Raman DenseMap<NodeRef, Scaled64> AdditionalCounts;
528410c374SEaswaran Raman for (auto &E : SCCEdges) {
53b45994b8SEaswaran Raman auto OptProfCount = GetProfCount(E.first, E.second);
54b45994b8SEaswaran Raman if (!OptProfCount)
558410c374SEaswaran Raman continue;
568410c374SEaswaran Raman auto Callee = CGT::edge_dest(E.second);
57*7a47ee51SKazu Hirata AdditionalCounts[Callee] += *OptProfCount;
58bdf20261SEaswaran Raman }
59bdf20261SEaswaran Raman
608410c374SEaswaran Raman // Update the counts for the nodes in the SCC.
61bdf20261SEaswaran Raman for (auto &Entry : AdditionalCounts)
628410c374SEaswaran Raman AddCount(Entry.first, Entry.second);
63bdf20261SEaswaran Raman
648410c374SEaswaran Raman // Now update the counts for nodes outside the SCC.
658410c374SEaswaran Raman for (auto &E : NonSCCEdges) {
66b45994b8SEaswaran Raman auto OptProfCount = GetProfCount(E.first, E.second);
67b45994b8SEaswaran Raman if (!OptProfCount)
688410c374SEaswaran Raman continue;
698410c374SEaswaran Raman auto Callee = CGT::edge_dest(E.second);
70*7a47ee51SKazu Hirata AddCount(Callee, *OptProfCount);
71bdf20261SEaswaran Raman }
72bdf20261SEaswaran Raman }
73bdf20261SEaswaran Raman
748410c374SEaswaran Raman /// Propgate synthetic entry counts on a callgraph \p CG.
75bdf20261SEaswaran Raman ///
76bdf20261SEaswaran Raman /// This performs a reverse post-order traversal of the callgraph SCC. For each
778410c374SEaswaran Raman /// SCC, it first propagates the entry counts to the nodes within the SCC
78bdf20261SEaswaran Raman /// through call edges and updates them in one shot. Then the entry counts are
7906a71533SEaswaran Raman /// propagated to nodes outside the SCC. This requires \p GraphTraits
808410c374SEaswaran Raman /// to have a specialization for \p CallGraphType.
81bdf20261SEaswaran Raman
828410c374SEaswaran Raman template <typename CallGraphType>
propagate(const CallGraphType & CG,GetProfCountTy GetProfCount,AddCountTy AddCount)838410c374SEaswaran Raman void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
84b45994b8SEaswaran Raman GetProfCountTy GetProfCount,
858410c374SEaswaran Raman AddCountTy AddCount) {
868410c374SEaswaran Raman std::vector<SccTy> SCCs;
87bdf20261SEaswaran Raman
888410c374SEaswaran Raman // Collect all the SCCs.
898410c374SEaswaran Raman for (auto I = scc_begin(CG); !I.isAtEnd(); ++I)
908410c374SEaswaran Raman SCCs.push_back(*I);
918410c374SEaswaran Raman
928410c374SEaswaran Raman // The callgraph-scc needs to be visited in top-down order for propagation.
938410c374SEaswaran Raman // The scc iterator returns the scc in bottom-up order, so reverse the SCCs
948410c374SEaswaran Raman // and call propagateFromSCC.
958410c374SEaswaran Raman for (auto &SCC : reverse(SCCs))
96b45994b8SEaswaran Raman propagateFromSCC(SCC, GetProfCount, AddCount);
97bdf20261SEaswaran Raman }
98bdf20261SEaswaran Raman
998410c374SEaswaran Raman template class llvm::SyntheticCountsUtils<const CallGraph *>;
1005a7056faSEaswaran Raman template class llvm::SyntheticCountsUtils<ModuleSummaryIndex *>;
101