xref: /llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp (revision 013f4a46d1978e370f940df3cbd04fb0399a04fe)
1bac6cd5bSPaul Kirth //===--- MisExpect.cpp - Check the use of llvm.expect with PGO data -------===//
2bac6cd5bSPaul Kirth //
3bac6cd5bSPaul Kirth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bac6cd5bSPaul Kirth // See https://llvm.org/LICENSE.txt for license information.
5bac6cd5bSPaul Kirth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bac6cd5bSPaul Kirth //
7bac6cd5bSPaul Kirth //===----------------------------------------------------------------------===//
8bac6cd5bSPaul Kirth //
9bac6cd5bSPaul Kirth // This contains code to emit warnings for potentially incorrect usage of the
10bac6cd5bSPaul Kirth // llvm.expect intrinsic. This utility extracts the threshold values from
11bac6cd5bSPaul Kirth // metadata associated with the instrumented Branch or Switch instruction. The
12bac6cd5bSPaul Kirth // threshold values are then used to determine if a warning should be emmited.
13bac6cd5bSPaul Kirth //
14bac6cd5bSPaul Kirth // MisExpect's implementation relies on two assumptions about how branch weights
15bac6cd5bSPaul Kirth // are managed in LLVM.
16bac6cd5bSPaul Kirth //
17bac6cd5bSPaul Kirth // 1) Frontend profiling weights are always in place before llvm.expect is
18bac6cd5bSPaul Kirth // lowered in LowerExpectIntrinsic.cpp. Frontend based instrumentation therefore
19bac6cd5bSPaul Kirth // needs to extract the branch weights and then compare them to the weights
20bac6cd5bSPaul Kirth // being added by the llvm.expect intrinsic lowering.
21bac6cd5bSPaul Kirth //
22bac6cd5bSPaul Kirth // 2) Sampling and IR based profiles will *only* have branch weight metadata
23bac6cd5bSPaul Kirth // before profiling data is consulted if they are from a lowered llvm.expect
24bac6cd5bSPaul Kirth // intrinsic. These profiles thus always extract the expected weights and then
25bac6cd5bSPaul Kirth // compare them to the weights collected during profiling to determine if a
26bac6cd5bSPaul Kirth // diagnostic message is warranted.
27bac6cd5bSPaul Kirth //
28bac6cd5bSPaul Kirth //===----------------------------------------------------------------------===//
29bac6cd5bSPaul Kirth 
30bac6cd5bSPaul Kirth #include "llvm/Transforms/Utils/MisExpect.h"
31bac6cd5bSPaul Kirth #include "llvm/ADT/Twine.h"
32bac6cd5bSPaul Kirth #include "llvm/Analysis/OptimizationRemarkEmitter.h"
33bac6cd5bSPaul Kirth #include "llvm/IR/DiagnosticInfo.h"
34bac6cd5bSPaul Kirth #include "llvm/IR/Instruction.h"
35bac6cd5bSPaul Kirth #include "llvm/IR/Instructions.h"
36bac6cd5bSPaul Kirth #include "llvm/IR/LLVMContext.h"
37d434e40fSPaul Kirth #include "llvm/IR/ProfDataUtils.h"
38bac6cd5bSPaul Kirth #include "llvm/Support/BranchProbability.h"
39bac6cd5bSPaul Kirth #include "llvm/Support/CommandLine.h"
40bac6cd5bSPaul Kirth #include "llvm/Support/FormatVariadic.h"
41f3a55a1dSJoe Loser #include <algorithm>
42bac6cd5bSPaul Kirth #include <cstdint>
43bac6cd5bSPaul Kirth #include <functional>
44bac6cd5bSPaul Kirth #include <numeric>
45bac6cd5bSPaul Kirth 
46bac6cd5bSPaul Kirth #define DEBUG_TYPE "misexpect"
47bac6cd5bSPaul Kirth 
48bac6cd5bSPaul Kirth using namespace llvm;
49bac6cd5bSPaul Kirth using namespace misexpect;
50bac6cd5bSPaul Kirth 
51bac6cd5bSPaul Kirth namespace llvm {
52bac6cd5bSPaul Kirth 
53bac6cd5bSPaul Kirth // Command line option to enable/disable the warning when profile data suggests
54bac6cd5bSPaul Kirth // a mismatch with the use of the llvm.expect intrinsic
55bac6cd5bSPaul Kirth static cl::opt<bool> PGOWarnMisExpect(
56bac6cd5bSPaul Kirth     "pgo-warn-misexpect", cl::init(false), cl::Hidden,
57bac6cd5bSPaul Kirth     cl::desc("Use this option to turn on/off "
58bac6cd5bSPaul Kirth              "warnings about incorrect usage of llvm.expect intrinsics."));
59bac6cd5bSPaul Kirth 
6089a080cbSPaul Kirth // Command line option for setting the diagnostic tolerance threshold
61656c5d65SPaul Kirth static cl::opt<uint32_t> MisExpectTolerance(
62bac6cd5bSPaul Kirth     "misexpect-tolerance", cl::init(0),
6389a080cbSPaul Kirth     cl::desc("Prevents emitting diagnostics when profile counts are "
64bac6cd5bSPaul Kirth              "within N% of the threshold.."));
65bac6cd5bSPaul Kirth 
66bac6cd5bSPaul Kirth } // namespace llvm
67bac6cd5bSPaul Kirth 
68bac6cd5bSPaul Kirth namespace {
69bac6cd5bSPaul Kirth 
70bac6cd5bSPaul Kirth bool isMisExpectDiagEnabled(LLVMContext &Ctx) {
71bac6cd5bSPaul Kirth   return PGOWarnMisExpect || Ctx.getMisExpectWarningRequested();
72bac6cd5bSPaul Kirth }
73bac6cd5bSPaul Kirth 
74f3a55a1dSJoe Loser uint32_t getMisExpectTolerance(LLVMContext &Ctx) {
75656c5d65SPaul Kirth   return std::max(static_cast<uint32_t>(MisExpectTolerance),
76bac6cd5bSPaul Kirth                   Ctx.getDiagnosticsMisExpectTolerance());
77bac6cd5bSPaul Kirth }
78bac6cd5bSPaul Kirth 
79bac6cd5bSPaul Kirth Instruction *getInstCondition(Instruction *I) {
80bac6cd5bSPaul Kirth   assert(I != nullptr && "MisExpect target Instruction cannot be nullptr");
81bac6cd5bSPaul Kirth   Instruction *Ret = nullptr;
82bac6cd5bSPaul Kirth   if (auto *B = dyn_cast<BranchInst>(I)) {
83bac6cd5bSPaul Kirth     Ret = dyn_cast<Instruction>(B->getCondition());
84bac6cd5bSPaul Kirth   }
85bac6cd5bSPaul Kirth   // TODO: Find a way to resolve condition location for switches
86bac6cd5bSPaul Kirth   // Using the condition of the switch seems to often resolve to an earlier
87bac6cd5bSPaul Kirth   // point in the program, i.e. the calculation of the switch condition, rather
88bac6cd5bSPaul Kirth   // than the switch's location in the source code. Thus, we should use the
89bac6cd5bSPaul Kirth   // instruction to get source code locations rather than the condition to
90bac6cd5bSPaul Kirth   // improve diagnostic output, such as the caret. If the same problem exists
91bac6cd5bSPaul Kirth   // for branch instructions, then we should remove this function and directly
92bac6cd5bSPaul Kirth   // use the instruction
93bac6cd5bSPaul Kirth   //
94bac6cd5bSPaul Kirth   else if (auto *S = dyn_cast<SwitchInst>(I)) {
95bac6cd5bSPaul Kirth     Ret = dyn_cast<Instruction>(S->getCondition());
96bac6cd5bSPaul Kirth   }
97bac6cd5bSPaul Kirth   return Ret ? Ret : I;
98bac6cd5bSPaul Kirth }
99bac6cd5bSPaul Kirth 
100bac6cd5bSPaul Kirth void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
101bac6cd5bSPaul Kirth                              uint64_t ProfCount, uint64_t TotalCount) {
102bac6cd5bSPaul Kirth   double PercentageCorrect = (double)ProfCount / TotalCount;
103bac6cd5bSPaul Kirth   auto PerString =
104bac6cd5bSPaul Kirth       formatv("{0:P} ({1} / {2})", PercentageCorrect, ProfCount, TotalCount);
105bac6cd5bSPaul Kirth   auto RemStr = formatv(
106bac6cd5bSPaul Kirth       "Potential performance regression from use of the llvm.expect intrinsic: "
107bac6cd5bSPaul Kirth       "Annotation was correct on {0} of profiled executions.",
108bac6cd5bSPaul Kirth       PerString);
109bac6cd5bSPaul Kirth   Twine Msg(PerString);
110bac6cd5bSPaul Kirth   Instruction *Cond = getInstCondition(I);
111bac6cd5bSPaul Kirth   if (isMisExpectDiagEnabled(Ctx))
112bac6cd5bSPaul Kirth     Ctx.diagnose(DiagnosticInfoMisExpect(Cond, Msg));
113bac6cd5bSPaul Kirth   OptimizationRemarkEmitter ORE(I->getParent()->getParent());
114bac6cd5bSPaul Kirth   ORE.emit(OptimizationRemark(DEBUG_TYPE, "misexpect", Cond) << RemStr.str());
115bac6cd5bSPaul Kirth }
116bac6cd5bSPaul Kirth 
117bac6cd5bSPaul Kirth } // namespace
118bac6cd5bSPaul Kirth 
119bac6cd5bSPaul Kirth namespace llvm {
120bac6cd5bSPaul Kirth namespace misexpect {
121bac6cd5bSPaul Kirth 
122bac6cd5bSPaul Kirth void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights,
123bac6cd5bSPaul Kirth                      ArrayRef<uint32_t> ExpectedWeights) {
124bac6cd5bSPaul Kirth   // To determine if we emit a diagnostic, we need to compare the branch weights
125bac6cd5bSPaul Kirth   // from the profile to those added by the llvm.expect intrinsic.
126bac6cd5bSPaul Kirth   // So first, we extract the "likely" and "unlikely" weights from
127bac6cd5bSPaul Kirth   // ExpectedWeights And determine the correct weight in the profile to compare
128bac6cd5bSPaul Kirth   // against.
129bac6cd5bSPaul Kirth   uint64_t LikelyBranchWeight = 0,
130bac6cd5bSPaul Kirth            UnlikelyBranchWeight = std::numeric_limits<uint32_t>::max();
131bac6cd5bSPaul Kirth   size_t MaxIndex = 0;
132bac6cd5bSPaul Kirth   for (size_t Idx = 0, End = ExpectedWeights.size(); Idx < End; Idx++) {
133bac6cd5bSPaul Kirth     uint32_t V = ExpectedWeights[Idx];
134bac6cd5bSPaul Kirth     if (LikelyBranchWeight < V) {
135bac6cd5bSPaul Kirth       LikelyBranchWeight = V;
136bac6cd5bSPaul Kirth       MaxIndex = Idx;
137bac6cd5bSPaul Kirth     }
138bac6cd5bSPaul Kirth     if (UnlikelyBranchWeight > V) {
139bac6cd5bSPaul Kirth       UnlikelyBranchWeight = V;
140bac6cd5bSPaul Kirth     }
141bac6cd5bSPaul Kirth   }
142bac6cd5bSPaul Kirth 
143bac6cd5bSPaul Kirth   const uint64_t ProfiledWeight = RealWeights[MaxIndex];
144bac6cd5bSPaul Kirth   const uint64_t RealWeightsTotal =
145bac6cd5bSPaul Kirth       std::accumulate(RealWeights.begin(), RealWeights.end(), (uint64_t)0,
146bac6cd5bSPaul Kirth                       std::plus<uint64_t>());
147bac6cd5bSPaul Kirth   const uint64_t NumUnlikelyTargets = RealWeights.size() - 1;
148bac6cd5bSPaul Kirth 
149bac6cd5bSPaul Kirth   uint64_t TotalBranchWeight =
150bac6cd5bSPaul Kirth       LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets);
151bac6cd5bSPaul Kirth 
152*a8cabb6fSPaul Kirth   // Failing this assert means that we have corrupted metadata.
153*a8cabb6fSPaul Kirth   assert((TotalBranchWeight >= LikelyBranchWeight) && (TotalBranchWeight > 0) &&
154*a8cabb6fSPaul Kirth          "TotalBranchWeight is less than the Likely branch weight");
155bac6cd5bSPaul Kirth 
156bac6cd5bSPaul Kirth   // To determine our threshold value we need to obtain the branch probability
157bac6cd5bSPaul Kirth   // for the weights added by llvm.expect and use that proportion to calculate
158bac6cd5bSPaul Kirth   // our threshold based on the collected profile data.
159bac6cd5bSPaul Kirth   auto LikelyProbablilty = BranchProbability::getBranchProbability(
160bac6cd5bSPaul Kirth       LikelyBranchWeight, TotalBranchWeight);
161bac6cd5bSPaul Kirth 
162bac6cd5bSPaul Kirth   uint64_t ScaledThreshold = LikelyProbablilty.scale(RealWeightsTotal);
163bac6cd5bSPaul Kirth 
164bac6cd5bSPaul Kirth   // clamp tolerance range to [0, 100)
165bac6cd5bSPaul Kirth   auto Tolerance = getMisExpectTolerance(I.getContext());
166f3a55a1dSJoe Loser   Tolerance = std::clamp(Tolerance, 0u, 99u);
167bac6cd5bSPaul Kirth 
168bac6cd5bSPaul Kirth   // Allow users to relax checking by N%  i.e., if they use a 5% tolerance,
169bac6cd5bSPaul Kirth   // then we check against 0.95*ScaledThreshold
170bac6cd5bSPaul Kirth   if (Tolerance > 0)
171bac6cd5bSPaul Kirth     ScaledThreshold *= (1.0 - Tolerance / 100.0);
172bac6cd5bSPaul Kirth 
173bac6cd5bSPaul Kirth   // When the profile weight is below the threshold, we emit the diagnostic
174bac6cd5bSPaul Kirth   if (ProfiledWeight < ScaledThreshold)
175bac6cd5bSPaul Kirth     emitMisexpectDiagnostic(&I, I.getContext(), ProfiledWeight,
176bac6cd5bSPaul Kirth                             RealWeightsTotal);
177bac6cd5bSPaul Kirth }
178bac6cd5bSPaul Kirth 
179bac6cd5bSPaul Kirth void checkBackendInstrumentation(Instruction &I,
180bac6cd5bSPaul Kirth                                  const ArrayRef<uint32_t> RealWeights) {
181*a8cabb6fSPaul Kirth   // Backend checking assumes any existing weight comes from an `llvm.expect`
182*a8cabb6fSPaul Kirth   // intrinsic. However, SampleProfiling + ThinLTO add branch weights  multiple
183*a8cabb6fSPaul Kirth   // times, leading to an invalid assumption in our checking. Backend checks
184*a8cabb6fSPaul Kirth   // should only operate on branch weights that carry the "!expected" field,
185*a8cabb6fSPaul Kirth   // since they are guaranteed to be added by the LowerExpectIntrinsic pass.
186*a8cabb6fSPaul Kirth   if (!hasBranchWeightOrigin(I))
187*a8cabb6fSPaul Kirth     return;
188d434e40fSPaul Kirth   SmallVector<uint32_t> ExpectedWeights;
189d434e40fSPaul Kirth   if (!extractBranchWeights(I, ExpectedWeights))
190bac6cd5bSPaul Kirth     return;
191bac6cd5bSPaul Kirth   verifyMisExpect(I, RealWeights, ExpectedWeights);
192bac6cd5bSPaul Kirth }
193bac6cd5bSPaul Kirth 
194bac6cd5bSPaul Kirth void checkFrontendInstrumentation(Instruction &I,
195bac6cd5bSPaul Kirth                                   const ArrayRef<uint32_t> ExpectedWeights) {
196d434e40fSPaul Kirth   SmallVector<uint32_t> RealWeights;
197d434e40fSPaul Kirth   if (!extractBranchWeights(I, RealWeights))
198bac6cd5bSPaul Kirth     return;
199bac6cd5bSPaul Kirth   verifyMisExpect(I, RealWeights, ExpectedWeights);
200bac6cd5bSPaul Kirth }
201bac6cd5bSPaul Kirth 
202bac6cd5bSPaul Kirth void checkExpectAnnotations(Instruction &I,
203bac6cd5bSPaul Kirth                             const ArrayRef<uint32_t> ExistingWeights,
20427105e2fSSimon Pilgrim                             bool IsFrontend) {
20527105e2fSSimon Pilgrim   if (IsFrontend) {
206bac6cd5bSPaul Kirth     checkFrontendInstrumentation(I, ExistingWeights);
207bac6cd5bSPaul Kirth   } else {
208bac6cd5bSPaul Kirth     checkBackendInstrumentation(I, ExistingWeights);
209bac6cd5bSPaul Kirth   }
210bac6cd5bSPaul Kirth }
211bac6cd5bSPaul Kirth 
212bac6cd5bSPaul Kirth } // namespace misexpect
213bac6cd5bSPaul Kirth } // namespace llvm
214bac6cd5bSPaul Kirth #undef DEBUG_TYPE
215