109467b48Spatrick //===-- SpeculateAnalyses.cpp --*- C++ -*-===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick
909467b48Spatrick #include "llvm/ExecutionEngine/Orc/SpeculateAnalyses.h"
1009467b48Spatrick #include "llvm/ADT/ArrayRef.h"
1109467b48Spatrick #include "llvm/ADT/DenseMap.h"
1209467b48Spatrick #include "llvm/ADT/STLExtras.h"
1309467b48Spatrick #include "llvm/ADT/SmallPtrSet.h"
1409467b48Spatrick #include "llvm/ADT/SmallVector.h"
1509467b48Spatrick #include "llvm/Analysis/BlockFrequencyInfo.h"
1609467b48Spatrick #include "llvm/Analysis/BranchProbabilityInfo.h"
1709467b48Spatrick #include "llvm/Analysis/CFG.h"
1809467b48Spatrick #include "llvm/IR/PassManager.h"
1909467b48Spatrick #include "llvm/Passes/PassBuilder.h"
2009467b48Spatrick #include "llvm/Support/ErrorHandling.h"
2109467b48Spatrick
2209467b48Spatrick #include <algorithm>
2309467b48Spatrick
2409467b48Spatrick namespace {
2509467b48Spatrick using namespace llvm;
findBBwithCalls(const Function & F,bool IndirectCall=false)2609467b48Spatrick SmallVector<const BasicBlock *, 8> findBBwithCalls(const Function &F,
2709467b48Spatrick bool IndirectCall = false) {
2809467b48Spatrick SmallVector<const BasicBlock *, 8> BBs;
2909467b48Spatrick
3009467b48Spatrick auto findCallInst = [&IndirectCall](const Instruction &I) {
3109467b48Spatrick if (auto Call = dyn_cast<CallBase>(&I))
3209467b48Spatrick return Call->isIndirectCall() ? IndirectCall : true;
3309467b48Spatrick else
3409467b48Spatrick return false;
3509467b48Spatrick };
3609467b48Spatrick for (auto &BB : F)
3709467b48Spatrick if (findCallInst(*BB.getTerminator()) ||
3809467b48Spatrick llvm::any_of(BB.instructionsWithoutDebug(), findCallInst))
3909467b48Spatrick BBs.emplace_back(&BB);
4009467b48Spatrick
4109467b48Spatrick return BBs;
4209467b48Spatrick }
4309467b48Spatrick } // namespace
4409467b48Spatrick
4509467b48Spatrick // Implementations of Queries shouldn't need to lock the resources
4609467b48Spatrick // such as LLVMContext, each argument (function) has a non-shared LLVMContext
4709467b48Spatrick // Plus, if Queries contain states necessary locking scheme should be provided.
4809467b48Spatrick namespace llvm {
4909467b48Spatrick namespace orc {
5009467b48Spatrick
5109467b48Spatrick // Collect direct calls only
findCalles(const BasicBlock * BB,DenseSet<StringRef> & CallesNames)5209467b48Spatrick void SpeculateQuery::findCalles(const BasicBlock *BB,
5309467b48Spatrick DenseSet<StringRef> &CallesNames) {
5409467b48Spatrick assert(BB != nullptr && "Traversing Null BB to find calls?");
5509467b48Spatrick
5609467b48Spatrick auto getCalledFunction = [&CallesNames](const CallBase *Call) {
5709467b48Spatrick auto CalledValue = Call->getCalledOperand()->stripPointerCasts();
5809467b48Spatrick if (auto DirectCall = dyn_cast<Function>(CalledValue))
5909467b48Spatrick CallesNames.insert(DirectCall->getName());
6009467b48Spatrick };
6109467b48Spatrick for (auto &I : BB->instructionsWithoutDebug())
6209467b48Spatrick if (auto CI = dyn_cast<CallInst>(&I))
6309467b48Spatrick getCalledFunction(CI);
6409467b48Spatrick
6509467b48Spatrick if (auto II = dyn_cast<InvokeInst>(BB->getTerminator()))
6609467b48Spatrick getCalledFunction(II);
6709467b48Spatrick }
6809467b48Spatrick
isStraightLine(const Function & F)6909467b48Spatrick bool SpeculateQuery::isStraightLine(const Function &F) {
70*d415bd75Srobert return llvm::all_of(F, [](const BasicBlock &BB) {
7109467b48Spatrick return BB.getSingleSuccessor() != nullptr;
7209467b48Spatrick });
7309467b48Spatrick }
7409467b48Spatrick
7509467b48Spatrick // BlockFreqQuery Implementations
7609467b48Spatrick
numBBToGet(size_t numBB)7709467b48Spatrick size_t BlockFreqQuery::numBBToGet(size_t numBB) {
7809467b48Spatrick // small CFG
7909467b48Spatrick if (numBB < 4)
8009467b48Spatrick return numBB;
8109467b48Spatrick // mid-size CFG
8209467b48Spatrick else if (numBB < 20)
8309467b48Spatrick return (numBB / 2);
8409467b48Spatrick else
8509467b48Spatrick return (numBB / 2) + (numBB / 4);
8609467b48Spatrick }
8709467b48Spatrick
operator ()(Function & F)8809467b48Spatrick BlockFreqQuery::ResultTy BlockFreqQuery::operator()(Function &F) {
8909467b48Spatrick DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
9009467b48Spatrick DenseSet<StringRef> Calles;
9109467b48Spatrick SmallVector<std::pair<const BasicBlock *, uint64_t>, 8> BBFreqs;
9209467b48Spatrick
9309467b48Spatrick PassBuilder PB;
9409467b48Spatrick FunctionAnalysisManager FAM;
9509467b48Spatrick PB.registerFunctionAnalyses(FAM);
9609467b48Spatrick
9709467b48Spatrick auto IBBs = findBBwithCalls(F);
9809467b48Spatrick
9909467b48Spatrick if (IBBs.empty())
100*d415bd75Srobert return std::nullopt;
10109467b48Spatrick
10209467b48Spatrick auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
10309467b48Spatrick
10409467b48Spatrick for (const auto I : IBBs)
10509467b48Spatrick BBFreqs.push_back({I, BFI.getBlockFreq(I).getFrequency()});
10609467b48Spatrick
10709467b48Spatrick assert(IBBs.size() == BBFreqs.size() && "BB Count Mismatch");
10809467b48Spatrick
10973471bf0Spatrick llvm::sort(BBFreqs, [](decltype(BBFreqs)::const_reference BBF,
11009467b48Spatrick decltype(BBFreqs)::const_reference BBS) {
11109467b48Spatrick return BBF.second > BBS.second ? true : false;
11209467b48Spatrick });
11309467b48Spatrick
11409467b48Spatrick // ignoring number of direct calls in a BB
11509467b48Spatrick auto Topk = numBBToGet(BBFreqs.size());
11609467b48Spatrick
11709467b48Spatrick for (size_t i = 0; i < Topk; i++)
11809467b48Spatrick findCalles(BBFreqs[i].first, Calles);
11909467b48Spatrick
12009467b48Spatrick assert(!Calles.empty() && "Running Analysis on Function with no calls?");
12109467b48Spatrick
12209467b48Spatrick CallerAndCalles.insert({F.getName(), std::move(Calles)});
12309467b48Spatrick
12409467b48Spatrick return CallerAndCalles;
12509467b48Spatrick }
12609467b48Spatrick
12709467b48Spatrick // SequenceBBQuery Implementation
getHottestBlocks(std::size_t TotalBlocks)12809467b48Spatrick std::size_t SequenceBBQuery::getHottestBlocks(std::size_t TotalBlocks) {
12909467b48Spatrick if (TotalBlocks == 1)
13009467b48Spatrick return TotalBlocks;
13109467b48Spatrick return TotalBlocks / 2;
13209467b48Spatrick }
13309467b48Spatrick
13409467b48Spatrick // FIXME : find good implementation.
13509467b48Spatrick SequenceBBQuery::BlockListTy
rearrangeBB(const Function & F,const BlockListTy & BBList)13609467b48Spatrick SequenceBBQuery::rearrangeBB(const Function &F, const BlockListTy &BBList) {
13709467b48Spatrick BlockListTy RearrangedBBSet;
13809467b48Spatrick
139*d415bd75Srobert for (auto &Block : F)
14009467b48Spatrick if (llvm::is_contained(BBList, &Block))
14109467b48Spatrick RearrangedBBSet.push_back(&Block);
14209467b48Spatrick
14309467b48Spatrick assert(RearrangedBBSet.size() == BBList.size() &&
14409467b48Spatrick "BasicBlock missing while rearranging?");
14509467b48Spatrick return RearrangedBBSet;
14609467b48Spatrick }
14709467b48Spatrick
traverseToEntryBlock(const BasicBlock * AtBB,const BlockListTy & CallerBlocks,const BackEdgesInfoTy & BackEdgesInfo,const BranchProbabilityInfo * BPI,VisitedBlocksInfoTy & VisitedBlocks)14809467b48Spatrick void SequenceBBQuery::traverseToEntryBlock(const BasicBlock *AtBB,
14909467b48Spatrick const BlockListTy &CallerBlocks,
15009467b48Spatrick const BackEdgesInfoTy &BackEdgesInfo,
15109467b48Spatrick const BranchProbabilityInfo *BPI,
15209467b48Spatrick VisitedBlocksInfoTy &VisitedBlocks) {
15309467b48Spatrick auto Itr = VisitedBlocks.find(AtBB);
15409467b48Spatrick if (Itr != VisitedBlocks.end()) { // already visited.
15509467b48Spatrick if (!Itr->second.Upward)
15609467b48Spatrick return;
15709467b48Spatrick Itr->second.Upward = false;
15809467b48Spatrick } else {
15909467b48Spatrick // Create hint for newly discoverd blocks.
16009467b48Spatrick WalkDirection BlockHint;
16109467b48Spatrick BlockHint.Upward = false;
16209467b48Spatrick // FIXME: Expensive Check
16309467b48Spatrick if (llvm::is_contained(CallerBlocks, AtBB))
16409467b48Spatrick BlockHint.CallerBlock = true;
16509467b48Spatrick VisitedBlocks.insert(std::make_pair(AtBB, BlockHint));
16609467b48Spatrick }
16709467b48Spatrick
16809467b48Spatrick const_pred_iterator PIt = pred_begin(AtBB), EIt = pred_end(AtBB);
16909467b48Spatrick // Move this check to top, when we have code setup to launch speculative
17009467b48Spatrick // compiles for function in entry BB, this triggers the speculative compiles
17109467b48Spatrick // before running the program.
17209467b48Spatrick if (PIt == EIt) // No Preds.
17309467b48Spatrick return;
17409467b48Spatrick
17509467b48Spatrick DenseSet<const BasicBlock *> PredSkipNodes;
17609467b48Spatrick
17709467b48Spatrick // Since we are checking for predecessor's backedges, this Block
17809467b48Spatrick // occurs in second position.
17909467b48Spatrick for (auto &I : BackEdgesInfo)
18009467b48Spatrick if (I.second == AtBB)
18109467b48Spatrick PredSkipNodes.insert(I.first);
18209467b48Spatrick
18309467b48Spatrick // Skip predecessors which source of back-edges.
18409467b48Spatrick for (; PIt != EIt; ++PIt)
18509467b48Spatrick // checking EdgeHotness is cheaper
18609467b48Spatrick if (BPI->isEdgeHot(*PIt, AtBB) && !PredSkipNodes.count(*PIt))
18709467b48Spatrick traverseToEntryBlock(*PIt, CallerBlocks, BackEdgesInfo, BPI,
18809467b48Spatrick VisitedBlocks);
18909467b48Spatrick }
19009467b48Spatrick
traverseToExitBlock(const BasicBlock * AtBB,const BlockListTy & CallerBlocks,const BackEdgesInfoTy & BackEdgesInfo,const BranchProbabilityInfo * BPI,VisitedBlocksInfoTy & VisitedBlocks)19109467b48Spatrick void SequenceBBQuery::traverseToExitBlock(const BasicBlock *AtBB,
19209467b48Spatrick const BlockListTy &CallerBlocks,
19309467b48Spatrick const BackEdgesInfoTy &BackEdgesInfo,
19409467b48Spatrick const BranchProbabilityInfo *BPI,
19509467b48Spatrick VisitedBlocksInfoTy &VisitedBlocks) {
19609467b48Spatrick auto Itr = VisitedBlocks.find(AtBB);
19709467b48Spatrick if (Itr != VisitedBlocks.end()) { // already visited.
19809467b48Spatrick if (!Itr->second.Downward)
19909467b48Spatrick return;
20009467b48Spatrick Itr->second.Downward = false;
20109467b48Spatrick } else {
20209467b48Spatrick // Create hint for newly discoverd blocks.
20309467b48Spatrick WalkDirection BlockHint;
20409467b48Spatrick BlockHint.Downward = false;
20509467b48Spatrick // FIXME: Expensive Check
20609467b48Spatrick if (llvm::is_contained(CallerBlocks, AtBB))
20709467b48Spatrick BlockHint.CallerBlock = true;
20809467b48Spatrick VisitedBlocks.insert(std::make_pair(AtBB, BlockHint));
20909467b48Spatrick }
21009467b48Spatrick
211097a140dSpatrick const_succ_iterator PIt = succ_begin(AtBB), EIt = succ_end(AtBB);
21209467b48Spatrick if (PIt == EIt) // No succs.
21309467b48Spatrick return;
21409467b48Spatrick
21509467b48Spatrick // If there are hot edges, then compute SuccSkipNodes.
21609467b48Spatrick DenseSet<const BasicBlock *> SuccSkipNodes;
21709467b48Spatrick
21809467b48Spatrick // Since we are checking for successor's backedges, this Block
21909467b48Spatrick // occurs in first position.
22009467b48Spatrick for (auto &I : BackEdgesInfo)
22109467b48Spatrick if (I.first == AtBB)
22209467b48Spatrick SuccSkipNodes.insert(I.second);
22309467b48Spatrick
22409467b48Spatrick for (; PIt != EIt; ++PIt)
22509467b48Spatrick if (BPI->isEdgeHot(AtBB, *PIt) && !SuccSkipNodes.count(*PIt))
22609467b48Spatrick traverseToExitBlock(*PIt, CallerBlocks, BackEdgesInfo, BPI,
22709467b48Spatrick VisitedBlocks);
22809467b48Spatrick }
22909467b48Spatrick
23009467b48Spatrick // Get Block frequencies for blocks and take most frquently executed block,
23109467b48Spatrick // walk towards the entry block from those blocks and discover the basic blocks
23209467b48Spatrick // with call.
23309467b48Spatrick SequenceBBQuery::BlockListTy
queryCFG(Function & F,const BlockListTy & CallerBlocks)23409467b48Spatrick SequenceBBQuery::queryCFG(Function &F, const BlockListTy &CallerBlocks) {
23509467b48Spatrick
23609467b48Spatrick BlockFreqInfoTy BBFreqs;
23709467b48Spatrick VisitedBlocksInfoTy VisitedBlocks;
23809467b48Spatrick BackEdgesInfoTy BackEdgesInfo;
23909467b48Spatrick
24009467b48Spatrick PassBuilder PB;
24109467b48Spatrick FunctionAnalysisManager FAM;
24209467b48Spatrick PB.registerFunctionAnalyses(FAM);
24309467b48Spatrick
24409467b48Spatrick auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
24509467b48Spatrick
24609467b48Spatrick llvm::FindFunctionBackedges(F, BackEdgesInfo);
24709467b48Spatrick
24809467b48Spatrick for (const auto I : CallerBlocks)
24909467b48Spatrick BBFreqs.push_back({I, BFI.getBlockFreq(I).getFrequency()});
25009467b48Spatrick
25109467b48Spatrick llvm::sort(BBFreqs, [](decltype(BBFreqs)::const_reference Bbf,
25209467b48Spatrick decltype(BBFreqs)::const_reference Bbs) {
25309467b48Spatrick return Bbf.second > Bbs.second;
25409467b48Spatrick });
25509467b48Spatrick
25609467b48Spatrick ArrayRef<std::pair<const BasicBlock *, uint64_t>> HotBlocksRef(BBFreqs);
25709467b48Spatrick HotBlocksRef =
25809467b48Spatrick HotBlocksRef.drop_back(BBFreqs.size() - getHottestBlocks(BBFreqs.size()));
25909467b48Spatrick
26009467b48Spatrick BranchProbabilityInfo *BPI =
26109467b48Spatrick FAM.getCachedResult<BranchProbabilityAnalysis>(F);
26209467b48Spatrick
26309467b48Spatrick // visit NHotBlocks,
26409467b48Spatrick // traverse upwards to entry
26509467b48Spatrick // traverse downwards to end.
26609467b48Spatrick
26709467b48Spatrick for (auto I : HotBlocksRef) {
26809467b48Spatrick traverseToEntryBlock(I.first, CallerBlocks, BackEdgesInfo, BPI,
26909467b48Spatrick VisitedBlocks);
27009467b48Spatrick traverseToExitBlock(I.first, CallerBlocks, BackEdgesInfo, BPI,
27109467b48Spatrick VisitedBlocks);
27209467b48Spatrick }
27309467b48Spatrick
27409467b48Spatrick BlockListTy MinCallerBlocks;
27509467b48Spatrick for (auto &I : VisitedBlocks)
27609467b48Spatrick if (I.second.CallerBlock)
27709467b48Spatrick MinCallerBlocks.push_back(std::move(I.first));
27809467b48Spatrick
27909467b48Spatrick return rearrangeBB(F, MinCallerBlocks);
28009467b48Spatrick }
28109467b48Spatrick
operator ()(Function & F)28209467b48Spatrick SpeculateQuery::ResultTy SequenceBBQuery::operator()(Function &F) {
28309467b48Spatrick // reduce the number of lists!
28409467b48Spatrick DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
28509467b48Spatrick DenseSet<StringRef> Calles;
28609467b48Spatrick BlockListTy SequencedBlocks;
28709467b48Spatrick BlockListTy CallerBlocks;
28809467b48Spatrick
28909467b48Spatrick CallerBlocks = findBBwithCalls(F);
29009467b48Spatrick if (CallerBlocks.empty())
291*d415bd75Srobert return std::nullopt;
29209467b48Spatrick
29309467b48Spatrick if (isStraightLine(F))
29409467b48Spatrick SequencedBlocks = rearrangeBB(F, CallerBlocks);
29509467b48Spatrick else
29609467b48Spatrick SequencedBlocks = queryCFG(F, CallerBlocks);
29709467b48Spatrick
298*d415bd75Srobert for (const auto *BB : SequencedBlocks)
29909467b48Spatrick findCalles(BB, Calles);
30009467b48Spatrick
30109467b48Spatrick CallerAndCalles.insert({F.getName(), std::move(Calles)});
30209467b48Spatrick return CallerAndCalles;
30309467b48Spatrick }
30409467b48Spatrick
30509467b48Spatrick } // namespace orc
30609467b48Spatrick } // namespace llvm
307