xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
1*81ad6265SDimitry Andric //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2*81ad6265SDimitry Andric //
3*81ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*81ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*81ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*81ad6265SDimitry Andric //
7*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
8*81ad6265SDimitry Andric //
9*81ad6265SDimitry Andric // This pass identifies/eliminate Redundant TLS Loads if related option is set.
10*81ad6265SDimitry Andric // The example: Please refer to the comment at the head of TLSVariableHoist.h.
11*81ad6265SDimitry Andric //
12*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
13*81ad6265SDimitry Andric 
14*81ad6265SDimitry Andric #include "llvm/ADT/SmallVector.h"
15*81ad6265SDimitry Andric #include "llvm/IR/BasicBlock.h"
16*81ad6265SDimitry Andric #include "llvm/IR/Dominators.h"
17*81ad6265SDimitry Andric #include "llvm/IR/Function.h"
18*81ad6265SDimitry Andric #include "llvm/IR/InstrTypes.h"
19*81ad6265SDimitry Andric #include "llvm/IR/Instruction.h"
20*81ad6265SDimitry Andric #include "llvm/IR/Instructions.h"
21*81ad6265SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
22*81ad6265SDimitry Andric #include "llvm/IR/Module.h"
23*81ad6265SDimitry Andric #include "llvm/IR/Value.h"
24*81ad6265SDimitry Andric #include "llvm/InitializePasses.h"
25*81ad6265SDimitry Andric #include "llvm/Pass.h"
26*81ad6265SDimitry Andric #include "llvm/Support/Casting.h"
27*81ad6265SDimitry Andric #include "llvm/Support/Debug.h"
28*81ad6265SDimitry Andric #include "llvm/Support/raw_ostream.h"
29*81ad6265SDimitry Andric #include "llvm/Transforms/Scalar.h"
30*81ad6265SDimitry Andric #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31*81ad6265SDimitry Andric #include <algorithm>
32*81ad6265SDimitry Andric #include <cassert>
33*81ad6265SDimitry Andric #include <cstdint>
34*81ad6265SDimitry Andric #include <iterator>
35*81ad6265SDimitry Andric #include <tuple>
36*81ad6265SDimitry Andric #include <utility>
37*81ad6265SDimitry Andric 
38*81ad6265SDimitry Andric using namespace llvm;
39*81ad6265SDimitry Andric using namespace tlshoist;
40*81ad6265SDimitry Andric 
41*81ad6265SDimitry Andric #define DEBUG_TYPE "tlshoist"
42*81ad6265SDimitry Andric 
43*81ad6265SDimitry Andric static cl::opt<bool> TLSLoadHoist(
44*81ad6265SDimitry Andric     "tls-load-hoist", cl::init(false), cl::Hidden,
45*81ad6265SDimitry Andric     cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
46*81ad6265SDimitry Andric              "TLS address calculation."));
47*81ad6265SDimitry Andric 
48*81ad6265SDimitry Andric namespace {
49*81ad6265SDimitry Andric 
50*81ad6265SDimitry Andric /// The TLS Variable hoist pass.
51*81ad6265SDimitry Andric class TLSVariableHoistLegacyPass : public FunctionPass {
52*81ad6265SDimitry Andric public:
53*81ad6265SDimitry Andric   static char ID; // Pass identification, replacement for typeid
54*81ad6265SDimitry Andric 
55*81ad6265SDimitry Andric   TLSVariableHoistLegacyPass() : FunctionPass(ID) {
56*81ad6265SDimitry Andric     initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
57*81ad6265SDimitry Andric   }
58*81ad6265SDimitry Andric 
59*81ad6265SDimitry Andric   bool runOnFunction(Function &Fn) override;
60*81ad6265SDimitry Andric 
61*81ad6265SDimitry Andric   StringRef getPassName() const override { return "TLS Variable Hoist"; }
62*81ad6265SDimitry Andric 
63*81ad6265SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
64*81ad6265SDimitry Andric     AU.setPreservesCFG();
65*81ad6265SDimitry Andric     AU.addRequired<DominatorTreeWrapperPass>();
66*81ad6265SDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
67*81ad6265SDimitry Andric   }
68*81ad6265SDimitry Andric 
69*81ad6265SDimitry Andric private:
70*81ad6265SDimitry Andric   TLSVariableHoistPass Impl;
71*81ad6265SDimitry Andric };
72*81ad6265SDimitry Andric 
73*81ad6265SDimitry Andric } // end anonymous namespace
74*81ad6265SDimitry Andric 
75*81ad6265SDimitry Andric char TLSVariableHoistLegacyPass::ID = 0;
76*81ad6265SDimitry Andric 
77*81ad6265SDimitry Andric INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
78*81ad6265SDimitry Andric                       "TLS Variable Hoist", false, false)
79*81ad6265SDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
80*81ad6265SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
81*81ad6265SDimitry Andric INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
82*81ad6265SDimitry Andric                     "TLS Variable Hoist", false, false)
83*81ad6265SDimitry Andric 
84*81ad6265SDimitry Andric FunctionPass *llvm::createTLSVariableHoistPass() {
85*81ad6265SDimitry Andric   return new TLSVariableHoistLegacyPass();
86*81ad6265SDimitry Andric }
87*81ad6265SDimitry Andric 
88*81ad6265SDimitry Andric /// Perform the TLS Variable Hoist optimization for the given function.
89*81ad6265SDimitry Andric bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
90*81ad6265SDimitry Andric   if (skipFunction(Fn))
91*81ad6265SDimitry Andric     return false;
92*81ad6265SDimitry Andric 
93*81ad6265SDimitry Andric   LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
94*81ad6265SDimitry Andric   LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
95*81ad6265SDimitry Andric 
96*81ad6265SDimitry Andric   bool MadeChange =
97*81ad6265SDimitry Andric       Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
98*81ad6265SDimitry Andric                    getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
99*81ad6265SDimitry Andric 
100*81ad6265SDimitry Andric   if (MadeChange) {
101*81ad6265SDimitry Andric     LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
102*81ad6265SDimitry Andric                       << Fn.getName() << '\n');
103*81ad6265SDimitry Andric     LLVM_DEBUG(dbgs() << Fn);
104*81ad6265SDimitry Andric   }
105*81ad6265SDimitry Andric   LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
106*81ad6265SDimitry Andric 
107*81ad6265SDimitry Andric   return MadeChange;
108*81ad6265SDimitry Andric }
109*81ad6265SDimitry Andric 
110*81ad6265SDimitry Andric void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
111*81ad6265SDimitry Andric   // Skip all cast instructions. They are visited indirectly later on.
112*81ad6265SDimitry Andric   if (Inst->isCast())
113*81ad6265SDimitry Andric     return;
114*81ad6265SDimitry Andric 
115*81ad6265SDimitry Andric   // Scan all operands.
116*81ad6265SDimitry Andric   for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
117*81ad6265SDimitry Andric     auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
118*81ad6265SDimitry Andric     if (!GV || !GV->isThreadLocal())
119*81ad6265SDimitry Andric       continue;
120*81ad6265SDimitry Andric 
121*81ad6265SDimitry Andric     // Add Candidate to TLSCandMap (GV --> Candidate).
122*81ad6265SDimitry Andric     TLSCandMap[GV].addUser(Inst, Idx);
123*81ad6265SDimitry Andric   }
124*81ad6265SDimitry Andric }
125*81ad6265SDimitry Andric 
126*81ad6265SDimitry Andric void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
127*81ad6265SDimitry Andric   // First, quickly check if there is TLS Variable.
128*81ad6265SDimitry Andric   Module *M = Fn.getParent();
129*81ad6265SDimitry Andric 
130*81ad6265SDimitry Andric   bool HasTLS = llvm::any_of(
131*81ad6265SDimitry Andric       M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
132*81ad6265SDimitry Andric 
133*81ad6265SDimitry Andric   // If non, directly return.
134*81ad6265SDimitry Andric   if (!HasTLS)
135*81ad6265SDimitry Andric     return;
136*81ad6265SDimitry Andric 
137*81ad6265SDimitry Andric   TLSCandMap.clear();
138*81ad6265SDimitry Andric 
139*81ad6265SDimitry Andric   // Then, collect TLS Variable info.
140*81ad6265SDimitry Andric   for (BasicBlock &BB : Fn) {
141*81ad6265SDimitry Andric     // Ignore unreachable basic blocks.
142*81ad6265SDimitry Andric     if (!DT->isReachableFromEntry(&BB))
143*81ad6265SDimitry Andric       continue;
144*81ad6265SDimitry Andric 
145*81ad6265SDimitry Andric     for (Instruction &Inst : BB)
146*81ad6265SDimitry Andric       collectTLSCandidate(&Inst);
147*81ad6265SDimitry Andric   }
148*81ad6265SDimitry Andric }
149*81ad6265SDimitry Andric 
150*81ad6265SDimitry Andric static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
151*81ad6265SDimitry Andric   if (Cand.Users.size() != 1)
152*81ad6265SDimitry Andric     return false;
153*81ad6265SDimitry Andric 
154*81ad6265SDimitry Andric   BasicBlock *BB = Cand.Users[0].Inst->getParent();
155*81ad6265SDimitry Andric   if (LI->getLoopFor(BB))
156*81ad6265SDimitry Andric     return false;
157*81ad6265SDimitry Andric 
158*81ad6265SDimitry Andric   return true;
159*81ad6265SDimitry Andric }
160*81ad6265SDimitry Andric 
161*81ad6265SDimitry Andric Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
162*81ad6265SDimitry Andric                                                          Loop *L) {
163*81ad6265SDimitry Andric   assert(L && "Unexcepted Loop status!");
164*81ad6265SDimitry Andric 
165*81ad6265SDimitry Andric   // Get the outermost loop.
166*81ad6265SDimitry Andric   while (Loop *Parent = L->getParentLoop())
167*81ad6265SDimitry Andric     L = Parent;
168*81ad6265SDimitry Andric 
169*81ad6265SDimitry Andric   BasicBlock *PreHeader = L->getLoopPreheader();
170*81ad6265SDimitry Andric 
171*81ad6265SDimitry Andric   // There is unique predecessor outside the loop.
172*81ad6265SDimitry Andric   if (PreHeader)
173*81ad6265SDimitry Andric     return PreHeader->getTerminator();
174*81ad6265SDimitry Andric 
175*81ad6265SDimitry Andric   BasicBlock *Header = L->getHeader();
176*81ad6265SDimitry Andric   BasicBlock *Dom = Header;
177*81ad6265SDimitry Andric   for (BasicBlock *PredBB : predecessors(Header))
178*81ad6265SDimitry Andric     Dom = DT->findNearestCommonDominator(Dom, PredBB);
179*81ad6265SDimitry Andric 
180*81ad6265SDimitry Andric   assert(Dom && "Not find dominator BB!");
181*81ad6265SDimitry Andric   Instruction *Term = Dom->getTerminator();
182*81ad6265SDimitry Andric 
183*81ad6265SDimitry Andric   return Term;
184*81ad6265SDimitry Andric }
185*81ad6265SDimitry Andric 
186*81ad6265SDimitry Andric Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
187*81ad6265SDimitry Andric                                               Instruction *I2) {
188*81ad6265SDimitry Andric   if (!I1)
189*81ad6265SDimitry Andric     return I2;
190*81ad6265SDimitry Andric   if (DT->dominates(I1, I2))
191*81ad6265SDimitry Andric     return I1;
192*81ad6265SDimitry Andric   if (DT->dominates(I2, I1))
193*81ad6265SDimitry Andric     return I2;
194*81ad6265SDimitry Andric 
195*81ad6265SDimitry Andric   // If there is no dominance relation, use common dominator.
196*81ad6265SDimitry Andric   BasicBlock *DomBB =
197*81ad6265SDimitry Andric       DT->findNearestCommonDominator(I1->getParent(), I2->getParent());
198*81ad6265SDimitry Andric 
199*81ad6265SDimitry Andric   Instruction *Dom = DomBB->getTerminator();
200*81ad6265SDimitry Andric   assert(Dom && "Common dominator not found!");
201*81ad6265SDimitry Andric 
202*81ad6265SDimitry Andric   return Dom;
203*81ad6265SDimitry Andric }
204*81ad6265SDimitry Andric 
205*81ad6265SDimitry Andric BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
206*81ad6265SDimitry Andric                                                          GlobalVariable *GV,
207*81ad6265SDimitry Andric                                                          BasicBlock *&PosBB) {
208*81ad6265SDimitry Andric   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
209*81ad6265SDimitry Andric 
210*81ad6265SDimitry Andric   // We should hoist the TLS use out of loop, so choose its nearest instruction
211*81ad6265SDimitry Andric   // which dominate the loop and the outside loops (if exist).
212*81ad6265SDimitry Andric   Instruction *LastPos = nullptr;
213*81ad6265SDimitry Andric   for (auto &User : Cand.Users) {
214*81ad6265SDimitry Andric     BasicBlock *BB = User.Inst->getParent();
215*81ad6265SDimitry Andric     Instruction *Pos = User.Inst;
216*81ad6265SDimitry Andric     if (Loop *L = LI->getLoopFor(BB)) {
217*81ad6265SDimitry Andric       Pos = getNearestLoopDomInst(BB, L);
218*81ad6265SDimitry Andric       assert(Pos && "Not find insert position out of loop!");
219*81ad6265SDimitry Andric     }
220*81ad6265SDimitry Andric     Pos = getDomInst(LastPos, Pos);
221*81ad6265SDimitry Andric     LastPos = Pos;
222*81ad6265SDimitry Andric   }
223*81ad6265SDimitry Andric 
224*81ad6265SDimitry Andric   assert(LastPos && "Unexpected insert position!");
225*81ad6265SDimitry Andric   BasicBlock *Parent = LastPos->getParent();
226*81ad6265SDimitry Andric   PosBB = Parent;
227*81ad6265SDimitry Andric   return LastPos->getIterator();
228*81ad6265SDimitry Andric }
229*81ad6265SDimitry Andric 
230*81ad6265SDimitry Andric // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
231*81ad6265SDimitry Andric Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
232*81ad6265SDimitry Andric                                                   GlobalVariable *GV) {
233*81ad6265SDimitry Andric   BasicBlock *PosBB = &Fn.getEntryBlock();
234*81ad6265SDimitry Andric   BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
235*81ad6265SDimitry Andric   Type *Ty = GV->getType();
236*81ad6265SDimitry Andric   auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
237*81ad6265SDimitry Andric   PosBB->getInstList().insert(Iter, CastInst);
238*81ad6265SDimitry Andric   return CastInst;
239*81ad6265SDimitry Andric }
240*81ad6265SDimitry Andric 
241*81ad6265SDimitry Andric bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
242*81ad6265SDimitry Andric                                                   GlobalVariable *GV) {
243*81ad6265SDimitry Andric 
244*81ad6265SDimitry Andric   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
245*81ad6265SDimitry Andric 
246*81ad6265SDimitry Andric   // If only used 1 time and not in loops, we no need to replace it.
247*81ad6265SDimitry Andric   if (oneUseOutsideLoop(Cand, LI))
248*81ad6265SDimitry Andric     return false;
249*81ad6265SDimitry Andric 
250*81ad6265SDimitry Andric   // Generate a bitcast (no type change)
251*81ad6265SDimitry Andric   auto *CastInst = genBitCastInst(Fn, GV);
252*81ad6265SDimitry Andric 
253*81ad6265SDimitry Andric   // to replace the uses of TLS Candidate
254*81ad6265SDimitry Andric   for (auto &User : Cand.Users)
255*81ad6265SDimitry Andric     User.Inst->setOperand(User.OpndIdx, CastInst);
256*81ad6265SDimitry Andric 
257*81ad6265SDimitry Andric   return true;
258*81ad6265SDimitry Andric }
259*81ad6265SDimitry Andric 
260*81ad6265SDimitry Andric bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
261*81ad6265SDimitry Andric   if (TLSCandMap.empty())
262*81ad6265SDimitry Andric     return false;
263*81ad6265SDimitry Andric 
264*81ad6265SDimitry Andric   bool Replaced = false;
265*81ad6265SDimitry Andric   for (auto &GV2Cand : TLSCandMap) {
266*81ad6265SDimitry Andric     GlobalVariable *GV = GV2Cand.first;
267*81ad6265SDimitry Andric     Replaced |= tryReplaceTLSCandidate(Fn, GV);
268*81ad6265SDimitry Andric   }
269*81ad6265SDimitry Andric 
270*81ad6265SDimitry Andric   return Replaced;
271*81ad6265SDimitry Andric }
272*81ad6265SDimitry Andric 
273*81ad6265SDimitry Andric /// Optimize expensive TLS variables in the given function.
274*81ad6265SDimitry Andric bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
275*81ad6265SDimitry Andric                                    LoopInfo &LI) {
276*81ad6265SDimitry Andric   if (Fn.hasOptNone())
277*81ad6265SDimitry Andric     return false;
278*81ad6265SDimitry Andric 
279*81ad6265SDimitry Andric   if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
280*81ad6265SDimitry Andric     return false;
281*81ad6265SDimitry Andric 
282*81ad6265SDimitry Andric   this->LI = &LI;
283*81ad6265SDimitry Andric   this->DT = &DT;
284*81ad6265SDimitry Andric   assert(this->LI && this->DT && "Unexcepted requirement!");
285*81ad6265SDimitry Andric 
286*81ad6265SDimitry Andric   // Collect all TLS variable candidates.
287*81ad6265SDimitry Andric   collectTLSCandidates(Fn);
288*81ad6265SDimitry Andric 
289*81ad6265SDimitry Andric   bool MadeChange = tryReplaceTLSCandidates(Fn);
290*81ad6265SDimitry Andric 
291*81ad6265SDimitry Andric   return MadeChange;
292*81ad6265SDimitry Andric }
293*81ad6265SDimitry Andric 
294*81ad6265SDimitry Andric PreservedAnalyses TLSVariableHoistPass::run(Function &F,
295*81ad6265SDimitry Andric                                             FunctionAnalysisManager &AM) {
296*81ad6265SDimitry Andric 
297*81ad6265SDimitry Andric   auto &LI = AM.getResult<LoopAnalysis>(F);
298*81ad6265SDimitry Andric   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
299*81ad6265SDimitry Andric 
300*81ad6265SDimitry Andric   if (!runImpl(F, DT, LI))
301*81ad6265SDimitry Andric     return PreservedAnalyses::all();
302*81ad6265SDimitry Andric 
303*81ad6265SDimitry Andric   PreservedAnalyses PA;
304*81ad6265SDimitry Andric   PA.preserveSet<CFGAnalyses>();
305*81ad6265SDimitry Andric   return PA;
306*81ad6265SDimitry Andric }
307