xref: /openbsd-src/gnu/llvm/llvm/lib/ExecutionEngine/Orc/SpeculateAnalyses.cpp (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
109467b48Spatrick //===-- SpeculateAnalyses.cpp  --*- C++ -*-===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick 
909467b48Spatrick #include "llvm/ExecutionEngine/Orc/SpeculateAnalyses.h"
1009467b48Spatrick #include "llvm/ADT/ArrayRef.h"
1109467b48Spatrick #include "llvm/ADT/DenseMap.h"
1209467b48Spatrick #include "llvm/ADT/STLExtras.h"
1309467b48Spatrick #include "llvm/ADT/SmallPtrSet.h"
1409467b48Spatrick #include "llvm/ADT/SmallVector.h"
1509467b48Spatrick #include "llvm/Analysis/BlockFrequencyInfo.h"
1609467b48Spatrick #include "llvm/Analysis/BranchProbabilityInfo.h"
1709467b48Spatrick #include "llvm/Analysis/CFG.h"
1809467b48Spatrick #include "llvm/IR/PassManager.h"
1909467b48Spatrick #include "llvm/Passes/PassBuilder.h"
2009467b48Spatrick #include "llvm/Support/ErrorHandling.h"
2109467b48Spatrick 
2209467b48Spatrick #include <algorithm>
2309467b48Spatrick 
2409467b48Spatrick namespace {
2509467b48Spatrick using namespace llvm;
findBBwithCalls(const Function & F,bool IndirectCall=false)2609467b48Spatrick SmallVector<const BasicBlock *, 8> findBBwithCalls(const Function &F,
2709467b48Spatrick                                                    bool IndirectCall = false) {
2809467b48Spatrick   SmallVector<const BasicBlock *, 8> BBs;
2909467b48Spatrick 
3009467b48Spatrick   auto findCallInst = [&IndirectCall](const Instruction &I) {
3109467b48Spatrick     if (auto Call = dyn_cast<CallBase>(&I))
3209467b48Spatrick       return Call->isIndirectCall() ? IndirectCall : true;
3309467b48Spatrick     else
3409467b48Spatrick       return false;
3509467b48Spatrick   };
3609467b48Spatrick   for (auto &BB : F)
3709467b48Spatrick     if (findCallInst(*BB.getTerminator()) ||
3809467b48Spatrick         llvm::any_of(BB.instructionsWithoutDebug(), findCallInst))
3909467b48Spatrick       BBs.emplace_back(&BB);
4009467b48Spatrick 
4109467b48Spatrick   return BBs;
4209467b48Spatrick }
4309467b48Spatrick } // namespace
4409467b48Spatrick 
4509467b48Spatrick // Implementations of Queries shouldn't need to lock the resources
4609467b48Spatrick // such as LLVMContext, each argument (function) has a non-shared LLVMContext
4709467b48Spatrick // Plus, if Queries contain states necessary locking scheme should be provided.
4809467b48Spatrick namespace llvm {
4909467b48Spatrick namespace orc {
5009467b48Spatrick 
5109467b48Spatrick // Collect direct calls only
findCalles(const BasicBlock * BB,DenseSet<StringRef> & CallesNames)5209467b48Spatrick void SpeculateQuery::findCalles(const BasicBlock *BB,
5309467b48Spatrick                                 DenseSet<StringRef> &CallesNames) {
5409467b48Spatrick   assert(BB != nullptr && "Traversing Null BB to find calls?");
5509467b48Spatrick 
5609467b48Spatrick   auto getCalledFunction = [&CallesNames](const CallBase *Call) {
5709467b48Spatrick     auto CalledValue = Call->getCalledOperand()->stripPointerCasts();
5809467b48Spatrick     if (auto DirectCall = dyn_cast<Function>(CalledValue))
5909467b48Spatrick       CallesNames.insert(DirectCall->getName());
6009467b48Spatrick   };
6109467b48Spatrick   for (auto &I : BB->instructionsWithoutDebug())
6209467b48Spatrick     if (auto CI = dyn_cast<CallInst>(&I))
6309467b48Spatrick       getCalledFunction(CI);
6409467b48Spatrick 
6509467b48Spatrick   if (auto II = dyn_cast<InvokeInst>(BB->getTerminator()))
6609467b48Spatrick     getCalledFunction(II);
6709467b48Spatrick }
6809467b48Spatrick 
isStraightLine(const Function & F)6909467b48Spatrick bool SpeculateQuery::isStraightLine(const Function &F) {
70*d415bd75Srobert   return llvm::all_of(F, [](const BasicBlock &BB) {
7109467b48Spatrick     return BB.getSingleSuccessor() != nullptr;
7209467b48Spatrick   });
7309467b48Spatrick }
7409467b48Spatrick 
7509467b48Spatrick // BlockFreqQuery Implementations
7609467b48Spatrick 
numBBToGet(size_t numBB)7709467b48Spatrick size_t BlockFreqQuery::numBBToGet(size_t numBB) {
7809467b48Spatrick   // small CFG
7909467b48Spatrick   if (numBB < 4)
8009467b48Spatrick     return numBB;
8109467b48Spatrick   // mid-size CFG
8209467b48Spatrick   else if (numBB < 20)
8309467b48Spatrick     return (numBB / 2);
8409467b48Spatrick   else
8509467b48Spatrick     return (numBB / 2) + (numBB / 4);
8609467b48Spatrick }
8709467b48Spatrick 
operator ()(Function & F)8809467b48Spatrick BlockFreqQuery::ResultTy BlockFreqQuery::operator()(Function &F) {
8909467b48Spatrick   DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
9009467b48Spatrick   DenseSet<StringRef> Calles;
9109467b48Spatrick   SmallVector<std::pair<const BasicBlock *, uint64_t>, 8> BBFreqs;
9209467b48Spatrick 
9309467b48Spatrick   PassBuilder PB;
9409467b48Spatrick   FunctionAnalysisManager FAM;
9509467b48Spatrick   PB.registerFunctionAnalyses(FAM);
9609467b48Spatrick 
9709467b48Spatrick   auto IBBs = findBBwithCalls(F);
9809467b48Spatrick 
9909467b48Spatrick   if (IBBs.empty())
100*d415bd75Srobert     return std::nullopt;
10109467b48Spatrick 
10209467b48Spatrick   auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
10309467b48Spatrick 
10409467b48Spatrick   for (const auto I : IBBs)
10509467b48Spatrick     BBFreqs.push_back({I, BFI.getBlockFreq(I).getFrequency()});
10609467b48Spatrick 
10709467b48Spatrick   assert(IBBs.size() == BBFreqs.size() && "BB Count Mismatch");
10809467b48Spatrick 
10973471bf0Spatrick   llvm::sort(BBFreqs, [](decltype(BBFreqs)::const_reference BBF,
11009467b48Spatrick                          decltype(BBFreqs)::const_reference BBS) {
11109467b48Spatrick     return BBF.second > BBS.second ? true : false;
11209467b48Spatrick   });
11309467b48Spatrick 
11409467b48Spatrick   // ignoring number of direct calls in a BB
11509467b48Spatrick   auto Topk = numBBToGet(BBFreqs.size());
11609467b48Spatrick 
11709467b48Spatrick   for (size_t i = 0; i < Topk; i++)
11809467b48Spatrick     findCalles(BBFreqs[i].first, Calles);
11909467b48Spatrick 
12009467b48Spatrick   assert(!Calles.empty() && "Running Analysis on Function with no calls?");
12109467b48Spatrick 
12209467b48Spatrick   CallerAndCalles.insert({F.getName(), std::move(Calles)});
12309467b48Spatrick 
12409467b48Spatrick   return CallerAndCalles;
12509467b48Spatrick }
12609467b48Spatrick 
12709467b48Spatrick // SequenceBBQuery Implementation
getHottestBlocks(std::size_t TotalBlocks)12809467b48Spatrick std::size_t SequenceBBQuery::getHottestBlocks(std::size_t TotalBlocks) {
12909467b48Spatrick   if (TotalBlocks == 1)
13009467b48Spatrick     return TotalBlocks;
13109467b48Spatrick   return TotalBlocks / 2;
13209467b48Spatrick }
13309467b48Spatrick 
13409467b48Spatrick // FIXME : find good implementation.
13509467b48Spatrick SequenceBBQuery::BlockListTy
rearrangeBB(const Function & F,const BlockListTy & BBList)13609467b48Spatrick SequenceBBQuery::rearrangeBB(const Function &F, const BlockListTy &BBList) {
13709467b48Spatrick   BlockListTy RearrangedBBSet;
13809467b48Spatrick 
139*d415bd75Srobert   for (auto &Block : F)
14009467b48Spatrick     if (llvm::is_contained(BBList, &Block))
14109467b48Spatrick       RearrangedBBSet.push_back(&Block);
14209467b48Spatrick 
14309467b48Spatrick   assert(RearrangedBBSet.size() == BBList.size() &&
14409467b48Spatrick          "BasicBlock missing while rearranging?");
14509467b48Spatrick   return RearrangedBBSet;
14609467b48Spatrick }
14709467b48Spatrick 
traverseToEntryBlock(const BasicBlock * AtBB,const BlockListTy & CallerBlocks,const BackEdgesInfoTy & BackEdgesInfo,const BranchProbabilityInfo * BPI,VisitedBlocksInfoTy & VisitedBlocks)14809467b48Spatrick void SequenceBBQuery::traverseToEntryBlock(const BasicBlock *AtBB,
14909467b48Spatrick                                            const BlockListTy &CallerBlocks,
15009467b48Spatrick                                            const BackEdgesInfoTy &BackEdgesInfo,
15109467b48Spatrick                                            const BranchProbabilityInfo *BPI,
15209467b48Spatrick                                            VisitedBlocksInfoTy &VisitedBlocks) {
15309467b48Spatrick   auto Itr = VisitedBlocks.find(AtBB);
15409467b48Spatrick   if (Itr != VisitedBlocks.end()) { // already visited.
15509467b48Spatrick     if (!Itr->second.Upward)
15609467b48Spatrick       return;
15709467b48Spatrick     Itr->second.Upward = false;
15809467b48Spatrick   } else {
15909467b48Spatrick     // Create hint for newly discoverd blocks.
16009467b48Spatrick     WalkDirection BlockHint;
16109467b48Spatrick     BlockHint.Upward = false;
16209467b48Spatrick     // FIXME: Expensive Check
16309467b48Spatrick     if (llvm::is_contained(CallerBlocks, AtBB))
16409467b48Spatrick       BlockHint.CallerBlock = true;
16509467b48Spatrick     VisitedBlocks.insert(std::make_pair(AtBB, BlockHint));
16609467b48Spatrick   }
16709467b48Spatrick 
16809467b48Spatrick   const_pred_iterator PIt = pred_begin(AtBB), EIt = pred_end(AtBB);
16909467b48Spatrick   // Move this check to top, when we have code setup to launch speculative
17009467b48Spatrick   // compiles for function in entry BB, this triggers the speculative compiles
17109467b48Spatrick   // before running the program.
17209467b48Spatrick   if (PIt == EIt) // No Preds.
17309467b48Spatrick     return;
17409467b48Spatrick 
17509467b48Spatrick   DenseSet<const BasicBlock *> PredSkipNodes;
17609467b48Spatrick 
17709467b48Spatrick   // Since we are checking for predecessor's backedges, this Block
17809467b48Spatrick   // occurs in second position.
17909467b48Spatrick   for (auto &I : BackEdgesInfo)
18009467b48Spatrick     if (I.second == AtBB)
18109467b48Spatrick       PredSkipNodes.insert(I.first);
18209467b48Spatrick 
18309467b48Spatrick   // Skip predecessors which source of back-edges.
18409467b48Spatrick   for (; PIt != EIt; ++PIt)
18509467b48Spatrick     // checking EdgeHotness is cheaper
18609467b48Spatrick     if (BPI->isEdgeHot(*PIt, AtBB) && !PredSkipNodes.count(*PIt))
18709467b48Spatrick       traverseToEntryBlock(*PIt, CallerBlocks, BackEdgesInfo, BPI,
18809467b48Spatrick                            VisitedBlocks);
18909467b48Spatrick }
19009467b48Spatrick 
traverseToExitBlock(const BasicBlock * AtBB,const BlockListTy & CallerBlocks,const BackEdgesInfoTy & BackEdgesInfo,const BranchProbabilityInfo * BPI,VisitedBlocksInfoTy & VisitedBlocks)19109467b48Spatrick void SequenceBBQuery::traverseToExitBlock(const BasicBlock *AtBB,
19209467b48Spatrick                                           const BlockListTy &CallerBlocks,
19309467b48Spatrick                                           const BackEdgesInfoTy &BackEdgesInfo,
19409467b48Spatrick                                           const BranchProbabilityInfo *BPI,
19509467b48Spatrick                                           VisitedBlocksInfoTy &VisitedBlocks) {
19609467b48Spatrick   auto Itr = VisitedBlocks.find(AtBB);
19709467b48Spatrick   if (Itr != VisitedBlocks.end()) { // already visited.
19809467b48Spatrick     if (!Itr->second.Downward)
19909467b48Spatrick       return;
20009467b48Spatrick     Itr->second.Downward = false;
20109467b48Spatrick   } else {
20209467b48Spatrick     // Create hint for newly discoverd blocks.
20309467b48Spatrick     WalkDirection BlockHint;
20409467b48Spatrick     BlockHint.Downward = false;
20509467b48Spatrick     // FIXME: Expensive Check
20609467b48Spatrick     if (llvm::is_contained(CallerBlocks, AtBB))
20709467b48Spatrick       BlockHint.CallerBlock = true;
20809467b48Spatrick     VisitedBlocks.insert(std::make_pair(AtBB, BlockHint));
20909467b48Spatrick   }
21009467b48Spatrick 
211097a140dSpatrick   const_succ_iterator PIt = succ_begin(AtBB), EIt = succ_end(AtBB);
21209467b48Spatrick   if (PIt == EIt) // No succs.
21309467b48Spatrick     return;
21409467b48Spatrick 
21509467b48Spatrick   // If there are hot edges, then compute SuccSkipNodes.
21609467b48Spatrick   DenseSet<const BasicBlock *> SuccSkipNodes;
21709467b48Spatrick 
21809467b48Spatrick   // Since we are checking for successor's backedges, this Block
21909467b48Spatrick   // occurs in first position.
22009467b48Spatrick   for (auto &I : BackEdgesInfo)
22109467b48Spatrick     if (I.first == AtBB)
22209467b48Spatrick       SuccSkipNodes.insert(I.second);
22309467b48Spatrick 
22409467b48Spatrick   for (; PIt != EIt; ++PIt)
22509467b48Spatrick     if (BPI->isEdgeHot(AtBB, *PIt) && !SuccSkipNodes.count(*PIt))
22609467b48Spatrick       traverseToExitBlock(*PIt, CallerBlocks, BackEdgesInfo, BPI,
22709467b48Spatrick                           VisitedBlocks);
22809467b48Spatrick }
22909467b48Spatrick 
23009467b48Spatrick // Get Block frequencies for blocks and take most frquently executed block,
23109467b48Spatrick // walk towards the entry block from those blocks and discover the basic blocks
23209467b48Spatrick // with call.
23309467b48Spatrick SequenceBBQuery::BlockListTy
queryCFG(Function & F,const BlockListTy & CallerBlocks)23409467b48Spatrick SequenceBBQuery::queryCFG(Function &F, const BlockListTy &CallerBlocks) {
23509467b48Spatrick 
23609467b48Spatrick   BlockFreqInfoTy BBFreqs;
23709467b48Spatrick   VisitedBlocksInfoTy VisitedBlocks;
23809467b48Spatrick   BackEdgesInfoTy BackEdgesInfo;
23909467b48Spatrick 
24009467b48Spatrick   PassBuilder PB;
24109467b48Spatrick   FunctionAnalysisManager FAM;
24209467b48Spatrick   PB.registerFunctionAnalyses(FAM);
24309467b48Spatrick 
24409467b48Spatrick   auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
24509467b48Spatrick 
24609467b48Spatrick   llvm::FindFunctionBackedges(F, BackEdgesInfo);
24709467b48Spatrick 
24809467b48Spatrick   for (const auto I : CallerBlocks)
24909467b48Spatrick     BBFreqs.push_back({I, BFI.getBlockFreq(I).getFrequency()});
25009467b48Spatrick 
25109467b48Spatrick   llvm::sort(BBFreqs, [](decltype(BBFreqs)::const_reference Bbf,
25209467b48Spatrick                          decltype(BBFreqs)::const_reference Bbs) {
25309467b48Spatrick     return Bbf.second > Bbs.second;
25409467b48Spatrick   });
25509467b48Spatrick 
25609467b48Spatrick   ArrayRef<std::pair<const BasicBlock *, uint64_t>> HotBlocksRef(BBFreqs);
25709467b48Spatrick   HotBlocksRef =
25809467b48Spatrick       HotBlocksRef.drop_back(BBFreqs.size() - getHottestBlocks(BBFreqs.size()));
25909467b48Spatrick 
26009467b48Spatrick   BranchProbabilityInfo *BPI =
26109467b48Spatrick       FAM.getCachedResult<BranchProbabilityAnalysis>(F);
26209467b48Spatrick 
26309467b48Spatrick   // visit NHotBlocks,
26409467b48Spatrick   // traverse upwards to entry
26509467b48Spatrick   // traverse downwards to end.
26609467b48Spatrick 
26709467b48Spatrick   for (auto I : HotBlocksRef) {
26809467b48Spatrick     traverseToEntryBlock(I.first, CallerBlocks, BackEdgesInfo, BPI,
26909467b48Spatrick                          VisitedBlocks);
27009467b48Spatrick     traverseToExitBlock(I.first, CallerBlocks, BackEdgesInfo, BPI,
27109467b48Spatrick                         VisitedBlocks);
27209467b48Spatrick   }
27309467b48Spatrick 
27409467b48Spatrick   BlockListTy MinCallerBlocks;
27509467b48Spatrick   for (auto &I : VisitedBlocks)
27609467b48Spatrick     if (I.second.CallerBlock)
27709467b48Spatrick       MinCallerBlocks.push_back(std::move(I.first));
27809467b48Spatrick 
27909467b48Spatrick   return rearrangeBB(F, MinCallerBlocks);
28009467b48Spatrick }
28109467b48Spatrick 
operator ()(Function & F)28209467b48Spatrick SpeculateQuery::ResultTy SequenceBBQuery::operator()(Function &F) {
28309467b48Spatrick   // reduce the number of lists!
28409467b48Spatrick   DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
28509467b48Spatrick   DenseSet<StringRef> Calles;
28609467b48Spatrick   BlockListTy SequencedBlocks;
28709467b48Spatrick   BlockListTy CallerBlocks;
28809467b48Spatrick 
28909467b48Spatrick   CallerBlocks = findBBwithCalls(F);
29009467b48Spatrick   if (CallerBlocks.empty())
291*d415bd75Srobert     return std::nullopt;
29209467b48Spatrick 
29309467b48Spatrick   if (isStraightLine(F))
29409467b48Spatrick     SequencedBlocks = rearrangeBB(F, CallerBlocks);
29509467b48Spatrick   else
29609467b48Spatrick     SequencedBlocks = queryCFG(F, CallerBlocks);
29709467b48Spatrick 
298*d415bd75Srobert   for (const auto *BB : SequencedBlocks)
29909467b48Spatrick     findCalles(BB, Calles);
30009467b48Spatrick 
30109467b48Spatrick   CallerAndCalles.insert({F.getName(), std::move(Calles)});
30209467b48Spatrick   return CallerAndCalles;
30309467b48Spatrick }
30409467b48Spatrick 
30509467b48Spatrick } // namespace orc
30609467b48Spatrick } // namespace llvm
307