xref: /llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp (revision f0ed31ce4b63a5530fd1de875c0d1467d4d2c6ea)
1 //===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 
10 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
11 #include "llvm/Analysis/CtxProfAnalysis.h"
12 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
13 #include "llvm/IR/Analysis.h"
14 #include "llvm/IR/DiagnosticInfo.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/IntrinsicInst.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/PassManager.h"
20 #include "llvm/ProfileData/InstrProf.h"
21 #include "llvm/Support/CommandLine.h"
22 #include <utility>
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "ctx-instr-lower"
27 
28 static cl::list<std::string> ContextRoots(
29     "profile-context-root", cl::Hidden,
30     cl::desc(
31         "A function name, assumed to be global, which will be treated as the "
32         "root of an interesting graph, which will be profiled independently "
33         "from other similar graphs."));
34 
35 bool PGOCtxProfLoweringPass::isCtxIRPGOInstrEnabled() {
36   return !ContextRoots.empty();
37 }
38 
39 // the names of symbols we expect in compiler-rt. Using a namespace for
40 // readability.
41 namespace CompilerRtAPINames {
42 static auto StartCtx = "__llvm_ctx_profile_start_context";
43 static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
44 static auto GetCtx = "__llvm_ctx_profile_get_context";
45 static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
46 static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
47 } // namespace CompilerRtAPINames
48 
49 namespace {
50 // The lowering logic and state.
51 class CtxInstrumentationLowerer final {
52   Module &M;
53   ModuleAnalysisManager &MAM;
54   Type *ContextNodeTy = nullptr;
55   Type *ContextRootTy = nullptr;
56 
57   DenseMap<const Function *, Constant *> ContextRootMap;
58   Function *StartCtx = nullptr;
59   Function *GetCtx = nullptr;
60   Function *ReleaseCtx = nullptr;
61   GlobalVariable *ExpectedCalleeTLS = nullptr;
62   GlobalVariable *CallsiteInfoTLS = nullptr;
63 
64 public:
65   CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
66   // return true if lowering happened (i.e. a change was made)
67   bool lowerFunction(Function &F);
68 };
69 
70 // llvm.instrprof.increment[.step] captures the total number of counters as one
71 // of its parameters, and llvm.instrprof.callsite captures the total number of
72 // callsites. Those values are the same for instances of those intrinsics in
73 // this function. Find the first instance of each and return them.
74 std::pair<uint32_t, uint32_t> getNumCountersAndCallsites(const Function &F) {
75   uint32_t NumCounters = 0;
76   uint32_t NumCallsites = 0;
77   for (const auto &BB : F) {
78     for (const auto &I : BB) {
79       if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
80         uint32_t V =
81             static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
82         assert((!NumCounters || V == NumCounters) &&
83                "expected all llvm.instrprof.increment[.step] intrinsics to "
84                "have the same total nr of counters parameter");
85         NumCounters = V;
86       } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
87         uint32_t V =
88             static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
89         assert((!NumCallsites || V == NumCallsites) &&
90                "expected all llvm.instrprof.callsite intrinsics to have the "
91                "same total nr of callsites parameter");
92         NumCallsites = V;
93       }
94 #if NDEBUG
95       if (NumCounters && NumCallsites)
96         return std::make_pair(NumCounters, NumCallsites);
97 #endif
98     }
99   }
100   return {NumCounters, NumCallsites};
101 }
102 } // namespace
103 
104 // set up tie-in with compiler-rt.
105 // NOTE!!!
106 // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
107 CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
108                                                      ModuleAnalysisManager &MAM)
109     : M(M), MAM(MAM) {
110   auto *PointerTy = PointerType::get(M.getContext(), 0);
111   auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
112   auto *I32Ty = Type::getInt32Ty(M.getContext());
113   auto *I64Ty = Type::getInt64Ty(M.getContext());
114 
115   // The ContextRoot type
116   ContextRootTy =
117       StructType::get(M.getContext(), {
118                                           PointerTy,          /*FirstNode*/
119                                           PointerTy,          /*FirstMemBlock*/
120                                           PointerTy,          /*CurrentMem*/
121                                           SanitizerMutexType, /*Taken*/
122                                       });
123   // The Context header.
124   ContextNodeTy = StructType::get(M.getContext(), {
125                                                       I64Ty,     /*Guid*/
126                                                       PointerTy, /*Next*/
127                                                       I32Ty,     /*NumCounters*/
128                                                       I32Ty, /*NumCallsites*/
129                                                   });
130 
131   // Define a global for each entrypoint. We'll reuse the entrypoint's name as
132   // prefix. We assume the entrypoint names to be unique.
133   for (const auto &Fname : ContextRoots) {
134     if (const auto *F = M.getFunction(Fname)) {
135       if (F->isDeclaration())
136         continue;
137       auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
138       cast<GlobalVariable>(G)->setInitializer(
139           Constant::getNullValue(ContextRootTy));
140       ContextRootMap.insert(std::make_pair(F, G));
141       for (const auto &BB : *F)
142         for (const auto &I : BB)
143           if (const auto *CB = dyn_cast<CallBase>(&I))
144             if (CB->isMustTailCall()) {
145               M.getContext().emitError(
146                   "The function " + Fname +
147                   " was indicated as a context root, but it features musttail "
148                   "calls, which is not supported.");
149             }
150     }
151   }
152 
153   // Declare the functions we will call.
154   StartCtx = cast<Function>(
155       M.getOrInsertFunction(
156            CompilerRtAPINames::StartCtx,
157            FunctionType::get(PointerTy,
158                              {PointerTy, /*ContextRoot*/
159                               I64Ty, /*Guid*/ I32Ty,
160                               /*NumCounters*/ I32Ty /*NumCallsites*/},
161                              false))
162           .getCallee());
163   GetCtx = cast<Function>(
164       M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
165                             FunctionType::get(PointerTy,
166                                               {PointerTy, /*Callee*/
167                                                I64Ty,     /*Guid*/
168                                                I32Ty,     /*NumCounters*/
169                                                I32Ty},    /*NumCallsites*/
170                                               false))
171           .getCallee());
172   ReleaseCtx = cast<Function>(
173       M.getOrInsertFunction(CompilerRtAPINames::ReleaseCtx,
174                             FunctionType::get(Type::getVoidTy(M.getContext()),
175                                               {
176                                                   PointerTy, /*ContextRoot*/
177                                               },
178                                               false))
179           .getCallee());
180 
181   // Declare the TLSes we will need to use.
182   CallsiteInfoTLS =
183       new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
184                          nullptr, CompilerRtAPINames::CallsiteTLS);
185   CallsiteInfoTLS->setThreadLocal(true);
186   CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
187   ExpectedCalleeTLS =
188       new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
189                          nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
190   ExpectedCalleeTLS->setThreadLocal(true);
191   ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
192 }
193 
194 PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
195                                               ModuleAnalysisManager &MAM) {
196   CtxInstrumentationLowerer Lowerer(M, MAM);
197   bool Changed = false;
198   for (auto &F : M)
199     Changed |= Lowerer.lowerFunction(F);
200   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
201 }
202 
203 bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
204   if (F.isDeclaration())
205     return false;
206   auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
207   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
208 
209   Value *Guid = nullptr;
210   auto [NumCounters, NumCallsites] = getNumCountersAndCallsites(F);
211 
212   Value *Context = nullptr;
213   Value *RealContext = nullptr;
214 
215   StructType *ThisContextType = nullptr;
216   Value *TheRootContext = nullptr;
217   Value *ExpectedCalleeTLSAddr = nullptr;
218   Value *CallsiteInfoTLSAddr = nullptr;
219 
220   auto &Head = F.getEntryBlock();
221   for (auto &I : Head) {
222     // Find the increment intrinsic in the entry basic block.
223     if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
224       assert(Mark->getIndex()->isZero());
225 
226       IRBuilder<> Builder(Mark);
227 
228       Guid = Builder.getInt64(
229           AssignGUIDPass::getGUID(cast<Function>(*Mark->getNameValue())));
230       // The type of the context of this function is now knowable since we have
231       // NumCallsites and NumCounters. We delcare it here because it's more
232       // convenient - we have the Builder.
233       ThisContextType = StructType::get(
234           F.getContext(),
235           {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NumCounters),
236            ArrayType::get(Builder.getPtrTy(), NumCallsites)});
237       // Figure out which way we obtain the context object for this function -
238       // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
239       // former case, we also set TheRootContext since we need to release it
240       // at the end (plus it can be used to know if we have an entrypoint or a
241       // regular function)
242       auto Iter = ContextRootMap.find(&F);
243       if (Iter != ContextRootMap.end()) {
244         TheRootContext = Iter->second;
245         Context = Builder.CreateCall(
246             StartCtx, {TheRootContext, Guid, Builder.getInt32(NumCounters),
247                        Builder.getInt32(NumCallsites)});
248         ORE.emit(
249             [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
250       } else {
251         Context =
252             Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NumCounters),
253                                         Builder.getInt32(NumCallsites)});
254         ORE.emit([&] {
255           return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
256         });
257       }
258       // The context could be scratch.
259       auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
260       if (NumCallsites > 0) {
261         // Figure out which index of the TLS 2-element buffers to use.
262         // Scratch context => we use index == 1. Real contexts => index == 0.
263         auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
264         // The GEPs corresponding to that index, in the respective TLS.
265         ExpectedCalleeTLSAddr = Builder.CreateGEP(
266             PointerType::getUnqual(F.getContext()),
267             Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
268         CallsiteInfoTLSAddr = Builder.CreateGEP(
269             Builder.getInt32Ty(),
270             Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
271       }
272       // Because the context pointer may have LSB set (to indicate scratch),
273       // clear it for the value we use as base address for the counter vector.
274       // This way, if later we want to have "real" (not clobbered) buffers
275       // acting as scratch, the lowering (at least this part of it that deals
276       // with counters) stays the same.
277       RealContext = Builder.CreateIntToPtr(
278           Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
279           PointerType::getUnqual(F.getContext()));
280       I.eraseFromParent();
281       break;
282     }
283   }
284   if (!Context) {
285     ORE.emit([&] {
286       return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
287              << "Function doesn't have instrumentation, skipping";
288     });
289     return false;
290   }
291 
292   bool ContextWasReleased = false;
293   for (auto &BB : F) {
294     for (auto &I : llvm::make_early_inc_range(BB)) {
295       if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
296         IRBuilder<> Builder(Instr);
297         switch (Instr->getIntrinsicID()) {
298         case llvm::Intrinsic::instrprof_increment:
299         case llvm::Intrinsic::instrprof_increment_step: {
300           // Increments (or increment-steps) are just a typical load - increment
301           // - store in the RealContext.
302           auto *AsStep = cast<InstrProfIncrementInst>(Instr);
303           auto *GEP = Builder.CreateGEP(
304               ThisContextType, RealContext,
305               {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
306           Builder.CreateStore(
307               Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
308                                 AsStep->getStep()),
309               GEP);
310         } break;
311         case llvm::Intrinsic::instrprof_callsite:
312           // callsite lowering: write the called value in the expected callee
313           // TLS we treat the TLS as volatile because of signal handlers and to
314           // avoid these being moved away from the callsite they decorate.
315           auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
316           Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
317                               true);
318           // write the GEP of the slot in the sub-contexts portion of the
319           // context in TLS. Now, here, we use the actual Context value - as
320           // returned from compiler-rt - which may have the LSB set if the
321           // Context was scratch. Since the header of the context object and
322           // then the values are all 8-aligned (or, really, insofar as we care,
323           // they are even) - if the context is scratch (meaning, an odd value),
324           // so will the GEP. This is important because this is then visible to
325           // compiler-rt which will produce scratch contexts for callers that
326           // have a scratch context.
327           Builder.CreateStore(
328               Builder.CreateGEP(ThisContextType, Context,
329                                 {Builder.getInt32(0), Builder.getInt32(2),
330                                  CSIntrinsic->getIndex()}),
331               CallsiteInfoTLSAddr, true);
332           break;
333         }
334         I.eraseFromParent();
335       } else if (TheRootContext && isa<ReturnInst>(I)) {
336         // Remember to release the context if we are an entrypoint.
337         IRBuilder<> Builder(&I);
338         Builder.CreateCall(ReleaseCtx, {TheRootContext});
339         ContextWasReleased = true;
340       }
341     }
342   }
343   // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
344   // to disallow this, (so this then stays as an error), another is to detect
345   // that and then do a wrapper or disallow the tail call. This only affects
346   // instrumentation, when we want to detect the call graph.
347   if (TheRootContext && !ContextWasReleased)
348     F.getContext().emitError(
349         "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
350         "instructions above which to release the context: " +
351         F.getName());
352   return true;
353 }
354