1*82d56013Sjoerg //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2*82d56013Sjoerg //
3*82d56013Sjoerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*82d56013Sjoerg // See https://llvm.org/LICENSE.txt for license information.
5*82d56013Sjoerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*82d56013Sjoerg //
7*82d56013Sjoerg //===----------------------------------------------------------------------===//
8*82d56013Sjoerg //
9*82d56013Sjoerg // Adjust optimization to make the code more kernel verifier friendly.
10*82d56013Sjoerg //
11*82d56013Sjoerg //===----------------------------------------------------------------------===//
12*82d56013Sjoerg
13*82d56013Sjoerg #include "BPF.h"
14*82d56013Sjoerg #include "BPFCORE.h"
15*82d56013Sjoerg #include "BPFTargetMachine.h"
16*82d56013Sjoerg #include "llvm/IR/Instruction.h"
17*82d56013Sjoerg #include "llvm/IR/Instructions.h"
18*82d56013Sjoerg #include "llvm/IR/Module.h"
19*82d56013Sjoerg #include "llvm/IR/PatternMatch.h"
20*82d56013Sjoerg #include "llvm/IR/Type.h"
21*82d56013Sjoerg #include "llvm/IR/User.h"
22*82d56013Sjoerg #include "llvm/IR/Value.h"
23*82d56013Sjoerg #include "llvm/Pass.h"
24*82d56013Sjoerg #include "llvm/Transforms/Utils/BasicBlockUtils.h"
25*82d56013Sjoerg
26*82d56013Sjoerg #define DEBUG_TYPE "bpf-adjust-opt"
27*82d56013Sjoerg
28*82d56013Sjoerg using namespace llvm;
29*82d56013Sjoerg using namespace llvm::PatternMatch;
30*82d56013Sjoerg
31*82d56013Sjoerg static cl::opt<bool>
32*82d56013Sjoerg DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
33*82d56013Sjoerg cl::desc("BPF: Disable Serializing ICMP insns."),
34*82d56013Sjoerg cl::init(false));
35*82d56013Sjoerg
36*82d56013Sjoerg static cl::opt<bool> DisableBPFavoidSpeculation(
37*82d56013Sjoerg "bpf-disable-avoid-speculation", cl::Hidden,
38*82d56013Sjoerg cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
39*82d56013Sjoerg cl::init(false));
40*82d56013Sjoerg
41*82d56013Sjoerg namespace {
42*82d56013Sjoerg
43*82d56013Sjoerg class BPFAdjustOpt final : public ModulePass {
44*82d56013Sjoerg public:
45*82d56013Sjoerg static char ID;
46*82d56013Sjoerg
BPFAdjustOpt()47*82d56013Sjoerg BPFAdjustOpt() : ModulePass(ID) {}
48*82d56013Sjoerg bool runOnModule(Module &M) override;
49*82d56013Sjoerg };
50*82d56013Sjoerg
51*82d56013Sjoerg class BPFAdjustOptImpl {
52*82d56013Sjoerg struct PassThroughInfo {
53*82d56013Sjoerg Instruction *Input;
54*82d56013Sjoerg Instruction *UsedInst;
55*82d56013Sjoerg uint32_t OpIdx;
PassThroughInfo__anon5a0d34ee0111::BPFAdjustOptImpl::PassThroughInfo56*82d56013Sjoerg PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
57*82d56013Sjoerg : Input(I), UsedInst(U), OpIdx(Idx) {}
58*82d56013Sjoerg };
59*82d56013Sjoerg
60*82d56013Sjoerg public:
BPFAdjustOptImpl(Module * M)61*82d56013Sjoerg BPFAdjustOptImpl(Module *M) : M(M) {}
62*82d56013Sjoerg
63*82d56013Sjoerg bool run();
64*82d56013Sjoerg
65*82d56013Sjoerg private:
66*82d56013Sjoerg Module *M;
67*82d56013Sjoerg SmallVector<PassThroughInfo, 16> PassThroughs;
68*82d56013Sjoerg
69*82d56013Sjoerg void adjustBasicBlock(BasicBlock &BB);
70*82d56013Sjoerg bool serializeICMPCrossBB(BasicBlock &BB);
71*82d56013Sjoerg void adjustInst(Instruction &I);
72*82d56013Sjoerg bool serializeICMPInBB(Instruction &I);
73*82d56013Sjoerg bool avoidSpeculation(Instruction &I);
74*82d56013Sjoerg bool insertPassThrough();
75*82d56013Sjoerg };
76*82d56013Sjoerg
77*82d56013Sjoerg } // End anonymous namespace
78*82d56013Sjoerg
79*82d56013Sjoerg char BPFAdjustOpt::ID = 0;
80*82d56013Sjoerg INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization",
81*82d56013Sjoerg false, false)
82*82d56013Sjoerg
createBPFAdjustOpt()83*82d56013Sjoerg ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); }
84*82d56013Sjoerg
runOnModule(Module & M)85*82d56013Sjoerg bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); }
86*82d56013Sjoerg
run()87*82d56013Sjoerg bool BPFAdjustOptImpl::run() {
88*82d56013Sjoerg for (Function &F : *M)
89*82d56013Sjoerg for (auto &BB : F) {
90*82d56013Sjoerg adjustBasicBlock(BB);
91*82d56013Sjoerg for (auto &I : BB)
92*82d56013Sjoerg adjustInst(I);
93*82d56013Sjoerg }
94*82d56013Sjoerg
95*82d56013Sjoerg return insertPassThrough();
96*82d56013Sjoerg }
97*82d56013Sjoerg
insertPassThrough()98*82d56013Sjoerg bool BPFAdjustOptImpl::insertPassThrough() {
99*82d56013Sjoerg for (auto &Info : PassThroughs) {
100*82d56013Sjoerg auto *CI = BPFCoreSharedInfo::insertPassThrough(
101*82d56013Sjoerg M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
102*82d56013Sjoerg Info.UsedInst->setOperand(Info.OpIdx, CI);
103*82d56013Sjoerg }
104*82d56013Sjoerg
105*82d56013Sjoerg return !PassThroughs.empty();
106*82d56013Sjoerg }
107*82d56013Sjoerg
108*82d56013Sjoerg // To avoid combining conditionals in the same basic block by
109*82d56013Sjoerg // instrcombine optimization.
serializeICMPInBB(Instruction & I)110*82d56013Sjoerg bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
111*82d56013Sjoerg // For:
112*82d56013Sjoerg // comp1 = icmp <opcode> ...;
113*82d56013Sjoerg // comp2 = icmp <opcode> ...;
114*82d56013Sjoerg // ... or comp1 comp2 ...
115*82d56013Sjoerg // changed to:
116*82d56013Sjoerg // comp1 = icmp <opcode> ...;
117*82d56013Sjoerg // comp2 = icmp <opcode> ...;
118*82d56013Sjoerg // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
119*82d56013Sjoerg // ... or new_comp1 comp2 ...
120*82d56013Sjoerg Value *Op0, *Op1;
121*82d56013Sjoerg // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
122*82d56013Sjoerg if (!match(&I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
123*82d56013Sjoerg return false;
124*82d56013Sjoerg auto *Icmp1 = dyn_cast<ICmpInst>(Op0);
125*82d56013Sjoerg if (!Icmp1)
126*82d56013Sjoerg return false;
127*82d56013Sjoerg auto *Icmp2 = dyn_cast<ICmpInst>(Op1);
128*82d56013Sjoerg if (!Icmp2)
129*82d56013Sjoerg return false;
130*82d56013Sjoerg
131*82d56013Sjoerg Value *Icmp1Op0 = Icmp1->getOperand(0);
132*82d56013Sjoerg Value *Icmp2Op0 = Icmp2->getOperand(0);
133*82d56013Sjoerg if (Icmp1Op0 != Icmp2Op0)
134*82d56013Sjoerg return false;
135*82d56013Sjoerg
136*82d56013Sjoerg // Now we got two icmp instructions which feed into
137*82d56013Sjoerg // an "or" instruction.
138*82d56013Sjoerg PassThroughInfo Info(Icmp1, &I, 0);
139*82d56013Sjoerg PassThroughs.push_back(Info);
140*82d56013Sjoerg return true;
141*82d56013Sjoerg }
142*82d56013Sjoerg
143*82d56013Sjoerg // To avoid combining conditionals in the same basic block by
144*82d56013Sjoerg // instrcombine optimization.
serializeICMPCrossBB(BasicBlock & BB)145*82d56013Sjoerg bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
146*82d56013Sjoerg // For:
147*82d56013Sjoerg // B1:
148*82d56013Sjoerg // comp1 = icmp <opcode> ...;
149*82d56013Sjoerg // if (comp1) goto B2 else B3;
150*82d56013Sjoerg // B2:
151*82d56013Sjoerg // comp2 = icmp <opcode> ...;
152*82d56013Sjoerg // if (comp2) goto B4 else B5;
153*82d56013Sjoerg // B4:
154*82d56013Sjoerg // ...
155*82d56013Sjoerg // changed to:
156*82d56013Sjoerg // B1:
157*82d56013Sjoerg // comp1 = icmp <opcode> ...;
158*82d56013Sjoerg // comp1 = __builtin_bpf_passthrough(seq_num, comp1);
159*82d56013Sjoerg // if (comp1) goto B2 else B3;
160*82d56013Sjoerg // B2:
161*82d56013Sjoerg // comp2 = icmp <opcode> ...;
162*82d56013Sjoerg // if (comp2) goto B4 else B5;
163*82d56013Sjoerg // B4:
164*82d56013Sjoerg // ...
165*82d56013Sjoerg
166*82d56013Sjoerg // Check basic predecessors, if two of them (say B1, B2) are using
167*82d56013Sjoerg // icmp instructions to generate conditions and one is the predesessor
168*82d56013Sjoerg // of another (e.g., B1 is the predecessor of B2). Add a passthrough
169*82d56013Sjoerg // barrier after icmp inst of block B1.
170*82d56013Sjoerg BasicBlock *B2 = BB.getSinglePredecessor();
171*82d56013Sjoerg if (!B2)
172*82d56013Sjoerg return false;
173*82d56013Sjoerg
174*82d56013Sjoerg BasicBlock *B1 = B2->getSinglePredecessor();
175*82d56013Sjoerg if (!B1)
176*82d56013Sjoerg return false;
177*82d56013Sjoerg
178*82d56013Sjoerg Instruction *TI = B2->getTerminator();
179*82d56013Sjoerg auto *BI = dyn_cast<BranchInst>(TI);
180*82d56013Sjoerg if (!BI || !BI->isConditional())
181*82d56013Sjoerg return false;
182*82d56013Sjoerg auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
183*82d56013Sjoerg if (!Cond || B2->getFirstNonPHI() != Cond)
184*82d56013Sjoerg return false;
185*82d56013Sjoerg Value *B2Op0 = Cond->getOperand(0);
186*82d56013Sjoerg auto Cond2Op = Cond->getPredicate();
187*82d56013Sjoerg
188*82d56013Sjoerg TI = B1->getTerminator();
189*82d56013Sjoerg BI = dyn_cast<BranchInst>(TI);
190*82d56013Sjoerg if (!BI || !BI->isConditional())
191*82d56013Sjoerg return false;
192*82d56013Sjoerg Cond = dyn_cast<ICmpInst>(BI->getCondition());
193*82d56013Sjoerg if (!Cond)
194*82d56013Sjoerg return false;
195*82d56013Sjoerg Value *B1Op0 = Cond->getOperand(0);
196*82d56013Sjoerg auto Cond1Op = Cond->getPredicate();
197*82d56013Sjoerg
198*82d56013Sjoerg if (B1Op0 != B2Op0)
199*82d56013Sjoerg return false;
200*82d56013Sjoerg
201*82d56013Sjoerg if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
202*82d56013Sjoerg if (Cond2Op != ICmpInst::ICMP_SLT && Cond1Op != ICmpInst::ICMP_SLE)
203*82d56013Sjoerg return false;
204*82d56013Sjoerg } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
205*82d56013Sjoerg if (Cond2Op != ICmpInst::ICMP_SGT && Cond1Op != ICmpInst::ICMP_SGE)
206*82d56013Sjoerg return false;
207*82d56013Sjoerg } else {
208*82d56013Sjoerg return false;
209*82d56013Sjoerg }
210*82d56013Sjoerg
211*82d56013Sjoerg PassThroughInfo Info(Cond, BI, 0);
212*82d56013Sjoerg PassThroughs.push_back(Info);
213*82d56013Sjoerg
214*82d56013Sjoerg return true;
215*82d56013Sjoerg }
216*82d56013Sjoerg
217*82d56013Sjoerg // To avoid speculative hoisting certain computations out of
218*82d56013Sjoerg // a basic block.
avoidSpeculation(Instruction & I)219*82d56013Sjoerg bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
220*82d56013Sjoerg if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
221*82d56013Sjoerg if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
222*82d56013Sjoerg if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
223*82d56013Sjoerg GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
224*82d56013Sjoerg return false;
225*82d56013Sjoerg }
226*82d56013Sjoerg }
227*82d56013Sjoerg
228*82d56013Sjoerg if (!isa<LoadInst>(&I) && !isa<CallInst>(&I))
229*82d56013Sjoerg return false;
230*82d56013Sjoerg
231*82d56013Sjoerg // For:
232*82d56013Sjoerg // B1:
233*82d56013Sjoerg // var = ...
234*82d56013Sjoerg // ...
235*82d56013Sjoerg // /* icmp may not be in the same block as var = ... */
236*82d56013Sjoerg // comp1 = icmp <opcode> var, <const>;
237*82d56013Sjoerg // if (comp1) goto B2 else B3;
238*82d56013Sjoerg // B2:
239*82d56013Sjoerg // ... var ...
240*82d56013Sjoerg // change to:
241*82d56013Sjoerg // B1:
242*82d56013Sjoerg // var = ...
243*82d56013Sjoerg // ...
244*82d56013Sjoerg // /* icmp may not be in the same block as var = ... */
245*82d56013Sjoerg // comp1 = icmp <opcode> var, <const>;
246*82d56013Sjoerg // if (comp1) goto B2 else B3;
247*82d56013Sjoerg // B2:
248*82d56013Sjoerg // var = __builtin_bpf_passthrough(seq_num, var);
249*82d56013Sjoerg // ... var ...
250*82d56013Sjoerg bool isCandidate = false;
251*82d56013Sjoerg SmallVector<PassThroughInfo, 4> Candidates;
252*82d56013Sjoerg for (User *U : I.users()) {
253*82d56013Sjoerg Instruction *Inst = dyn_cast<Instruction>(U);
254*82d56013Sjoerg if (!Inst)
255*82d56013Sjoerg continue;
256*82d56013Sjoerg
257*82d56013Sjoerg // May cover a little bit more than the
258*82d56013Sjoerg // above pattern.
259*82d56013Sjoerg if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
260*82d56013Sjoerg Value *Icmp1Op1 = Icmp1->getOperand(1);
261*82d56013Sjoerg if (!isa<Constant>(Icmp1Op1))
262*82d56013Sjoerg return false;
263*82d56013Sjoerg isCandidate = true;
264*82d56013Sjoerg continue;
265*82d56013Sjoerg }
266*82d56013Sjoerg
267*82d56013Sjoerg // Ignore the use in the same basic block as the definition.
268*82d56013Sjoerg if (Inst->getParent() == I.getParent())
269*82d56013Sjoerg continue;
270*82d56013Sjoerg
271*82d56013Sjoerg // use in a different basic block, If there is a call or
272*82d56013Sjoerg // load/store insn before this instruction in this basic
273*82d56013Sjoerg // block. Most likely it cannot be hoisted out. Skip it.
274*82d56013Sjoerg for (auto &I2 : *Inst->getParent()) {
275*82d56013Sjoerg if (isa<CallInst>(&I2))
276*82d56013Sjoerg return false;
277*82d56013Sjoerg if (isa<LoadInst>(&I2) || isa<StoreInst>(&I2))
278*82d56013Sjoerg return false;
279*82d56013Sjoerg if (&I2 == Inst)
280*82d56013Sjoerg break;
281*82d56013Sjoerg }
282*82d56013Sjoerg
283*82d56013Sjoerg // It should be used in a GEP or a simple arithmetic like
284*82d56013Sjoerg // ZEXT/SEXT which is used for GEP.
285*82d56013Sjoerg if (Inst->getOpcode() == Instruction::ZExt ||
286*82d56013Sjoerg Inst->getOpcode() == Instruction::SExt) {
287*82d56013Sjoerg PassThroughInfo Info(&I, Inst, 0);
288*82d56013Sjoerg Candidates.push_back(Info);
289*82d56013Sjoerg } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
290*82d56013Sjoerg // traverse GEP inst to find Use operand index
291*82d56013Sjoerg unsigned i, e;
292*82d56013Sjoerg for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
293*82d56013Sjoerg Value *V = GI->getOperand(i);
294*82d56013Sjoerg if (V == &I)
295*82d56013Sjoerg break;
296*82d56013Sjoerg }
297*82d56013Sjoerg if (i == e)
298*82d56013Sjoerg continue;
299*82d56013Sjoerg
300*82d56013Sjoerg PassThroughInfo Info(&I, GI, i);
301*82d56013Sjoerg Candidates.push_back(Info);
302*82d56013Sjoerg }
303*82d56013Sjoerg }
304*82d56013Sjoerg
305*82d56013Sjoerg if (!isCandidate || Candidates.empty())
306*82d56013Sjoerg return false;
307*82d56013Sjoerg
308*82d56013Sjoerg llvm::append_range(PassThroughs, Candidates);
309*82d56013Sjoerg return true;
310*82d56013Sjoerg }
311*82d56013Sjoerg
adjustBasicBlock(BasicBlock & BB)312*82d56013Sjoerg void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
313*82d56013Sjoerg if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
314*82d56013Sjoerg return;
315*82d56013Sjoerg }
316*82d56013Sjoerg
adjustInst(Instruction & I)317*82d56013Sjoerg void BPFAdjustOptImpl::adjustInst(Instruction &I) {
318*82d56013Sjoerg if (!DisableBPFserializeICMP && serializeICMPInBB(I))
319*82d56013Sjoerg return;
320*82d56013Sjoerg if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
321*82d56013Sjoerg return;
322*82d56013Sjoerg }
323*82d56013Sjoerg
run(Module & M,ModuleAnalysisManager & AM)324*82d56013Sjoerg PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
325*82d56013Sjoerg return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
326*82d56013Sjoerg : PreservedAnalyses::all();
327*82d56013Sjoerg }
328