1 //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Adjust optimization to make the code more kernel verifier friendly.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "BPF.h"
14 #include "BPFCORE.h"
15 #include "BPFTargetMachine.h"
16 #include "llvm/IR/Instruction.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/IntrinsicsBPF.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/IR/Type.h"
22 #include "llvm/IR/User.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
26
27 #define DEBUG_TYPE "bpf-adjust-opt"
28
29 using namespace llvm;
30 using namespace llvm::PatternMatch;
31
32 static cl::opt<bool>
33 DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
34 cl::desc("BPF: Disable Serializing ICMP insns."),
35 cl::init(false));
36
37 static cl::opt<bool> DisableBPFavoidSpeculation(
38 "bpf-disable-avoid-speculation", cl::Hidden,
39 cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
40 cl::init(false));
41
42 namespace {
43
44 class BPFAdjustOpt final : public ModulePass {
45 public:
46 static char ID;
47
BPFAdjustOpt()48 BPFAdjustOpt() : ModulePass(ID) {}
49 bool runOnModule(Module &M) override;
50 };
51
52 class BPFAdjustOptImpl {
53 struct PassThroughInfo {
54 Instruction *Input;
55 Instruction *UsedInst;
56 uint32_t OpIdx;
PassThroughInfo__anon8b199e3a0111::BPFAdjustOptImpl::PassThroughInfo57 PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
58 : Input(I), UsedInst(U), OpIdx(Idx) {}
59 };
60
61 public:
BPFAdjustOptImpl(Module * M)62 BPFAdjustOptImpl(Module *M) : M(M) {}
63
64 bool run();
65
66 private:
67 Module *M;
68 SmallVector<PassThroughInfo, 16> PassThroughs;
69
70 bool adjustICmpToBuiltin();
71 void adjustBasicBlock(BasicBlock &BB);
72 bool serializeICMPCrossBB(BasicBlock &BB);
73 void adjustInst(Instruction &I);
74 bool serializeICMPInBB(Instruction &I);
75 bool avoidSpeculation(Instruction &I);
76 bool insertPassThrough();
77 };
78
79 } // End anonymous namespace
80
81 char BPFAdjustOpt::ID = 0;
82 INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization",
83 false, false)
84
createBPFAdjustOpt()85 ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); }
86
runOnModule(Module & M)87 bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); }
88
run()89 bool BPFAdjustOptImpl::run() {
90 bool Changed = adjustICmpToBuiltin();
91
92 for (Function &F : *M)
93 for (auto &BB : F) {
94 adjustBasicBlock(BB);
95 for (auto &I : BB)
96 adjustInst(I);
97 }
98 return insertPassThrough() || Changed;
99 }
100
101 // Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with
102 // trunc op into mask and cmp") added a transformation to
103 // convert "(conv)a < power_2_const" to "a & <const>" in certain
104 // cases and bpf kernel verifier has to handle the resulted code
105 // conservatively and this may reject otherwise legitimate program.
106 // Here, we change related icmp code to a builtin which will
107 // be restored to original icmp code later to prevent that
108 // InstCombine transformatin.
adjustICmpToBuiltin()109 bool BPFAdjustOptImpl::adjustICmpToBuiltin() {
110 bool Changed = false;
111 ICmpInst *ToBeDeleted = nullptr;
112 for (Function &F : *M)
113 for (auto &BB : F)
114 for (auto &I : BB) {
115 if (ToBeDeleted) {
116 ToBeDeleted->eraseFromParent();
117 ToBeDeleted = nullptr;
118 }
119
120 auto *Icmp = dyn_cast<ICmpInst>(&I);
121 if (!Icmp)
122 continue;
123
124 Value *Op0 = Icmp->getOperand(0);
125 if (!isa<TruncInst>(Op0))
126 continue;
127
128 auto ConstOp1 = dyn_cast<ConstantInt>(Icmp->getOperand(1));
129 if (!ConstOp1)
130 continue;
131
132 auto ConstOp1Val = ConstOp1->getValue().getZExtValue();
133 auto Op = Icmp->getPredicate();
134 if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) {
135 if ((ConstOp1Val - 1) & ConstOp1Val)
136 continue;
137 } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) {
138 if (ConstOp1Val & (ConstOp1Val + 1))
139 continue;
140 } else {
141 continue;
142 }
143
144 Constant *Opcode =
145 ConstantInt::get(Type::getInt32Ty(BB.getContext()), Op);
146 Function *Fn = Intrinsic::getDeclaration(
147 M, Intrinsic::bpf_compare, {Op0->getType(), ConstOp1->getType()});
148 auto *NewInst = CallInst::Create(Fn, {Opcode, Op0, ConstOp1});
149 NewInst->insertBefore(&I);
150 Icmp->replaceAllUsesWith(NewInst);
151 Changed = true;
152 ToBeDeleted = Icmp;
153 }
154
155 return Changed;
156 }
157
insertPassThrough()158 bool BPFAdjustOptImpl::insertPassThrough() {
159 for (auto &Info : PassThroughs) {
160 auto *CI = BPFCoreSharedInfo::insertPassThrough(
161 M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
162 Info.UsedInst->setOperand(Info.OpIdx, CI);
163 }
164
165 return !PassThroughs.empty();
166 }
167
168 // To avoid combining conditionals in the same basic block by
169 // instrcombine optimization.
serializeICMPInBB(Instruction & I)170 bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
171 // For:
172 // comp1 = icmp <opcode> ...;
173 // comp2 = icmp <opcode> ...;
174 // ... or comp1 comp2 ...
175 // changed to:
176 // comp1 = icmp <opcode> ...;
177 // comp2 = icmp <opcode> ...;
178 // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
179 // ... or new_comp1 comp2 ...
180 Value *Op0, *Op1;
181 // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
182 if (!match(&I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
183 return false;
184 auto *Icmp1 = dyn_cast<ICmpInst>(Op0);
185 if (!Icmp1)
186 return false;
187 auto *Icmp2 = dyn_cast<ICmpInst>(Op1);
188 if (!Icmp2)
189 return false;
190
191 Value *Icmp1Op0 = Icmp1->getOperand(0);
192 Value *Icmp2Op0 = Icmp2->getOperand(0);
193 if (Icmp1Op0 != Icmp2Op0)
194 return false;
195
196 // Now we got two icmp instructions which feed into
197 // an "or" instruction.
198 PassThroughInfo Info(Icmp1, &I, 0);
199 PassThroughs.push_back(Info);
200 return true;
201 }
202
203 // To avoid combining conditionals in the same basic block by
204 // instrcombine optimization.
serializeICMPCrossBB(BasicBlock & BB)205 bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
206 // For:
207 // B1:
208 // comp1 = icmp <opcode> ...;
209 // if (comp1) goto B2 else B3;
210 // B2:
211 // comp2 = icmp <opcode> ...;
212 // if (comp2) goto B4 else B5;
213 // B4:
214 // ...
215 // changed to:
216 // B1:
217 // comp1 = icmp <opcode> ...;
218 // comp1 = __builtin_bpf_passthrough(seq_num, comp1);
219 // if (comp1) goto B2 else B3;
220 // B2:
221 // comp2 = icmp <opcode> ...;
222 // if (comp2) goto B4 else B5;
223 // B4:
224 // ...
225
226 // Check basic predecessors, if two of them (say B1, B2) are using
227 // icmp instructions to generate conditions and one is the predesessor
228 // of another (e.g., B1 is the predecessor of B2). Add a passthrough
229 // barrier after icmp inst of block B1.
230 BasicBlock *B2 = BB.getSinglePredecessor();
231 if (!B2)
232 return false;
233
234 BasicBlock *B1 = B2->getSinglePredecessor();
235 if (!B1)
236 return false;
237
238 Instruction *TI = B2->getTerminator();
239 auto *BI = dyn_cast<BranchInst>(TI);
240 if (!BI || !BI->isConditional())
241 return false;
242 auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
243 if (!Cond || B2->getFirstNonPHI() != Cond)
244 return false;
245 Value *B2Op0 = Cond->getOperand(0);
246 auto Cond2Op = Cond->getPredicate();
247
248 TI = B1->getTerminator();
249 BI = dyn_cast<BranchInst>(TI);
250 if (!BI || !BI->isConditional())
251 return false;
252 Cond = dyn_cast<ICmpInst>(BI->getCondition());
253 if (!Cond)
254 return false;
255 Value *B1Op0 = Cond->getOperand(0);
256 auto Cond1Op = Cond->getPredicate();
257
258 if (B1Op0 != B2Op0)
259 return false;
260
261 if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
262 if (Cond2Op != ICmpInst::ICMP_SLT && Cond2Op != ICmpInst::ICMP_SLE)
263 return false;
264 } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
265 if (Cond2Op != ICmpInst::ICMP_SGT && Cond2Op != ICmpInst::ICMP_SGE)
266 return false;
267 } else if (Cond1Op == ICmpInst::ICMP_ULT || Cond1Op == ICmpInst::ICMP_ULE) {
268 if (Cond2Op != ICmpInst::ICMP_UGT && Cond2Op != ICmpInst::ICMP_UGE)
269 return false;
270 } else if (Cond1Op == ICmpInst::ICMP_UGT || Cond1Op == ICmpInst::ICMP_UGE) {
271 if (Cond2Op != ICmpInst::ICMP_ULT && Cond2Op != ICmpInst::ICMP_ULE)
272 return false;
273 } else {
274 return false;
275 }
276
277 PassThroughInfo Info(Cond, BI, 0);
278 PassThroughs.push_back(Info);
279
280 return true;
281 }
282
283 // To avoid speculative hoisting certain computations out of
284 // a basic block.
avoidSpeculation(Instruction & I)285 bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
286 if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
287 if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
288 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
289 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
290 return false;
291 }
292 }
293
294 if (!isa<LoadInst>(&I) && !isa<CallInst>(&I))
295 return false;
296
297 // For:
298 // B1:
299 // var = ...
300 // ...
301 // /* icmp may not be in the same block as var = ... */
302 // comp1 = icmp <opcode> var, <const>;
303 // if (comp1) goto B2 else B3;
304 // B2:
305 // ... var ...
306 // change to:
307 // B1:
308 // var = ...
309 // ...
310 // /* icmp may not be in the same block as var = ... */
311 // comp1 = icmp <opcode> var, <const>;
312 // if (comp1) goto B2 else B3;
313 // B2:
314 // var = __builtin_bpf_passthrough(seq_num, var);
315 // ... var ...
316 bool isCandidate = false;
317 SmallVector<PassThroughInfo, 4> Candidates;
318 for (User *U : I.users()) {
319 Instruction *Inst = dyn_cast<Instruction>(U);
320 if (!Inst)
321 continue;
322
323 // May cover a little bit more than the
324 // above pattern.
325 if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
326 Value *Icmp1Op1 = Icmp1->getOperand(1);
327 if (!isa<Constant>(Icmp1Op1))
328 return false;
329 isCandidate = true;
330 continue;
331 }
332
333 // Ignore the use in the same basic block as the definition.
334 if (Inst->getParent() == I.getParent())
335 continue;
336
337 // use in a different basic block, If there is a call or
338 // load/store insn before this instruction in this basic
339 // block. Most likely it cannot be hoisted out. Skip it.
340 for (auto &I2 : *Inst->getParent()) {
341 if (isa<CallInst>(&I2))
342 return false;
343 if (isa<LoadInst>(&I2) || isa<StoreInst>(&I2))
344 return false;
345 if (&I2 == Inst)
346 break;
347 }
348
349 // It should be used in a GEP or a simple arithmetic like
350 // ZEXT/SEXT which is used for GEP.
351 if (Inst->getOpcode() == Instruction::ZExt ||
352 Inst->getOpcode() == Instruction::SExt) {
353 PassThroughInfo Info(&I, Inst, 0);
354 Candidates.push_back(Info);
355 } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
356 // traverse GEP inst to find Use operand index
357 unsigned i, e;
358 for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
359 Value *V = GI->getOperand(i);
360 if (V == &I)
361 break;
362 }
363 if (i == e)
364 continue;
365
366 PassThroughInfo Info(&I, GI, i);
367 Candidates.push_back(Info);
368 }
369 }
370
371 if (!isCandidate || Candidates.empty())
372 return false;
373
374 llvm::append_range(PassThroughs, Candidates);
375 return true;
376 }
377
adjustBasicBlock(BasicBlock & BB)378 void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
379 if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
380 return;
381 }
382
adjustInst(Instruction & I)383 void BPFAdjustOptImpl::adjustInst(Instruction &I) {
384 if (!DisableBPFserializeICMP && serializeICMPInBB(I))
385 return;
386 if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
387 return;
388 }
389
run(Module & M,ModuleAnalysisManager & AM)390 PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
391 return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
392 : PreservedAnalyses::all();
393 }
394