1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25
26 using namespace clang;
27 using namespace CodeGen;
28
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)29 void CodeGenPGO::setFuncName(StringRef Name,
30 llvm::GlobalValue::LinkageTypes Linkage) {
31 StringRef RawFuncName = Name;
32
33 // Function names may be prefixed with a binary '1' to indicate
34 // that the backend should not modify the symbols due to any platform
35 // naming convention. Do not include that '1' in the PGO profile name.
36 if (RawFuncName[0] == '\1')
37 RawFuncName = RawFuncName.substr(1);
38
39 FuncName = RawFuncName;
40 if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41 // For local symbols, prepend the main file name to distinguish them.
42 // Do not include the full path in the file name since there's no guarantee
43 // that it will stay the same, e.g., if the files are checked out from
44 // version control in different locations.
45 if (CGM.getCodeGenOpts().MainFileName.empty())
46 FuncName = FuncName.insert(0, "<unknown>:");
47 else
48 FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49 }
50
51 // If we're generating a profile, create a variable for the name.
52 if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53 createFuncNameVar(Linkage);
54 }
55
setFuncName(llvm::Function * Fn)56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57 setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59
createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage)60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61 // Usually, we want to match the function's linkage, but
62 // available_externally and extern_weak both have the wrong semantics.
63 if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
64 Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
65 else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
66 Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
67
68 auto *Value =
69 llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
70 FuncNameVar =
71 new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
72 Value, "__llvm_profile_name_" + FuncName);
73
74 // Hide the symbol so that we correctly get a copy for each executable.
75 if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
76 FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
77 }
78
79 namespace {
80 /// \brief Stable hasher for PGO region counters.
81 ///
82 /// PGOHash produces a stable hash of a given function's control flow.
83 ///
84 /// Changing the output of this hash will invalidate all previously generated
85 /// profiles -- i.e., don't do it.
86 ///
87 /// \note When this hash does eventually change (years?), we still need to
88 /// support old hashes. We'll need to pull in the version number from the
89 /// profile data format and use the matching hash function.
90 class PGOHash {
91 uint64_t Working;
92 unsigned Count;
93 llvm::MD5 MD5;
94
95 static const int NumBitsPerType = 6;
96 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
97 static const unsigned TooBig = 1u << NumBitsPerType;
98
99 public:
100 /// \brief Hash values for AST nodes.
101 ///
102 /// Distinct values for AST nodes that have region counters attached.
103 ///
104 /// These values must be stable. All new members must be added at the end,
105 /// and no members should be removed. Changing the enumeration value for an
106 /// AST node will affect the hash of every function that contains that node.
107 enum HashType : unsigned char {
108 None = 0,
109 LabelStmt = 1,
110 WhileStmt,
111 DoStmt,
112 ForStmt,
113 CXXForRangeStmt,
114 ObjCForCollectionStmt,
115 SwitchStmt,
116 CaseStmt,
117 DefaultStmt,
118 IfStmt,
119 CXXTryStmt,
120 CXXCatchStmt,
121 ConditionalOperator,
122 BinaryOperatorLAnd,
123 BinaryOperatorLOr,
124 BinaryConditionalOperator,
125
126 // Keep this last. It's for the static assert that follows.
127 LastHashType
128 };
129 static_assert(LastHashType <= TooBig, "Too many types in HashType");
130
131 // TODO: When this format changes, take in a version number here, and use the
132 // old hash calculation for file formats that used the old hash.
PGOHash()133 PGOHash() : Working(0), Count(0) {}
134 void combine(HashType Type);
135 uint64_t finalize();
136 };
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
140
141 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
142 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
143 /// The next counter value to assign.
144 unsigned NextCounter;
145 /// The function hash.
146 PGOHash Hash;
147 /// The map of statements to counters.
148 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
149
MapRegionCounters__anon7f44950c0111::MapRegionCounters150 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
151 : NextCounter(0), CounterMap(CounterMap) {}
152
153 // Blocks and lambdas are handled as separate functions, so we need not
154 // traverse them in the parent context.
TraverseBlockExpr__anon7f44950c0111::MapRegionCounters155 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaBody__anon7f44950c0111::MapRegionCounters156 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
TraverseCapturedStmt__anon7f44950c0111::MapRegionCounters157 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
158
VisitDecl__anon7f44950c0111::MapRegionCounters159 bool VisitDecl(const Decl *D) {
160 switch (D->getKind()) {
161 default:
162 break;
163 case Decl::Function:
164 case Decl::CXXMethod:
165 case Decl::CXXConstructor:
166 case Decl::CXXDestructor:
167 case Decl::CXXConversion:
168 case Decl::ObjCMethod:
169 case Decl::Block:
170 case Decl::Captured:
171 CounterMap[D->getBody()] = NextCounter++;
172 break;
173 }
174 return true;
175 }
176
VisitStmt__anon7f44950c0111::MapRegionCounters177 bool VisitStmt(const Stmt *S) {
178 auto Type = getHashType(S);
179 if (Type == PGOHash::None)
180 return true;
181
182 CounterMap[S] = NextCounter++;
183 Hash.combine(Type);
184 return true;
185 }
getHashType__anon7f44950c0111::MapRegionCounters186 PGOHash::HashType getHashType(const Stmt *S) {
187 switch (S->getStmtClass()) {
188 default:
189 break;
190 case Stmt::LabelStmtClass:
191 return PGOHash::LabelStmt;
192 case Stmt::WhileStmtClass:
193 return PGOHash::WhileStmt;
194 case Stmt::DoStmtClass:
195 return PGOHash::DoStmt;
196 case Stmt::ForStmtClass:
197 return PGOHash::ForStmt;
198 case Stmt::CXXForRangeStmtClass:
199 return PGOHash::CXXForRangeStmt;
200 case Stmt::ObjCForCollectionStmtClass:
201 return PGOHash::ObjCForCollectionStmt;
202 case Stmt::SwitchStmtClass:
203 return PGOHash::SwitchStmt;
204 case Stmt::CaseStmtClass:
205 return PGOHash::CaseStmt;
206 case Stmt::DefaultStmtClass:
207 return PGOHash::DefaultStmt;
208 case Stmt::IfStmtClass:
209 return PGOHash::IfStmt;
210 case Stmt::CXXTryStmtClass:
211 return PGOHash::CXXTryStmt;
212 case Stmt::CXXCatchStmtClass:
213 return PGOHash::CXXCatchStmt;
214 case Stmt::ConditionalOperatorClass:
215 return PGOHash::ConditionalOperator;
216 case Stmt::BinaryConditionalOperatorClass:
217 return PGOHash::BinaryConditionalOperator;
218 case Stmt::BinaryOperatorClass: {
219 const BinaryOperator *BO = cast<BinaryOperator>(S);
220 if (BO->getOpcode() == BO_LAnd)
221 return PGOHash::BinaryOperatorLAnd;
222 if (BO->getOpcode() == BO_LOr)
223 return PGOHash::BinaryOperatorLOr;
224 break;
225 }
226 }
227 return PGOHash::None;
228 }
229 };
230
231 /// A StmtVisitor that propagates the raw counts through the AST and
232 /// records the count at statements where the value may change.
233 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
234 /// PGO state.
235 CodeGenPGO &PGO;
236
237 /// A flag that is set when the current count should be recorded on the
238 /// next statement, such as at the exit of a loop.
239 bool RecordNextStmtCount;
240
241 /// The map of statements to count values.
242 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
243
244 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
245 struct BreakContinue {
246 uint64_t BreakCount;
247 uint64_t ContinueCount;
BreakContinue__anon7f44950c0111::ComputeRegionCounts::BreakContinue248 BreakContinue() : BreakCount(0), ContinueCount(0) {}
249 };
250 SmallVector<BreakContinue, 8> BreakContinueStack;
251
ComputeRegionCounts__anon7f44950c0111::ComputeRegionCounts252 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
253 CodeGenPGO &PGO)
254 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
255
RecordStmtCount__anon7f44950c0111::ComputeRegionCounts256 void RecordStmtCount(const Stmt *S) {
257 if (RecordNextStmtCount) {
258 CountMap[S] = PGO.getCurrentRegionCount();
259 RecordNextStmtCount = false;
260 }
261 }
262
VisitStmt__anon7f44950c0111::ComputeRegionCounts263 void VisitStmt(const Stmt *S) {
264 RecordStmtCount(S);
265 for (Stmt::const_child_range I = S->children(); I; ++I) {
266 if (*I)
267 this->Visit(*I);
268 }
269 }
270
VisitFunctionDecl__anon7f44950c0111::ComputeRegionCounts271 void VisitFunctionDecl(const FunctionDecl *D) {
272 // Counter tracks entry to the function body.
273 RegionCounter Cnt(PGO, D->getBody());
274 Cnt.beginRegion();
275 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
276 Visit(D->getBody());
277 }
278
279 // Skip lambda expressions. We visit these as FunctionDecls when we're
280 // generating them and aren't interested in the body when generating a
281 // parent context.
VisitLambdaExpr__anon7f44950c0111::ComputeRegionCounts282 void VisitLambdaExpr(const LambdaExpr *LE) {}
283
VisitCapturedDecl__anon7f44950c0111::ComputeRegionCounts284 void VisitCapturedDecl(const CapturedDecl *D) {
285 // Counter tracks entry to the capture body.
286 RegionCounter Cnt(PGO, D->getBody());
287 Cnt.beginRegion();
288 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
289 Visit(D->getBody());
290 }
291
VisitObjCMethodDecl__anon7f44950c0111::ComputeRegionCounts292 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
293 // Counter tracks entry to the method body.
294 RegionCounter Cnt(PGO, D->getBody());
295 Cnt.beginRegion();
296 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
297 Visit(D->getBody());
298 }
299
VisitBlockDecl__anon7f44950c0111::ComputeRegionCounts300 void VisitBlockDecl(const BlockDecl *D) {
301 // Counter tracks entry to the block body.
302 RegionCounter Cnt(PGO, D->getBody());
303 Cnt.beginRegion();
304 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
305 Visit(D->getBody());
306 }
307
VisitReturnStmt__anon7f44950c0111::ComputeRegionCounts308 void VisitReturnStmt(const ReturnStmt *S) {
309 RecordStmtCount(S);
310 if (S->getRetValue())
311 Visit(S->getRetValue());
312 PGO.setCurrentRegionUnreachable();
313 RecordNextStmtCount = true;
314 }
315
VisitGotoStmt__anon7f44950c0111::ComputeRegionCounts316 void VisitGotoStmt(const GotoStmt *S) {
317 RecordStmtCount(S);
318 PGO.setCurrentRegionUnreachable();
319 RecordNextStmtCount = true;
320 }
321
VisitLabelStmt__anon7f44950c0111::ComputeRegionCounts322 void VisitLabelStmt(const LabelStmt *S) {
323 RecordNextStmtCount = false;
324 // Counter tracks the block following the label.
325 RegionCounter Cnt(PGO, S);
326 Cnt.beginRegion();
327 CountMap[S] = PGO.getCurrentRegionCount();
328 Visit(S->getSubStmt());
329 }
330
VisitBreakStmt__anon7f44950c0111::ComputeRegionCounts331 void VisitBreakStmt(const BreakStmt *S) {
332 RecordStmtCount(S);
333 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
334 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
335 PGO.setCurrentRegionUnreachable();
336 RecordNextStmtCount = true;
337 }
338
VisitContinueStmt__anon7f44950c0111::ComputeRegionCounts339 void VisitContinueStmt(const ContinueStmt *S) {
340 RecordStmtCount(S);
341 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
342 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
343 PGO.setCurrentRegionUnreachable();
344 RecordNextStmtCount = true;
345 }
346
VisitWhileStmt__anon7f44950c0111::ComputeRegionCounts347 void VisitWhileStmt(const WhileStmt *S) {
348 RecordStmtCount(S);
349 // Counter tracks the body of the loop.
350 RegionCounter Cnt(PGO, S);
351 BreakContinueStack.push_back(BreakContinue());
352 // Visit the body region first so the break/continue adjustments can be
353 // included when visiting the condition.
354 Cnt.beginRegion();
355 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
356 Visit(S->getBody());
357 Cnt.adjustForControlFlow();
358
359 // ...then go back and propagate counts through the condition. The count
360 // at the start of the condition is the sum of the incoming edges,
361 // the backedge from the end of the loop body, and the edges from
362 // continue statements.
363 BreakContinue BC = BreakContinueStack.pop_back_val();
364 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
365 Cnt.getAdjustedCount() + BC.ContinueCount);
366 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
367 Visit(S->getCond());
368 Cnt.adjustForControlFlow();
369 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
370 RecordNextStmtCount = true;
371 }
372
VisitDoStmt__anon7f44950c0111::ComputeRegionCounts373 void VisitDoStmt(const DoStmt *S) {
374 RecordStmtCount(S);
375 // Counter tracks the body of the loop.
376 RegionCounter Cnt(PGO, S);
377 BreakContinueStack.push_back(BreakContinue());
378 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
379 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
380 Visit(S->getBody());
381 Cnt.adjustForControlFlow();
382
383 BreakContinue BC = BreakContinueStack.pop_back_val();
384 // The count at the start of the condition is equal to the count at the
385 // end of the body. The adjusted count does not include either the
386 // fall-through count coming into the loop or the continue count, so add
387 // both of those separately. This is coincidentally the same equation as
388 // with while loops but for different reasons.
389 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
390 Cnt.getAdjustedCount() + BC.ContinueCount);
391 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
392 Visit(S->getCond());
393 Cnt.adjustForControlFlow();
394 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
395 RecordNextStmtCount = true;
396 }
397
VisitForStmt__anon7f44950c0111::ComputeRegionCounts398 void VisitForStmt(const ForStmt *S) {
399 RecordStmtCount(S);
400 if (S->getInit())
401 Visit(S->getInit());
402 // Counter tracks the body of the loop.
403 RegionCounter Cnt(PGO, S);
404 BreakContinueStack.push_back(BreakContinue());
405 // Visit the body region first. (This is basically the same as a while
406 // loop; see further comments in VisitWhileStmt.)
407 Cnt.beginRegion();
408 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
409 Visit(S->getBody());
410 Cnt.adjustForControlFlow();
411
412 // The increment is essentially part of the body but it needs to include
413 // the count for all the continue statements.
414 if (S->getInc()) {
415 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
416 BreakContinueStack.back().ContinueCount);
417 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
418 Visit(S->getInc());
419 Cnt.adjustForControlFlow();
420 }
421
422 BreakContinue BC = BreakContinueStack.pop_back_val();
423
424 // ...then go back and propagate counts through the condition.
425 if (S->getCond()) {
426 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
427 Cnt.getAdjustedCount() +
428 BC.ContinueCount);
429 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
430 Visit(S->getCond());
431 Cnt.adjustForControlFlow();
432 }
433 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
434 RecordNextStmtCount = true;
435 }
436
VisitCXXForRangeStmt__anon7f44950c0111::ComputeRegionCounts437 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
438 RecordStmtCount(S);
439 Visit(S->getRangeStmt());
440 Visit(S->getBeginEndStmt());
441 // Counter tracks the body of the loop.
442 RegionCounter Cnt(PGO, S);
443 BreakContinueStack.push_back(BreakContinue());
444 // Visit the body region first. (This is basically the same as a while
445 // loop; see further comments in VisitWhileStmt.)
446 Cnt.beginRegion();
447 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
448 Visit(S->getLoopVarStmt());
449 Visit(S->getBody());
450 Cnt.adjustForControlFlow();
451
452 // The increment is essentially part of the body but it needs to include
453 // the count for all the continue statements.
454 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
455 BreakContinueStack.back().ContinueCount);
456 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
457 Visit(S->getInc());
458 Cnt.adjustForControlFlow();
459
460 BreakContinue BC = BreakContinueStack.pop_back_val();
461
462 // ...then go back and propagate counts through the condition.
463 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
464 Cnt.getAdjustedCount() +
465 BC.ContinueCount);
466 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
467 Visit(S->getCond());
468 Cnt.adjustForControlFlow();
469 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
470 RecordNextStmtCount = true;
471 }
472
VisitObjCForCollectionStmt__anon7f44950c0111::ComputeRegionCounts473 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
474 RecordStmtCount(S);
475 Visit(S->getElement());
476 // Counter tracks the body of the loop.
477 RegionCounter Cnt(PGO, S);
478 BreakContinueStack.push_back(BreakContinue());
479 Cnt.beginRegion();
480 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
481 Visit(S->getBody());
482 BreakContinue BC = BreakContinueStack.pop_back_val();
483 Cnt.adjustForControlFlow();
484 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
485 RecordNextStmtCount = true;
486 }
487
VisitSwitchStmt__anon7f44950c0111::ComputeRegionCounts488 void VisitSwitchStmt(const SwitchStmt *S) {
489 RecordStmtCount(S);
490 Visit(S->getCond());
491 PGO.setCurrentRegionUnreachable();
492 BreakContinueStack.push_back(BreakContinue());
493 Visit(S->getBody());
494 // If the switch is inside a loop, add the continue counts.
495 BreakContinue BC = BreakContinueStack.pop_back_val();
496 if (!BreakContinueStack.empty())
497 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
498 // Counter tracks the exit block of the switch.
499 RegionCounter ExitCnt(PGO, S);
500 ExitCnt.beginRegion();
501 RecordNextStmtCount = true;
502 }
503
VisitCaseStmt__anon7f44950c0111::ComputeRegionCounts504 void VisitCaseStmt(const CaseStmt *S) {
505 RecordNextStmtCount = false;
506 // Counter for this particular case. This counts only jumps from the
507 // switch header and does not include fallthrough from the case before
508 // this one.
509 RegionCounter Cnt(PGO, S);
510 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
511 CountMap[S] = Cnt.getCount();
512 RecordNextStmtCount = true;
513 Visit(S->getSubStmt());
514 }
515
VisitDefaultStmt__anon7f44950c0111::ComputeRegionCounts516 void VisitDefaultStmt(const DefaultStmt *S) {
517 RecordNextStmtCount = false;
518 // Counter for this default case. This does not include fallthrough from
519 // the previous case.
520 RegionCounter Cnt(PGO, S);
521 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
522 CountMap[S] = Cnt.getCount();
523 RecordNextStmtCount = true;
524 Visit(S->getSubStmt());
525 }
526
VisitIfStmt__anon7f44950c0111::ComputeRegionCounts527 void VisitIfStmt(const IfStmt *S) {
528 RecordStmtCount(S);
529 // Counter tracks the "then" part of an if statement. The count for
530 // the "else" part, if it exists, will be calculated from this counter.
531 RegionCounter Cnt(PGO, S);
532 Visit(S->getCond());
533
534 Cnt.beginRegion();
535 CountMap[S->getThen()] = PGO.getCurrentRegionCount();
536 Visit(S->getThen());
537 Cnt.adjustForControlFlow();
538
539 if (S->getElse()) {
540 Cnt.beginElseRegion();
541 CountMap[S->getElse()] = PGO.getCurrentRegionCount();
542 Visit(S->getElse());
543 Cnt.adjustForControlFlow();
544 }
545 Cnt.applyAdjustmentsToRegion(0);
546 RecordNextStmtCount = true;
547 }
548
VisitCXXTryStmt__anon7f44950c0111::ComputeRegionCounts549 void VisitCXXTryStmt(const CXXTryStmt *S) {
550 RecordStmtCount(S);
551 Visit(S->getTryBlock());
552 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
553 Visit(S->getHandler(I));
554 // Counter tracks the continuation block of the try statement.
555 RegionCounter Cnt(PGO, S);
556 Cnt.beginRegion();
557 RecordNextStmtCount = true;
558 }
559
VisitCXXCatchStmt__anon7f44950c0111::ComputeRegionCounts560 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
561 RecordNextStmtCount = false;
562 // Counter tracks the catch statement's handler block.
563 RegionCounter Cnt(PGO, S);
564 Cnt.beginRegion();
565 CountMap[S] = PGO.getCurrentRegionCount();
566 Visit(S->getHandlerBlock());
567 }
568
VisitAbstractConditionalOperator__anon7f44950c0111::ComputeRegionCounts569 void VisitAbstractConditionalOperator(
570 const AbstractConditionalOperator *E) {
571 RecordStmtCount(E);
572 // Counter tracks the "true" part of a conditional operator. The
573 // count in the "false" part will be calculated from this counter.
574 RegionCounter Cnt(PGO, E);
575 Visit(E->getCond());
576
577 Cnt.beginRegion();
578 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
579 Visit(E->getTrueExpr());
580 Cnt.adjustForControlFlow();
581
582 Cnt.beginElseRegion();
583 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
584 Visit(E->getFalseExpr());
585 Cnt.adjustForControlFlow();
586
587 Cnt.applyAdjustmentsToRegion(0);
588 RecordNextStmtCount = true;
589 }
590
VisitBinLAnd__anon7f44950c0111::ComputeRegionCounts591 void VisitBinLAnd(const BinaryOperator *E) {
592 RecordStmtCount(E);
593 // Counter tracks the right hand side of a logical and operator.
594 RegionCounter Cnt(PGO, E);
595 Visit(E->getLHS());
596 Cnt.beginRegion();
597 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
598 Visit(E->getRHS());
599 Cnt.adjustForControlFlow();
600 Cnt.applyAdjustmentsToRegion(0);
601 RecordNextStmtCount = true;
602 }
603
VisitBinLOr__anon7f44950c0111::ComputeRegionCounts604 void VisitBinLOr(const BinaryOperator *E) {
605 RecordStmtCount(E);
606 // Counter tracks the right hand side of a logical or operator.
607 RegionCounter Cnt(PGO, E);
608 Visit(E->getLHS());
609 Cnt.beginRegion();
610 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
611 Visit(E->getRHS());
612 Cnt.adjustForControlFlow();
613 Cnt.applyAdjustmentsToRegion(0);
614 RecordNextStmtCount = true;
615 }
616 };
617 }
618
combine(HashType Type)619 void PGOHash::combine(HashType Type) {
620 // Check that we never combine 0 and only have six bits.
621 assert(Type && "Hash is invalid: unexpected type 0");
622 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
623
624 // Pass through MD5 if enough work has built up.
625 if (Count && Count % NumTypesPerWord == 0) {
626 using namespace llvm::support;
627 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
628 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
629 Working = 0;
630 }
631
632 // Accumulate the current type.
633 ++Count;
634 Working = Working << NumBitsPerType | Type;
635 }
636
finalize()637 uint64_t PGOHash::finalize() {
638 // Use Working as the hash directly if we never used MD5.
639 if (Count <= NumTypesPerWord)
640 // No need to byte swap here, since none of the math was endian-dependent.
641 // This number will be byte-swapped as required on endianness transitions,
642 // so we will see the same value on the other side.
643 return Working;
644
645 // Check for remaining work in Working.
646 if (Working)
647 MD5.update(Working);
648
649 // Finalize the MD5 and return the hash.
650 llvm::MD5::MD5Result Result;
651 MD5.final(Result);
652 using namespace llvm::support;
653 return endian::read<uint64_t, little, unaligned>(Result);
654 }
655
checkGlobalDecl(GlobalDecl GD)656 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
657 // Make sure we only emit coverage mapping for one constructor/destructor.
658 // Clang emits several functions for the constructor and the destructor of
659 // a class. Every function is instrumented, but we only want to provide
660 // coverage for one of them. Because of that we only emit the coverage mapping
661 // for the base constructor/destructor.
662 if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
663 GD.getCtorType() != Ctor_Base) ||
664 (isa<CXXDestructorDecl>(GD.getDecl()) &&
665 GD.getDtorType() != Dtor_Base)) {
666 SkipCoverageMapping = true;
667 }
668 }
669
assignRegionCounters(const Decl * D,llvm::Function * Fn)670 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
671 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
672 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
673 if (!InstrumentRegions && !PGOReader)
674 return;
675 if (D->isImplicit())
676 return;
677 CGM.ClearUnusedCoverageMapping(D);
678 setFuncName(Fn);
679
680 mapRegionCounters(D);
681 if (CGM.getCodeGenOpts().CoverageMapping)
682 emitCounterRegionMapping(D);
683 if (PGOReader) {
684 SourceManager &SM = CGM.getContext().getSourceManager();
685 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
686 computeRegionCounts(D);
687 applyFunctionAttributes(PGOReader, Fn);
688 }
689 }
690
mapRegionCounters(const Decl * D)691 void CodeGenPGO::mapRegionCounters(const Decl *D) {
692 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
693 MapRegionCounters Walker(*RegionCounterMap);
694 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
695 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
696 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
697 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
698 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
699 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
700 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
701 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
702 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
703 NumRegionCounters = Walker.NextCounter;
704 FunctionHash = Walker.Hash.finalize();
705 }
706
emitCounterRegionMapping(const Decl * D)707 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
708 if (SkipCoverageMapping)
709 return;
710 // Don't map the functions inside the system headers
711 auto Loc = D->getBody()->getLocStart();
712 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
713 return;
714
715 std::string CoverageMapping;
716 llvm::raw_string_ostream OS(CoverageMapping);
717 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
718 CGM.getContext().getSourceManager(),
719 CGM.getLangOpts(), RegionCounterMap.get());
720 MappingGen.emitCounterMapping(D, OS);
721 OS.flush();
722
723 if (CoverageMapping.empty())
724 return;
725
726 CGM.getCoverageMapping()->addFunctionMappingRecord(
727 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
728 }
729
730 void
emitEmptyCounterMapping(const Decl * D,StringRef FuncName,llvm::GlobalValue::LinkageTypes Linkage)731 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
732 llvm::GlobalValue::LinkageTypes Linkage) {
733 if (SkipCoverageMapping)
734 return;
735 setFuncName(FuncName, Linkage);
736
737 // Don't map the functions inside the system headers
738 auto Loc = D->getBody()->getLocStart();
739 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
740 return;
741
742 std::string CoverageMapping;
743 llvm::raw_string_ostream OS(CoverageMapping);
744 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
745 CGM.getContext().getSourceManager(),
746 CGM.getLangOpts());
747 MappingGen.emitEmptyMapping(D, OS);
748 OS.flush();
749
750 if (CoverageMapping.empty())
751 return;
752
753 CGM.getCoverageMapping()->addFunctionMappingRecord(
754 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
755 }
756
computeRegionCounts(const Decl * D)757 void CodeGenPGO::computeRegionCounts(const Decl *D) {
758 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
759 ComputeRegionCounts Walker(*StmtCountMap, *this);
760 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
761 Walker.VisitFunctionDecl(FD);
762 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
763 Walker.VisitObjCMethodDecl(MD);
764 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
765 Walker.VisitBlockDecl(BD);
766 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
767 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
768 }
769
770 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)771 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
772 llvm::Function *Fn) {
773 if (!haveRegionCounts())
774 return;
775
776 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
777 uint64_t FunctionCount = getRegionCount(0);
778 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
779 // Turn on InlineHint attribute for hot functions.
780 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
781 Fn->addFnAttr(llvm::Attribute::InlineHint);
782 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
783 // Turn on Cold attribute for cold functions.
784 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
785 Fn->addFnAttr(llvm::Attribute::Cold);
786 }
787
emitCounterIncrement(CGBuilderTy & Builder,unsigned Counter)788 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
789 if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
790 return;
791 if (!Builder.GetInsertPoint())
792 return;
793 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
794 Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
795 llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
796 Builder.getInt64(FunctionHash),
797 Builder.getInt32(NumRegionCounters),
798 Builder.getInt32(Counter));
799 }
800
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)801 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
802 bool IsInMainFile) {
803 CGM.getPGOStats().addVisited(IsInMainFile);
804 RegionCounts.clear();
805 if (std::error_code EC =
806 PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
807 if (EC == llvm::instrprof_error::unknown_function)
808 CGM.getPGOStats().addMissing(IsInMainFile);
809 else if (EC == llvm::instrprof_error::hash_mismatch)
810 CGM.getPGOStats().addMismatched(IsInMainFile);
811 else if (EC == llvm::instrprof_error::malformed)
812 // TODO: Consider a more specific warning for this case.
813 CGM.getPGOStats().addMismatched(IsInMainFile);
814 RegionCounts.clear();
815 }
816 }
817
818 /// \brief Calculate what to divide by to scale weights.
819 ///
820 /// Given the maximum weight, calculate a divisor that will scale all the
821 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)822 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
823 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
824 }
825
826 /// \brief Scale an individual branch weight (and add 1).
827 ///
828 /// Scale a 64-bit weight down to 32-bits using \c Scale.
829 ///
830 /// According to Laplace's Rule of Succession, it is better to compute the
831 /// weight based on the count plus 1, so universally add 1 to the value.
832 ///
833 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
834 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)835 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
836 assert(Scale && "scale by 0?");
837 uint64_t Scaled = Weight / Scale + 1;
838 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
839 return Scaled;
840 }
841
createBranchWeights(uint64_t TrueCount,uint64_t FalseCount)842 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
843 uint64_t FalseCount) {
844 // Check for empty weights.
845 if (!TrueCount && !FalseCount)
846 return nullptr;
847
848 // Calculate how to scale down to 32-bits.
849 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
850
851 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
852 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
853 scaleBranchWeight(FalseCount, Scale));
854 }
855
createBranchWeights(ArrayRef<uint64_t> Weights)856 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
857 // We need at least two elements to create meaningful weights.
858 if (Weights.size() < 2)
859 return nullptr;
860
861 // Check for empty weights.
862 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
863 if (MaxWeight == 0)
864 return nullptr;
865
866 // Calculate how to scale down to 32-bits.
867 uint64_t Scale = calculateWeightScale(MaxWeight);
868
869 SmallVector<uint32_t, 16> ScaledWeights;
870 ScaledWeights.reserve(Weights.size());
871 for (uint64_t W : Weights)
872 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
873
874 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
875 return MDHelper.createBranchWeights(ScaledWeights);
876 }
877
createLoopWeights(const Stmt * Cond,RegionCounter & Cnt)878 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
879 RegionCounter &Cnt) {
880 if (!haveRegionCounts())
881 return nullptr;
882 uint64_t LoopCount = Cnt.getCount();
883 uint64_t CondCount = 0;
884 bool Found = getStmtCount(Cond, CondCount);
885 assert(Found && "missing expected loop condition count");
886 (void)Found;
887 if (CondCount == 0)
888 return nullptr;
889 return createBranchWeights(LoopCount,
890 std::max(CondCount, LoopCount) - LoopCount);
891 }
892