xref: /llvm-project/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp (revision fa9ac62d02fd6b5028f301ee398c3d3a1c0eacae)
1 //===- LowerAllowCheckPass.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Instrumentation/LowerAllowCheckPass.h"
10 
11 #include "llvm/ADT/SmallVector.h"
12 #include "llvm/ADT/Statistic.h"
13 #include "llvm/ADT/StringRef.h"
14 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
15 #include "llvm/Analysis/ProfileSummaryInfo.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/DiagnosticInfo.h"
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/RandomNumberGenerator.h"
25 #include <memory>
26 #include <random>
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "lower-allow-check"
31 
32 static cl::opt<int>
33     HotPercentileCutoff("lower-allow-check-percentile-cutoff-hot",
34                         cl::desc("Hot percentile cutoff."));
35 
36 static cl::opt<float>
37     RandomRate("lower-allow-check-random-rate",
38                cl::desc("Probability value in the range [0.0, 1.0] of "
39                         "unconditional pseudo-random checks."));
40 
41 STATISTIC(NumChecksTotal, "Number of checks");
42 STATISTIC(NumChecksRemoved, "Number of removed checks");
43 
44 struct RemarkInfo {
45   ore::NV Kind;
46   ore::NV F;
47   ore::NV BB;
48   explicit RemarkInfo(IntrinsicInst *II)
49       : Kind("Kind", II->getArgOperand(0)),
50         F("Function", II->getParent()->getParent()),
51         BB("Block", II->getParent()->getName()) {}
52 };
53 
54 static void emitRemark(IntrinsicInst *II, OptimizationRemarkEmitter &ORE,
55                        bool Removed) {
56   if (Removed) {
57     ORE.emit([&]() {
58       RemarkInfo Info(II);
59       return OptimizationRemark(DEBUG_TYPE, "Removed", II)
60              << "Removed check: Kind=" << Info.Kind << " F=" << Info.F
61              << " BB=" << Info.BB;
62     });
63   } else {
64     ORE.emit([&]() {
65       RemarkInfo Info(II);
66       return OptimizationRemarkMissed(DEBUG_TYPE, "Allowed", II)
67              << "Allowed check: Kind=" << Info.Kind << " F=" << Info.F
68              << " BB=" << Info.BB;
69     });
70   }
71 }
72 
73 static bool removeUbsanTraps(Function &F, const BlockFrequencyInfo &BFI,
74                              const ProfileSummaryInfo *PSI,
75                              OptimizationRemarkEmitter &ORE,
76                              const std::vector<unsigned int> &cutoffs) {
77   SmallVector<std::pair<IntrinsicInst *, bool>, 16> ReplaceWithValue;
78   std::unique_ptr<RandomNumberGenerator> Rng;
79 
80   auto GetRng = [&]() -> RandomNumberGenerator & {
81     if (!Rng)
82       Rng = F.getParent()->createRNG(F.getName());
83     return *Rng;
84   };
85 
86   auto GetCutoff = [&](const IntrinsicInst *II) -> unsigned {
87     if (HotPercentileCutoff.getNumOccurrences())
88       return HotPercentileCutoff;
89     else if (II->getIntrinsicID() == Intrinsic::allow_ubsan_check) {
90       auto *Kind = cast<ConstantInt>(II->getArgOperand(0));
91       if (Kind->getZExtValue() < cutoffs.size())
92         return cutoffs[Kind->getZExtValue()];
93     }
94 
95     return 0;
96   };
97 
98   auto ShouldRemoveHot = [&](const BasicBlock &BB, unsigned int cutoff) {
99     return (cutoff == 1000000) ||
100            (PSI && PSI->isHotCountNthPercentile(
101                        cutoff, BFI.getBlockProfileCount(&BB).value_or(0)));
102   };
103 
104   auto ShouldRemoveRandom = [&]() {
105     return RandomRate.getNumOccurrences() &&
106            !std::bernoulli_distribution(RandomRate)(GetRng());
107   };
108 
109   auto ShouldRemove = [&](const IntrinsicInst *II) {
110     unsigned int cutoff = GetCutoff(II);
111     return ShouldRemoveRandom() || ShouldRemoveHot(*(II->getParent()), cutoff);
112   };
113 
114   for (BasicBlock &BB : F) {
115     for (Instruction &I : BB) {
116       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
117       if (!II)
118         continue;
119       auto ID = II->getIntrinsicID();
120       switch (ID) {
121       case Intrinsic::allow_ubsan_check:
122       case Intrinsic::allow_runtime_check: {
123         ++NumChecksTotal;
124 
125         bool ToRemove = ShouldRemove(II);
126 
127         ReplaceWithValue.push_back({
128             II,
129             ToRemove,
130         });
131         if (ToRemove)
132           ++NumChecksRemoved;
133         emitRemark(II, ORE, ToRemove);
134         break;
135       }
136       default:
137         break;
138       }
139     }
140   }
141 
142   for (auto [I, V] : ReplaceWithValue) {
143     I->replaceAllUsesWith(ConstantInt::getBool(I->getType(), !V));
144     I->eraseFromParent();
145   }
146 
147   return !ReplaceWithValue.empty();
148 }
149 
150 PreservedAnalyses LowerAllowCheckPass::run(Function &F,
151                                            FunctionAnalysisManager &AM) {
152   if (F.isDeclaration())
153     return PreservedAnalyses::all();
154   auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
155   ProfileSummaryInfo *PSI =
156       MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
157   BlockFrequencyInfo &BFI = AM.getResult<BlockFrequencyAnalysis>(F);
158   OptimizationRemarkEmitter &ORE =
159       AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
160 
161   return removeUbsanTraps(F, BFI, PSI, ORE, Opts.cutoffs)
162              ? PreservedAnalyses::none()
163              : PreservedAnalyses::all();
164 }
165 
166 bool LowerAllowCheckPass::IsRequested() {
167   return RandomRate.getNumOccurrences() ||
168          HotPercentileCutoff.getNumOccurrences();
169 }
170 
171 void LowerAllowCheckPass::printPipeline(
172     raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
173   static_cast<PassInfoMixin<LowerAllowCheckPass> *>(this)->printPipeline(
174       OS, MapClassName2PassName);
175   OS << "<";
176 
177   // Format is <cutoffs[0,1,2]=70000;cutoffs[5,6,8]=90000>
178   // but it's equally valid to specify
179   //   cutoffs[0]=70000;cutoffs[1]=70000;cutoffs[2]=70000;cutoffs[5]=90000;...
180   // and that's what we do here. It is verbose but valid and easy to verify
181   // correctness.
182   // TODO: print shorter output by combining adjacent runs, etc.
183   int i = 0;
184   for (unsigned int cutoff : Opts.cutoffs) {
185     if (cutoff > 0) {
186       if (i > 0)
187         OS << ";";
188       OS << "cutoffs[" << i << "]=" << cutoff;
189     }
190 
191     i++;
192   }
193   OS << '>';
194 }
195