xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Target/BPF/BPFAdjustOpt.cpp (revision 82d56013d7b633d116a93943de88e08335357a7c)
1*82d56013Sjoerg //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2*82d56013Sjoerg //
3*82d56013Sjoerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*82d56013Sjoerg // See https://llvm.org/LICENSE.txt for license information.
5*82d56013Sjoerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*82d56013Sjoerg //
7*82d56013Sjoerg //===----------------------------------------------------------------------===//
8*82d56013Sjoerg //
9*82d56013Sjoerg // Adjust optimization to make the code more kernel verifier friendly.
10*82d56013Sjoerg //
11*82d56013Sjoerg //===----------------------------------------------------------------------===//
12*82d56013Sjoerg 
13*82d56013Sjoerg #include "BPF.h"
14*82d56013Sjoerg #include "BPFCORE.h"
15*82d56013Sjoerg #include "BPFTargetMachine.h"
16*82d56013Sjoerg #include "llvm/IR/Instruction.h"
17*82d56013Sjoerg #include "llvm/IR/Instructions.h"
18*82d56013Sjoerg #include "llvm/IR/Module.h"
19*82d56013Sjoerg #include "llvm/IR/PatternMatch.h"
20*82d56013Sjoerg #include "llvm/IR/Type.h"
21*82d56013Sjoerg #include "llvm/IR/User.h"
22*82d56013Sjoerg #include "llvm/IR/Value.h"
23*82d56013Sjoerg #include "llvm/Pass.h"
24*82d56013Sjoerg #include "llvm/Transforms/Utils/BasicBlockUtils.h"
25*82d56013Sjoerg 
26*82d56013Sjoerg #define DEBUG_TYPE "bpf-adjust-opt"
27*82d56013Sjoerg 
28*82d56013Sjoerg using namespace llvm;
29*82d56013Sjoerg using namespace llvm::PatternMatch;
30*82d56013Sjoerg 
31*82d56013Sjoerg static cl::opt<bool>
32*82d56013Sjoerg     DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
33*82d56013Sjoerg                             cl::desc("BPF: Disable Serializing ICMP insns."),
34*82d56013Sjoerg                             cl::init(false));
35*82d56013Sjoerg 
36*82d56013Sjoerg static cl::opt<bool> DisableBPFavoidSpeculation(
37*82d56013Sjoerg     "bpf-disable-avoid-speculation", cl::Hidden,
38*82d56013Sjoerg     cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
39*82d56013Sjoerg     cl::init(false));
40*82d56013Sjoerg 
41*82d56013Sjoerg namespace {
42*82d56013Sjoerg 
43*82d56013Sjoerg class BPFAdjustOpt final : public ModulePass {
44*82d56013Sjoerg public:
45*82d56013Sjoerg   static char ID;
46*82d56013Sjoerg 
BPFAdjustOpt()47*82d56013Sjoerg   BPFAdjustOpt() : ModulePass(ID) {}
48*82d56013Sjoerg   bool runOnModule(Module &M) override;
49*82d56013Sjoerg };
50*82d56013Sjoerg 
51*82d56013Sjoerg class BPFAdjustOptImpl {
52*82d56013Sjoerg   struct PassThroughInfo {
53*82d56013Sjoerg     Instruction *Input;
54*82d56013Sjoerg     Instruction *UsedInst;
55*82d56013Sjoerg     uint32_t OpIdx;
PassThroughInfo__anon5a0d34ee0111::BPFAdjustOptImpl::PassThroughInfo56*82d56013Sjoerg     PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
57*82d56013Sjoerg         : Input(I), UsedInst(U), OpIdx(Idx) {}
58*82d56013Sjoerg   };
59*82d56013Sjoerg 
60*82d56013Sjoerg public:
BPFAdjustOptImpl(Module * M)61*82d56013Sjoerg   BPFAdjustOptImpl(Module *M) : M(M) {}
62*82d56013Sjoerg 
63*82d56013Sjoerg   bool run();
64*82d56013Sjoerg 
65*82d56013Sjoerg private:
66*82d56013Sjoerg   Module *M;
67*82d56013Sjoerg   SmallVector<PassThroughInfo, 16> PassThroughs;
68*82d56013Sjoerg 
69*82d56013Sjoerg   void adjustBasicBlock(BasicBlock &BB);
70*82d56013Sjoerg   bool serializeICMPCrossBB(BasicBlock &BB);
71*82d56013Sjoerg   void adjustInst(Instruction &I);
72*82d56013Sjoerg   bool serializeICMPInBB(Instruction &I);
73*82d56013Sjoerg   bool avoidSpeculation(Instruction &I);
74*82d56013Sjoerg   bool insertPassThrough();
75*82d56013Sjoerg };
76*82d56013Sjoerg 
77*82d56013Sjoerg } // End anonymous namespace
78*82d56013Sjoerg 
79*82d56013Sjoerg char BPFAdjustOpt::ID = 0;
80*82d56013Sjoerg INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization",
81*82d56013Sjoerg                 false, false)
82*82d56013Sjoerg 
createBPFAdjustOpt()83*82d56013Sjoerg ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); }
84*82d56013Sjoerg 
runOnModule(Module & M)85*82d56013Sjoerg bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); }
86*82d56013Sjoerg 
run()87*82d56013Sjoerg bool BPFAdjustOptImpl::run() {
88*82d56013Sjoerg   for (Function &F : *M)
89*82d56013Sjoerg     for (auto &BB : F) {
90*82d56013Sjoerg       adjustBasicBlock(BB);
91*82d56013Sjoerg       for (auto &I : BB)
92*82d56013Sjoerg         adjustInst(I);
93*82d56013Sjoerg     }
94*82d56013Sjoerg 
95*82d56013Sjoerg   return insertPassThrough();
96*82d56013Sjoerg }
97*82d56013Sjoerg 
insertPassThrough()98*82d56013Sjoerg bool BPFAdjustOptImpl::insertPassThrough() {
99*82d56013Sjoerg   for (auto &Info : PassThroughs) {
100*82d56013Sjoerg     auto *CI = BPFCoreSharedInfo::insertPassThrough(
101*82d56013Sjoerg         M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
102*82d56013Sjoerg     Info.UsedInst->setOperand(Info.OpIdx, CI);
103*82d56013Sjoerg   }
104*82d56013Sjoerg 
105*82d56013Sjoerg   return !PassThroughs.empty();
106*82d56013Sjoerg }
107*82d56013Sjoerg 
108*82d56013Sjoerg // To avoid combining conditionals in the same basic block by
109*82d56013Sjoerg // instrcombine optimization.
serializeICMPInBB(Instruction & I)110*82d56013Sjoerg bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
111*82d56013Sjoerg   // For:
112*82d56013Sjoerg   //   comp1 = icmp <opcode> ...;
113*82d56013Sjoerg   //   comp2 = icmp <opcode> ...;
114*82d56013Sjoerg   //   ... or comp1 comp2 ...
115*82d56013Sjoerg   // changed to:
116*82d56013Sjoerg   //   comp1 = icmp <opcode> ...;
117*82d56013Sjoerg   //   comp2 = icmp <opcode> ...;
118*82d56013Sjoerg   //   new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
119*82d56013Sjoerg   //   ... or new_comp1 comp2 ...
120*82d56013Sjoerg   Value *Op0, *Op1;
121*82d56013Sjoerg   // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
122*82d56013Sjoerg   if (!match(&I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
123*82d56013Sjoerg     return false;
124*82d56013Sjoerg   auto *Icmp1 = dyn_cast<ICmpInst>(Op0);
125*82d56013Sjoerg   if (!Icmp1)
126*82d56013Sjoerg     return false;
127*82d56013Sjoerg   auto *Icmp2 = dyn_cast<ICmpInst>(Op1);
128*82d56013Sjoerg   if (!Icmp2)
129*82d56013Sjoerg     return false;
130*82d56013Sjoerg 
131*82d56013Sjoerg   Value *Icmp1Op0 = Icmp1->getOperand(0);
132*82d56013Sjoerg   Value *Icmp2Op0 = Icmp2->getOperand(0);
133*82d56013Sjoerg   if (Icmp1Op0 != Icmp2Op0)
134*82d56013Sjoerg     return false;
135*82d56013Sjoerg 
136*82d56013Sjoerg   // Now we got two icmp instructions which feed into
137*82d56013Sjoerg   // an "or" instruction.
138*82d56013Sjoerg   PassThroughInfo Info(Icmp1, &I, 0);
139*82d56013Sjoerg   PassThroughs.push_back(Info);
140*82d56013Sjoerg   return true;
141*82d56013Sjoerg }
142*82d56013Sjoerg 
143*82d56013Sjoerg // To avoid combining conditionals in the same basic block by
144*82d56013Sjoerg // instrcombine optimization.
serializeICMPCrossBB(BasicBlock & BB)145*82d56013Sjoerg bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
146*82d56013Sjoerg   // For:
147*82d56013Sjoerg   //   B1:
148*82d56013Sjoerg   //     comp1 = icmp <opcode> ...;
149*82d56013Sjoerg   //     if (comp1) goto B2 else B3;
150*82d56013Sjoerg   //   B2:
151*82d56013Sjoerg   //     comp2 = icmp <opcode> ...;
152*82d56013Sjoerg   //     if (comp2) goto B4 else B5;
153*82d56013Sjoerg   //   B4:
154*82d56013Sjoerg   //     ...
155*82d56013Sjoerg   // changed to:
156*82d56013Sjoerg   //   B1:
157*82d56013Sjoerg   //     comp1 = icmp <opcode> ...;
158*82d56013Sjoerg   //     comp1 = __builtin_bpf_passthrough(seq_num, comp1);
159*82d56013Sjoerg   //     if (comp1) goto B2 else B3;
160*82d56013Sjoerg   //   B2:
161*82d56013Sjoerg   //     comp2 = icmp <opcode> ...;
162*82d56013Sjoerg   //     if (comp2) goto B4 else B5;
163*82d56013Sjoerg   //   B4:
164*82d56013Sjoerg   //     ...
165*82d56013Sjoerg 
166*82d56013Sjoerg   // Check basic predecessors, if two of them (say B1, B2) are using
167*82d56013Sjoerg   // icmp instructions to generate conditions and one is the predesessor
168*82d56013Sjoerg   // of another (e.g., B1 is the predecessor of B2). Add a passthrough
169*82d56013Sjoerg   // barrier after icmp inst of block B1.
170*82d56013Sjoerg   BasicBlock *B2 = BB.getSinglePredecessor();
171*82d56013Sjoerg   if (!B2)
172*82d56013Sjoerg     return false;
173*82d56013Sjoerg 
174*82d56013Sjoerg   BasicBlock *B1 = B2->getSinglePredecessor();
175*82d56013Sjoerg   if (!B1)
176*82d56013Sjoerg     return false;
177*82d56013Sjoerg 
178*82d56013Sjoerg   Instruction *TI = B2->getTerminator();
179*82d56013Sjoerg   auto *BI = dyn_cast<BranchInst>(TI);
180*82d56013Sjoerg   if (!BI || !BI->isConditional())
181*82d56013Sjoerg     return false;
182*82d56013Sjoerg   auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
183*82d56013Sjoerg   if (!Cond || B2->getFirstNonPHI() != Cond)
184*82d56013Sjoerg     return false;
185*82d56013Sjoerg   Value *B2Op0 = Cond->getOperand(0);
186*82d56013Sjoerg   auto Cond2Op = Cond->getPredicate();
187*82d56013Sjoerg 
188*82d56013Sjoerg   TI = B1->getTerminator();
189*82d56013Sjoerg   BI = dyn_cast<BranchInst>(TI);
190*82d56013Sjoerg   if (!BI || !BI->isConditional())
191*82d56013Sjoerg     return false;
192*82d56013Sjoerg   Cond = dyn_cast<ICmpInst>(BI->getCondition());
193*82d56013Sjoerg   if (!Cond)
194*82d56013Sjoerg     return false;
195*82d56013Sjoerg   Value *B1Op0 = Cond->getOperand(0);
196*82d56013Sjoerg   auto Cond1Op = Cond->getPredicate();
197*82d56013Sjoerg 
198*82d56013Sjoerg   if (B1Op0 != B2Op0)
199*82d56013Sjoerg     return false;
200*82d56013Sjoerg 
201*82d56013Sjoerg   if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
202*82d56013Sjoerg     if (Cond2Op != ICmpInst::ICMP_SLT && Cond1Op != ICmpInst::ICMP_SLE)
203*82d56013Sjoerg       return false;
204*82d56013Sjoerg   } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
205*82d56013Sjoerg     if (Cond2Op != ICmpInst::ICMP_SGT && Cond1Op != ICmpInst::ICMP_SGE)
206*82d56013Sjoerg       return false;
207*82d56013Sjoerg   } else {
208*82d56013Sjoerg     return false;
209*82d56013Sjoerg   }
210*82d56013Sjoerg 
211*82d56013Sjoerg   PassThroughInfo Info(Cond, BI, 0);
212*82d56013Sjoerg   PassThroughs.push_back(Info);
213*82d56013Sjoerg 
214*82d56013Sjoerg   return true;
215*82d56013Sjoerg }
216*82d56013Sjoerg 
217*82d56013Sjoerg // To avoid speculative hoisting certain computations out of
218*82d56013Sjoerg // a basic block.
avoidSpeculation(Instruction & I)219*82d56013Sjoerg bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
220*82d56013Sjoerg   if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
221*82d56013Sjoerg     if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
222*82d56013Sjoerg       if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
223*82d56013Sjoerg           GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
224*82d56013Sjoerg         return false;
225*82d56013Sjoerg     }
226*82d56013Sjoerg   }
227*82d56013Sjoerg 
228*82d56013Sjoerg   if (!isa<LoadInst>(&I) && !isa<CallInst>(&I))
229*82d56013Sjoerg     return false;
230*82d56013Sjoerg 
231*82d56013Sjoerg   // For:
232*82d56013Sjoerg   //   B1:
233*82d56013Sjoerg   //     var = ...
234*82d56013Sjoerg   //     ...
235*82d56013Sjoerg   //     /* icmp may not be in the same block as var = ... */
236*82d56013Sjoerg   //     comp1 = icmp <opcode> var, <const>;
237*82d56013Sjoerg   //     if (comp1) goto B2 else B3;
238*82d56013Sjoerg   //   B2:
239*82d56013Sjoerg   //     ... var ...
240*82d56013Sjoerg   // change to:
241*82d56013Sjoerg   //   B1:
242*82d56013Sjoerg   //     var = ...
243*82d56013Sjoerg   //     ...
244*82d56013Sjoerg   //     /* icmp may not be in the same block as var = ... */
245*82d56013Sjoerg   //     comp1 = icmp <opcode> var, <const>;
246*82d56013Sjoerg   //     if (comp1) goto B2 else B3;
247*82d56013Sjoerg   //   B2:
248*82d56013Sjoerg   //     var = __builtin_bpf_passthrough(seq_num, var);
249*82d56013Sjoerg   //     ... var ...
250*82d56013Sjoerg   bool isCandidate = false;
251*82d56013Sjoerg   SmallVector<PassThroughInfo, 4> Candidates;
252*82d56013Sjoerg   for (User *U : I.users()) {
253*82d56013Sjoerg     Instruction *Inst = dyn_cast<Instruction>(U);
254*82d56013Sjoerg     if (!Inst)
255*82d56013Sjoerg       continue;
256*82d56013Sjoerg 
257*82d56013Sjoerg     // May cover a little bit more than the
258*82d56013Sjoerg     // above pattern.
259*82d56013Sjoerg     if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
260*82d56013Sjoerg       Value *Icmp1Op1 = Icmp1->getOperand(1);
261*82d56013Sjoerg       if (!isa<Constant>(Icmp1Op1))
262*82d56013Sjoerg         return false;
263*82d56013Sjoerg       isCandidate = true;
264*82d56013Sjoerg       continue;
265*82d56013Sjoerg     }
266*82d56013Sjoerg 
267*82d56013Sjoerg     // Ignore the use in the same basic block as the definition.
268*82d56013Sjoerg     if (Inst->getParent() == I.getParent())
269*82d56013Sjoerg       continue;
270*82d56013Sjoerg 
271*82d56013Sjoerg     // use in a different basic block, If there is a call or
272*82d56013Sjoerg     // load/store insn before this instruction in this basic
273*82d56013Sjoerg     // block. Most likely it cannot be hoisted out. Skip it.
274*82d56013Sjoerg     for (auto &I2 : *Inst->getParent()) {
275*82d56013Sjoerg       if (isa<CallInst>(&I2))
276*82d56013Sjoerg         return false;
277*82d56013Sjoerg       if (isa<LoadInst>(&I2) || isa<StoreInst>(&I2))
278*82d56013Sjoerg         return false;
279*82d56013Sjoerg       if (&I2 == Inst)
280*82d56013Sjoerg         break;
281*82d56013Sjoerg     }
282*82d56013Sjoerg 
283*82d56013Sjoerg     // It should be used in a GEP or a simple arithmetic like
284*82d56013Sjoerg     // ZEXT/SEXT which is used for GEP.
285*82d56013Sjoerg     if (Inst->getOpcode() == Instruction::ZExt ||
286*82d56013Sjoerg         Inst->getOpcode() == Instruction::SExt) {
287*82d56013Sjoerg       PassThroughInfo Info(&I, Inst, 0);
288*82d56013Sjoerg       Candidates.push_back(Info);
289*82d56013Sjoerg     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
290*82d56013Sjoerg       // traverse GEP inst to find Use operand index
291*82d56013Sjoerg       unsigned i, e;
292*82d56013Sjoerg       for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
293*82d56013Sjoerg         Value *V = GI->getOperand(i);
294*82d56013Sjoerg         if (V == &I)
295*82d56013Sjoerg           break;
296*82d56013Sjoerg       }
297*82d56013Sjoerg       if (i == e)
298*82d56013Sjoerg         continue;
299*82d56013Sjoerg 
300*82d56013Sjoerg       PassThroughInfo Info(&I, GI, i);
301*82d56013Sjoerg       Candidates.push_back(Info);
302*82d56013Sjoerg     }
303*82d56013Sjoerg   }
304*82d56013Sjoerg 
305*82d56013Sjoerg   if (!isCandidate || Candidates.empty())
306*82d56013Sjoerg     return false;
307*82d56013Sjoerg 
308*82d56013Sjoerg   llvm::append_range(PassThroughs, Candidates);
309*82d56013Sjoerg   return true;
310*82d56013Sjoerg }
311*82d56013Sjoerg 
adjustBasicBlock(BasicBlock & BB)312*82d56013Sjoerg void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
313*82d56013Sjoerg   if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
314*82d56013Sjoerg     return;
315*82d56013Sjoerg }
316*82d56013Sjoerg 
adjustInst(Instruction & I)317*82d56013Sjoerg void BPFAdjustOptImpl::adjustInst(Instruction &I) {
318*82d56013Sjoerg   if (!DisableBPFserializeICMP && serializeICMPInBB(I))
319*82d56013Sjoerg     return;
320*82d56013Sjoerg   if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
321*82d56013Sjoerg     return;
322*82d56013Sjoerg }
323*82d56013Sjoerg 
run(Module & M,ModuleAnalysisManager & AM)324*82d56013Sjoerg PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
325*82d56013Sjoerg   return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
326*82d56013Sjoerg                                     : PreservedAnalyses::all();
327*82d56013Sjoerg }
328